AdamW optimizer for bfloat16 models in pytorch 🔥.
MIT License
However, usage of bfloat16 in torch ecosystem is ... awkward (torch AMP is very non-transparent, and was initially developed with a focus on fp16, which is totally different from bf16).
If you just convert all weights and inputs to bfloat16, you're likely to run into an issue of stale weights: updates are too small to modify bfloat16 weight (see gopher paper, section C2 for a large-scale example).
There are two possible remedies:
As recent study has shown, both options are completely competitive in quality to float32 training. That's what we implement here in a convenient wrapper.
pip install git+https://github.com/arogozhnikov/adamw_bfloat16.git
Use as a drop-in replacement for pytorch's AdamW:
import torch
from adamw_bfloat16 import LR, AdamW_BF16
model = model.to(torch.bfloat16)
# default preheat and decay
optimizer = AdamW_BF16(model.parameters())
# configure LR schedule. Use built-in scheduling opportunity
optimizer = AdamW_BF16(model.parameters(), lr_function=LR(lr=1e-4, preheat_steps=5000, decay_power=-0.25))
# in the loop:
loss.backward()
optimizer.step()
optimizer.zero_grad()
Or you can even replace last two lines with one:
optimizer.step(zero_grad=True)
This optimizer simplifies the code by removing:
.step()
scheduler)Uses ~25% less memory per parameter compared to built-in AdamW.