[rocm] scaled_grouped_mm support gfx942 fp8 data type#3540
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3540
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 63f1339 with merge base 3350b2f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @xiaobochen-amd! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
| ) | ||
| assert torch.equal(out, ref_out) | ||
| # FP8 matmul allows some error due to precision limits and accumulation order differences | ||
| assert torch.allclose(out, ref_out, rtol=1e-2, atol=2.5) |
There was a problem hiding this comment.
This doesn't seem right, why can we use torch.equal with CUDA but need to have a very permissive threshold for ROCM? The operations should be identical.
There was a problem hiding this comment.
Mathematically, the operations are the same, but in real floating-point computations, numerical errors are unavoidable. Introducing a tolerance is therefore a standard industry practice. Major operator libraries, such as FlashInfer and TransformerEngine, follow this approach as well.
There was a problem hiding this comment.
I understand, but my point is that the CUDA code path is also doing real low-precision FP8 computations, and the resulting difference is within the tolerance threshold of torch.equal.
Your new line of code indicates that the AMD code path accumulates much more error. Why? rtol=1e-2 and atol=2.5 is extremely permissive as well - those would be the most relaxed thresholds in the codebase I think, so I am scrutinizing this carefully.
To clarify what this test is doing:
- Verify doing a fp8 rowwise
torch._scaled_grouped_mmis bitwise identical to doing a for-loop over the groups, doing a separatetorch._scaled_mmfor each, and concatenating the results. - In the CUDA codepath, the results are bitwise identical. In this ROCM codepath, apparently they are not?
There was a problem hiding this comment.
The difference mainly comes from using different implementations: on ROCm, scaled_mm is implemented via hipBLASLt, while scaled_grouped_mm uses CK. These two libraries do not guarantee the same FP8 compute/accumulation behavior (e.g., accumulation order, internal kernel numerics), so bitwise-identical results are not expected and small numerical differences are reasonable.
Also, this case is very large with deep accumulation, so differences from floating-point rounding can accumulate. As noted in the comment at line 236 of test/prototype/moe_training/test_scaled_grouped_mm.py, the max delta occurs at 260 vs 262, which is consistent with larger absolute error at larger magnitudes.
There was a problem hiding this comment.
The difference mainly comes from using different implementations: on ROCm, scaled_mm is implemented via hipBLASLt, while scaled_grouped_mm uses CK. These two libraries do not guarantee the same FP8 compute/accumulation behavior (e.g., accumulation order, internal kernel numerics), so bitwise-identical results are not expected and small numerical differences are reasonable.
this makes sense, let's only enable the looser tolerance for AMD then and add a comment explaining why.
There was a problem hiding this comment.
This makes sense, but for CUDA we also use different implementations though (scaled_mm = dispatch to cublas, scaled_grouped_mm = dispatch to a cutlass kernel)?
I also notice this changes the test from total_M=131072 to 4096, so the individual gemms in the grouped gemm are much smaller? what is the error with the original test?
There was a problem hiding this comment.
this makes sense, let's only enable the looser tolerance for AMD then and add a comment explaining why.
(to clarify, this is fine, but i am very curious about the answers to my questions above^. could it be the case the cublas gemm and cutlass grouped gemm impls coincidentally have the same reduction strategies for this shape size, and thus same exact accum rounding error? that would be interesting)
There was a problem hiding this comment.
Also, this case is very large with deep accumulation, so differences from floating-point rounding can accumulate. As noted in the comment at line 236 of test/prototype/moe_training/test_scaled_grouped_mm.py, the max delta occurs at 260 vs 262, which is consistent with larger absolute error at larger magnitudes.
131072 is too large and causes an int32 overflow in the _triton_fp8_per_group_colwise_scales_kernel. I have already fixed this issue.
| ) | ||
|
|
||
| def _test_comm(self, input_size: int): | ||
| from torchao.utils import is_ROCM |
There was a problem hiding this comment.
move this to test_comm to match codebase convention
| logger: logging.Logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _get_float8_dtype(): |
There was a problem hiding this comment.
the util is fine, but this should be passed in as configuration to the top level API not looked up just-in-time to minimize complexity
There was a problem hiding this comment.
The reason I wrote it this way is that the dtype was previously hardcoded as torch.float8_e4m3fn, but AMD MI300 requires torch.float8_e4m3fnuz. So I added this function to automatically select the correct dtype based on the platform.
However, I'm not entirely sure how you'd like this to be designed. Could you please clarify?
There was a problem hiding this comment.
instead of
def top_level_user_api(...):
float8_dtype = _get_float8_dtype()
...it would be better to do this:
def top_level_user_api(
float8_dtype: torch.dtype = torch.float8_e4m3fn,
...,
):
...
# user overrides the dtype
float8_dtype_for_amd = _get_float8_dtype_from_user_hardware()
top_level_user_api(float8_dtype=float8_dtype_for_amd)this ensures that the "magic" dtype selection happens in one place and is controlled by the user
There was a problem hiding this comment.
I’ve made the changes.
|
|
||
| @skip_if_rocm("ROCm not supported") | ||
| @pytest.mark.parametrize("m", [131072]) | ||
| def _get_float8_dtype(): |
There was a problem hiding this comment.
reuse the util instead of copy-pasting it
438bbaf to
8ecf355
Compare
|
Hi, @danielvegamyhre @vkuzo is there anything else in this PR that you would like me to modify or address? |
| ) | ||
| assert torch.equal(result1, ref_group_result1) | ||
| assert torch.equal(result2, ref_group_result2) | ||
| # FP8 matmul allows some error due to precision limits and accumulation order differences. |
There was a problem hiding this comment.
this should be gated by rocm
| ) | ||
| group_row_end_idx = tl.load(offsets_ptr + offset_idx) | ||
| block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| # Force int64 math for pointer offsets (avoid i32 overflow on large problems) |
There was a problem hiding this comment.
thoughts about splitting this change to a separate PR? It's hard to reason about the current PR with all three of these mixed together:
- adding ROCM
- very loose tolerances on ROCM (which is a bit unexpected)
- a fix in this kernel (not clear whether this is for ROCM, CUDA or both, and whether this fix affects the tolerances in (2))
There was a problem hiding this comment.
@vkuzo If I split this into multiple PRs, I would need to modify the unit tests as well. Specifically, in test_valid_scaled_grouped_mm_2d_3d, using m = 131072 can trigger an int32 overflow. This is the reason I previously removed that case. danielvegamyhre had asked about it earlier, and this was the underlying issue. Now that I have added m = 131072 back, the int32 overflow issue needs to be fixed accordingly.
At the moment, we are actively working on enabling low-precision support for TorchTitan on AMD GPUs, and this PR is required to move that effort forward. Given this context, would it be possible to keep this change in a single PR instead of splitting it into multiple ones? That would help us move the TorchTitan work forward more efficiently.
There was a problem hiding this comment.
it isn't clear how fixing the int32 overflow impacts the tolerances for AMD, would be good to clarify this. Can you set tolerance to zero after your int32 fix? Can you make the tolerance tighter after it? etc
There was a problem hiding this comment.
@vkuzo
I have updated rtol and atol to 1e-2, and the tests pass locally. In my understanding, this is a relatively strict tolerance for FP8. I also referenced FlashInfer, which uses the same values. flashinfer-link
The int32 overflow is not related to tolerances. The issue is that the current test case has very large shapes, causing the kernel to exceed the int32 range when computing indices, which leads to incorrect results.
Could you please review it again? Thanks.
| "B should be a ScaledGroupedMMTensor" | ||
| ) | ||
| scaling_type = B.scaling_type | ||
| float8_dtype = torch.float8_e4m3fnuz if is_MI300() else torch.float8_e4m3fn |
There was a problem hiding this comment.
this should be passed by the user at the very top level of the API, not set automagically in the middle of the codebase
| # - scaled_mm is implemented via hipBLASLt | ||
| # - scaled_grouped_mm uses CK | ||
| # These do not guarantee identical FP8 compute/accumulation behavior (e.g. accumulation order), | ||
| # and this test is very large (deep accumulation), so rounding differences can accumulate. |
There was a problem hiding this comment.
I still don't feel great about this. The current comment says the differences are magnified because of the large problem size. Should we add a test for a small problem size which matches exactly on ROCM? rtol/atol 1e-2 can mask various underlying issues with the kernels.
vkuzo
left a comment
There was a problem hiding this comment.
lgtm at a high level, an equality check in unit tests with a smaller problem size would be nice. I'll let @danielvegamyhre review the triton code.
|
looks like ruff is failing, could you fix that |
Fixed |
Based on my experience, numerical differences can still be observed even when the shape is reduced. The issue is not necessarily that “large shapes are inherently inaccurate,” but rather that it is difficult to guarantee that different kernels (or different library implementations) are fully aligned in the following implementation details:
As the problem size grows, the number of accumulation steps increases and the numerical error may further accumulate. That said, even in higher-precision formats such as FP16 or BF16, differences can still be observed when these implementation details differ. In FP8, this effect is more pronounced. Regarding the CUDA path currently achieving exact I also noticed that in |
agreed, the tolerance in the current PR is very high though, actual numerical bugs can slip past it. I think adding some more narrow cases that can pass with a tigher tolerance would be good. |
for comparing high precision output to quantized, SQNR makes sense as it measures the strength of the signal vs quantization noise. For comparing quantized output with kernel 1 vs quantized output with kernel 2, something like MSE or comparing the values directly makes more sense to me. |
| block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | ||
| # Force int64 math for pointer offsets (avoid i32 overflow on large problems) | ||
| block_col_offs = (block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)).to(tl.int64) | ||
| stride_input_row_i64 = tl.full((), stride_input_row, tl.int64) |
There was a problem hiding this comment.
rather than doing this you can just define the type to be tl.int64 in the kernel function signature i believe. See how _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel does it. we also use int64 there to prevent a similar overflow issue.
| if message: | ||
| skip_message += f": {message}" | ||
| pytest.skip(skip_message) | ||
| raise unittest.SkipTest(skip_message) |
There was a problem hiding this comment.
why are you switching to unittest from pytest here? would prefer not to do that please
There was a problem hiding this comment.
This isn’t a framework preference change. ROCm doesn’t support nvfp4, so we need to skip ROCm in test_nvfp4.py. In that file, TestComm inherits from FSDPTest, which is a unittest.TestCase-style test. Therefore the skip needs to be implemented via unittest.SkipTest; pytest-style skip/markers won’t take effect under the unittest runner. Pytest can still collect and honor unittest.SkipTest, so this remains compatible with both runners.
|
I reproduced this in a CUDA environment, and the conclusion is that this test cannot achieve equal on CUDA either. Environment:
Repro steps: Results: All 4 cases fail. The failures include both:
Error Log: |
| @pytest.mark.parametrize("m", [131072]) | ||
| @pytest.mark.parametrize("n", [8192]) | ||
| @pytest.mark.parametrize("k", [5120]) | ||
| @pytest.mark.parametrize("m", [256, 1024, 4096, 131072]) |
There was a problem hiding this comment.
this is a lot of combinations across mkn, does the test finish in a reasonable amount of time?
I'd recommend limiting to max of 3-5 combinations / O(low seconds) runtime instead of having 4 * 3 * 3 here
| # - scaled_grouped_mm uses CK | ||
| # These do not guarantee identical FP8 compute/accumulation behavior (e.g. accumulation order), | ||
| # and this test is very large (deep accumulation), so rounding differences can accumulate. | ||
| assert torch.allclose(out, ref_out, rtol=1e-2, atol=1e-2) |
There was a problem hiding this comment.
we have added small test cases but did not make the tolerance tighter for them, is that expected? this doesn't seem to match the this test is very large (deep accumulation), so rounding differences can accumulate. comment
| # FP8 matmul allows some error due to precision limits and accumulation order differences. | ||
| # Tested with M=131072, K=5120, N=8192, bfloat16 output: | ||
| # 99.9986% of points have error < 0.1, max error ~2 (-262 vs -260, relative error ~0.77%) | ||
| assert torch.allclose(result1, ref_group_result1, rtol=1e-2, atol=1e-2) |
There was a problem hiding this comment.
can we make tolerance tighter for small problem sizes
|
@vkuzo 1e-2 is already relatively strict for FP8. Tightening it further would cause some test cases to fail. Typically, BF16/FP16 use tighter tolerances, such as 1e-3 or 1e-4. In low-precision computation, as the problem size increases, the tolerance typically needs to be moderately relaxed. |
@vkuzo I also tested on H200, and it did not pass either. |
| and in column-major memory layout. | ||
| offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. | ||
| out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported. | ||
| float8_dtype (torch.dtype): The float8 dtype to use for quantization. Default is torch.float8_e4m3fn. |
There was a problem hiding this comment.
this may need to be renamed later, but doesn't have to be in this PR
|
Unknown label
|
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
This reverts commit 2540ac4.
|
@xiaobochen-amd unfortunately we need to revert this as it seems to have broken MoE training due to some callsites haven't been properly updated to correctly handle the new param. Starting next week we'll support emulated mode to run these tests in CI, but in the meantime can you please fix then run the tests locally before submitting? thanks! |
@danielvegamyhre Thanks for the heads up. Could you please point out which callsites are not updated correctly, or share the failing logs? That would help me fix it much faster. |
Sure, the most important one is here: The unit tests you'll want are:
Please also double check benchmarks in |
|
@xiaobochen-amd you'll also need to update the number of gradients returned from the autograd func, now that you're adding a new parameter to forward(). I have a draft PR I made earlier to do the minimal changes to get it working locally (and fix an unrelated issue) if you want to use it to get started. There may be other changes needed for all tests to pass though: #3712 |
This PR adds rowwise support for scaled_grouped_mm on gfx942 with float8 dtypes. PyTorch already provides float8 support on gfx942, and this change aligns torchao with the existing PyTorch capability. Relevant unit tests have been added and all tests pass.