Skip to content

fix RoPE t range issue for fp16#26602

Merged
Rocketknight1 merged 1 commit intohuggingface:mainfrom
rui-ren:avoid-fp16-issue-for-range-ops
Oct 6, 2023
Merged

fix RoPE t range issue for fp16#26602
Rocketknight1 merged 1 commit intohuggingface:mainfrom
rui-ren:avoid-fp16-issue-for-range-ops

Conversation

@rui-ren
Copy link
Contributor

@rui-ren rui-ren commented Oct 4, 2023

Issue

Sometimes training with fp16, the dtype of self.inv_freq will be changed from fp32 to fp16. This scenario will cause the position t to use dtype of fp16, like

t = torch.arange(seq_len, device=device, dtype=torch.float16)

After converting to onnx graph, however, Range Ops in onnx do not support fp16 as here

Update

Use the below to avoid this scenario

t = torch.arange(seq_len, device=device).to(dtype)

@LysandreJik
Copy link
Member

Seems fair, WDYT @Rocketknight1 ?

@Rocketknight1
Copy link
Member

This will cause outputs to change numerically a bit when running in float16 or bfloat16 because freqs will be calculated in higher precision. The memory/performance impact probably wouldn't be huge, but this might introduce a small deviation for models that were trained in those precisions! Let me run some tests before I approve it.

@Rocketknight1
Copy link
Member

After testing, outputs seem equivalent for bfloat16 models so I'm happy to approve this!

@Rocketknight1
Copy link
Member

@rui-ren let me know if you want to add anything else to this PR, or if you're happy for me to merge it now!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@rui-ren
Copy link
Contributor Author

rui-ren commented Oct 5, 2023

@Rocketknight1 Please merge this PR. Thank you for your review.

@Rocketknight1 Rocketknight1 merged commit 8749942 into huggingface:main Oct 6, 2023
@Rocketknight1
Copy link
Member

Done. Thanks for a clean and helpful PR @rui-ren!

@rui-ren rui-ren deleted the avoid-fp16-issue-for-range-ops branch October 6, 2023 14:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants