Skip to content

Fix rms_norm in fp16/bf16#147203

Closed
riccardofelluga wants to merge 3 commits intopytorch:mainfrom
riccardofelluga:rms-cast
Closed

Fix rms_norm in fp16/bf16#147203
riccardofelluga wants to merge 3 commits intopytorch:mainfrom
riccardofelluga:rms-cast

Conversation

@riccardofelluga
Copy link
Copy Markdown
Contributor

Fixes #134106. This PR moves the upcasted_result down-casting after all computation is done.

Since the multiplication with the weight_opt input is not done in half precision, the current code path is doing the following: fp16 -> fp32 -> fp16 -> fp32 -> fp16. What we want tho is to avoid down-casting and this PR proposes: fp16 -> fp32 -> fp16. This results in better accuracy as it avoids truncating.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147203

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 5a0c8cf with merge base 81847d0 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Feb 14, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Feb 14, 2025
@colesbury colesbury requested a review from albanD February 18, 2025 16:32
@colesbury colesbury added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 18, 2025
Copy link
Copy Markdown
Collaborator

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add a test that checks for whether the expected tolerances are met?

@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2025
@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 3, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@riccardofelluga riccardofelluga requested a review from eqy March 3, 2025 21:48
@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot rebase

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 7, 2025

You don't have permissions to rebase this PR since you are a first time contributor. If you think this is a mistake, please contact PyTorch Dev Infra.

@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 7, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -i

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 7, 2025

-i flag is only allowed for users with write permissions

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Mar 7, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Successfully rebased rms-cast onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout rms-cast && git pull --rebase)

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 7, 2025
@riccardofelluga
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 7, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@riccardofelluga @eqy This PR seems to break torchtitan float8 training with rowwise scales, when RMS norm is used.

  • Sometime between 2.6.0 and present, a change in pytorch core was introduced that caused loss to not go down and then eventually become NaN after 40 or so steps.
  • I binary searched the commits in this time range and confirmed this commit is what caused the regression (link)
  • I also confirmed the issue reproduces using rmsnorm with the latest nightly build, and does not reproduce using layernorm (link)

Can we either revert this change or look into a fix asap please? This NaN issue is currently blocking the release of a blog post on float8 rowwise training, so we are eager to resolve it as soon as possible. Thanks!

cc @vkuzo @lessw2020

@danielvegamyhre
Copy link
Copy Markdown
Contributor

fyi @drisspg on #147203 (comment) as well since Vasiliy is OOO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants