Skip to content

Fix/casting continue pretraining#1200

Merged
danielhanchen merged 6 commits into
unslothai:mainfrom
Erland366:fix/casting-continue-pretraining
Oct 27, 2024
Merged

Fix/casting continue pretraining#1200
danielhanchen merged 6 commits into
unslothai:mainfrom
Erland366:fix/casting-continue-pretraining

Conversation

@Erland366

@Erland366 Erland366 commented Oct 27, 2024

Copy link
Copy Markdown
Collaborator

Theres' this issue of attempting unscale FP16 gradients

image

After investigation, this is because of global dtype, which is when we use it on colab, we will use torch.float16 instead of torch.bfloat16. This error does not happened if we use torch.bfloat16 (my own RTX4090 for example). So we need to use torch.float32 on device that is still using torch.float16

Here is the result on bfloat16
image

Abd here's the result on float32 (colab)
image

@Erland366 Erland366 marked this pull request as ready for review October 27, 2024 12:12
@Erland366

Copy link
Copy Markdown
Collaborator Author

Previously I drafted this because somehow I can't use BS 2 just like the example. BUt now I can use BS 2. So opening this PR instead of Draft

@gautamabambang

Copy link
Copy Markdown

Thankyou so much for bringing this up in PR man🙏🙏

@danielhanchen danielhanchen merged commit fdf25b7 into unslothai:main Oct 27, 2024
@danielhanchen

Copy link
Copy Markdown
Member

Oh I totally missed float16 cannot be used and only bfloat16 can be used for continued pretraining - nice catch!

@mombip

mombip commented Apr 25, 2025

Copy link
Copy Markdown

I believe this change lead to error while saving model with float32 embeddings (then model.config.torch_dtype is 'float32').

model = FastLanguageModel.get_peft_model(
    model,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj", "embed_tokens"],
    ... 

During saving model with model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",) I'm getting error: RuntimeError: Invalid device string: 'float32'. It is caused because string value of torch_dtype = 'float32' is not supported - there is no mapping to torch.float32:

unsloth/save.py (551)

    ...
    torch_dtype = internal_model.config.torch_dtype
    if type(torch_dtype) is str:
        if   torch_dtype ==  "float16": torch_dtype = torch.float16
        elif torch_dtype == "bfloat16": torch_dtype = torch.bfloat16
    pass

    # Check modules to save float32 dtype
    state_dict["model.embed_tokens.weight"] = internal_model.model.embed_tokens.weight.data.to(torch_dtype)
    ...

For "float16" and "float16" are converted to Torch dtype. For "float32" it remains a string and causes Error when internal_model.model.embed_tokens.weight.data.to(torch_dtype) is called.

abiswas-realadvice pushed a commit to abiswas-realadvice/unsloth that referenced this pull request May 14, 2026
* Bring back float32 if float16 instead of bfloat16

* Refactor mixed precision handling for lm_head and embed_tokens to ensure correct dtype usage

* Fix dtype retrieval for embed_tokens and lm_head in mixed precision training

* Fix dtype retrieval for embed_tokens and lm_head to use weight dtype in mixed precision training

* Fix dtype handling for embed_tokens and lm_head to ensure correct float32 usage in mixed precision training

* Fix dtype assignment for lm_head modules to ensure correct weight dtype usage in mixed precision training
ayoubzulfiqar pushed a commit to ayoubzulfiqar/unsloth that referenced this pull request Jun 11, 2026
* Bring back float32 if float16 instead of bfloat16

* Refactor mixed precision handling for lm_head and embed_tokens to ensure correct dtype usage

* Fix dtype retrieval for embed_tokens and lm_head in mixed precision training

* Fix dtype retrieval for embed_tokens and lm_head to use weight dtype in mixed precision training

* Fix dtype handling for embed_tokens and lm_head to ensure correct float32 usage in mixed precision training

* Fix dtype assignment for lm_head modules to ensure correct weight dtype usage in mixed precision training
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants