Fix/casting continue pretraining#1200
Conversation
…ure correct dtype usage
…in mixed precision training
…at32 usage in mixed precision training
…pe usage in mixed precision training
|
Previously I drafted this because somehow I can't use BS 2 just like the example. BUt now I can use BS 2. So opening this PR instead of Draft |
|
Thankyou so much for bringing this up in PR man🙏🙏 |
|
Oh I totally missed float16 cannot be used and only bfloat16 can be used for continued pretraining - nice catch! |
|
I believe this change lead to error while saving model with During saving model with unsloth/save.py (551) For |
* Bring back float32 if float16 instead of bfloat16 * Refactor mixed precision handling for lm_head and embed_tokens to ensure correct dtype usage * Fix dtype retrieval for embed_tokens and lm_head in mixed precision training * Fix dtype retrieval for embed_tokens and lm_head to use weight dtype in mixed precision training * Fix dtype handling for embed_tokens and lm_head to ensure correct float32 usage in mixed precision training * Fix dtype assignment for lm_head modules to ensure correct weight dtype usage in mixed precision training
* Bring back float32 if float16 instead of bfloat16 * Refactor mixed precision handling for lm_head and embed_tokens to ensure correct dtype usage * Fix dtype retrieval for embed_tokens and lm_head in mixed precision training * Fix dtype retrieval for embed_tokens and lm_head to use weight dtype in mixed precision training * Fix dtype handling for embed_tokens and lm_head to ensure correct float32 usage in mixed precision training * Fix dtype assignment for lm_head modules to ensure correct weight dtype usage in mixed precision training
Theres' this issue of attempting unscale FP16 gradients
After investigation, this is because of global dtype, which is when we use it on colab, we will use
torch.float16instead oftorch.bfloat16. This error does not happened if we usetorch.bfloat16(my own RTX4090 for example). So we need to usetorch.float32on device that is still usingtorch.float16Here is the result on bfloat16

Abd here's the result on float32 (colab)
