Skip to content

[ROCm] port CK rowwise F8 from fbgemm#140856

Closed
jeffdaily wants to merge 10 commits intopytorch:mainfrom
ROCm:ck_rowwise_f8_fbgemm
Closed

[ROCm] port CK rowwise F8 from fbgemm#140856
jeffdaily wants to merge 10 commits intopytorch:mainfrom
ROCm:ck_rowwise_f8_fbgemm

Conversation

@jeffdaily
Copy link
Collaborator

@jeffdaily jeffdaily commented Nov 15, 2024

@jeffdaily jeffdaily added module: rocm AMD GPU support for Pytorch rocm This tag is for PRs from ROCm team ciflow/rocm Trigger "default" config CI on ROCm labels Nov 15, 2024
@jeffdaily jeffdaily requested a review from jwfromm November 15, 2024 23:35
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140856

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 64cef29 with merge base a7509e9 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jeffdaily jeffdaily changed the title Ck rowwise f8 fbgemm [ROCm] port CK rowwise F8 from fbgemm Nov 15, 2024
@cpuhrsch cpuhrsch requested a review from drisspg November 16, 2024 00:38
@drisspg drisspg added release notes: nn release notes category module: floatx (formerly float8) For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float types skip-pr-sanity-checks labels Nov 16, 2024
@drisspg
Copy link
Contributor

drisspg commented Nov 16, 2024

Left a few comments, I think it looks good. Would be good to also note the increase in binary size from these PR


x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).t()
x_fp8 = x.to(e4m3_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that the code to automatically set this was added in #117822 - IMO in a separate PR we should change these dtypes to be set explicitly by the user / testing framework, to follow the convention used elsewhere in similar files and make it crystal clear which dtypes are being tested where.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree but am open to being convinced otherwise. For F8 types there are effectively two sets (e4m3fnuz/e5m2fnuz and e4m3fn/e5m2) and the abstraction introduced in #117822 is useful to avoid much copy/paste code, or decorating many unit tests with allowed types. I like the current solution because it is compact. But developer education was missing; new unit tests that work on CUDA and do not use the abstracted types will not work on ROCm.

@jeffdaily
Copy link
Collaborator Author

Left a few comments, I think it looks good. Would be good to also note the increase in binary size from these PR

Binary size increased by 5.4MB.

@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 18, 2024
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 19, 2024
@jeffdaily
Copy link
Collaborator Author

Before landing need to verify this still builds okay for gfx1100 etc.

@jeffdaily jeffdaily requested a review from drisspg November 22, 2024 20:45
@atalman
Copy link
Contributor

atalman commented Dec 5, 2024

@pytorchmergebot revert -c ghfirst -m "Failing internal build"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Dec 5, 2024
This reverts commit 291626f.

Reverted #140856 on behalf of https://github.com/atalman due to Failing internal build ([comment](#140856 (comment)))
@pytorchmergebot
Copy link
Collaborator

@jeffdaily your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Dec 5, 2024
@facebook-github-bot
Copy link
Contributor

@atalman has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
This reverts commit 291626f.

Reverted pytorch#140856 on behalf of https://github.com/atalman due to Failing internal build ([comment](pytorch#140856 (comment)))
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
@drisspg
Copy link
Contributor

drisspg commented Dec 10, 2024

So I think the problem with the PR as stands is this line: https://github.com/pytorch/pytorch/pull/140856/files#diff-19b256efe989af74ad429ef2a1eb6e075784aa18aea04c7d36bb0e790e9a8170R19

Including of torch.h is typically done for external c++ extensions and messes w/ some internal build systems. Proper fix would be to refine which headers are needed for building

cc @jeffdaily

@jeffdaily
Copy link
Collaborator Author

@drisspg removing the #include of torch.h had no effect on the cmake build. Would you be able to check the internal build?

@drisspg
Copy link
Contributor

drisspg commented Dec 11, 2024

@jeffdaily Yeah will do

@drisspg
Copy link
Contributor

drisspg commented Dec 12, 2024

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/140856/head returned non-zero exit code 1

Rebasing (1/7)
Auto-merging aten/src/ATen/CMakeLists.txt
CONFLICT (content): Merge conflict in aten/src/ATen/CMakeLists.txt
Auto-merging aten/src/ATen/native/cuda/Blas.cpp
Auto-merging test/test_matmul_cuda.py
CONFLICT (content): Merge conflict in test/test_matmul_cuda.py
error: could not apply 1276c3dce54... [ROCm] CK f8 rowwise gemm
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config advice.mergeConflict false"
Could not apply 1276c3dce54... [ROCm] CK f8 rowwise gemm

Raised by https://github.com/pytorch/pytorch/actions/runs/12304294940

@drisspg
Copy link
Contributor

drisspg commented Dec 12, 2024

@jeffdaily could you help rebase and then I can import and land internally

@jeffdaily
Copy link
Collaborator Author

@jeffdaily could you help rebase and then I can import and land internally

Done.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jeffdaily
Copy link
Collaborator Author

@jeffdaily could you help rebase and then I can import and land internally

Any status update?

@drisspg
Copy link
Contributor

drisspg commented Dec 16, 2024

@jeffdaily Trying to get a amd gpu to test, there are some issues but I wanna see if I can patch them internally

drisspg pushed a commit to drisspg/pytorch that referenced this pull request Dec 17, 2024
Summary:
This ports (copies) FBGEMM's implementation from jwfromm.

https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise

cc sunway513 jithunnair-amd pruthvistony ROCmSupport dllehr-amd jataylo hongxiayang naromero77amd yanbing-j vkuzo albanD kadeng penguinwu

Pull Request resolved: pytorch#140856

Reviewed By: atalman

Differential Revision: D66797096

Pulled By: drisspg
@drisspg
Copy link
Contributor

drisspg commented Dec 17, 2024

Had to unlink and re-export: #143416

@jeffdaily
Copy link
Collaborator Author

Closing in favor of #143416.

@jeffdaily jeffdaily closed this Jan 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: floatx (formerly float8) For torch.float8_e5m2 and torch.float8_e4m3 and other sub 8-bit float types module: rocm AMD GPU support for Pytorch open source release notes: nn release notes category Reverted rocm priority high priority ROCm PRs from performance or other aspects rocm This tag is for PRs from ROCm team skip-pr-sanity-checks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants