Skip to content

[rocm] scaled_grouped_mm support gfx942 fp8 data type#3540

Merged
vkuzo merged 10 commits into
pytorch:mainfrom
xiaobochen-amd:dev
Jan 23, 2026
Merged

[rocm] scaled_grouped_mm support gfx942 fp8 data type#3540
vkuzo merged 10 commits into
pytorch:mainfrom
xiaobochen-amd:dev

Conversation

@xiaobochen-amd

Copy link
Copy Markdown
Contributor

This PR adds rowwise support for scaled_grouped_mm on gfx942 with float8 dtypes. PyTorch already provides float8 support on gfx942, and this change aligns torchao with the existing PyTorch capability. Relevant unit tests have been added and all tests pass.

docker:  rocm/primus:v25.10

torch==2.11.0.dev20251221+rocm7.1

pytest test/prototype/moe_training/test_scaled_grouped_mm.py

@pytorch-bot

pytorch-bot Bot commented Dec 25, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3540

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 63f1339 with merge base 3350b2f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla

meta-cla Bot commented Dec 25, 2025

Copy link
Copy Markdown

Hi @xiaobochen-amd!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 25, 2025
@meta-cla

meta-cla Bot commented Dec 25, 2025

Copy link
Copy Markdown

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@danielvegamyhre danielvegamyhre self-requested a review January 6, 2026 22:00
)
assert torch.equal(out, ref_out)
# FP8 matmul allows some error due to precision limits and accumulation order differences
assert torch.allclose(out, ref_out, rtol=1e-2, atol=2.5)

@danielvegamyhre danielvegamyhre Jan 6, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right, why can we use torch.equal with CUDA but need to have a very permissive threshold for ROCM? The operations should be identical.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mathematically, the operations are the same, but in real floating-point computations, numerical errors are unavoidable. Introducing a tolerance is therefore a standard industry practice. Major operator libraries, such as FlashInfer and TransformerEngine, follow this approach as well.

https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/test_float8_blockwise_gemm_exact.py#L206

https://github.com/flashinfer-ai/flashinfer/blob/main/tests/gemm/test_groupwise_scaled_gemm_fp8.py#L76

@danielvegamyhre danielvegamyhre Jan 7, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand, but my point is that the CUDA code path is also doing real low-precision FP8 computations, and the resulting difference is within the tolerance threshold of torch.equal.

Your new line of code indicates that the AMD code path accumulates much more error. Why? rtol=1e-2 and atol=2.5 is extremely permissive as well - those would be the most relaxed thresholds in the codebase I think, so I am scrutinizing this carefully.

To clarify what this test is doing:

  • Verify doing a fp8 rowwise torch._scaled_grouped_mm is bitwise identical to doing a for-loop over the groups, doing a separate torch._scaled_mm for each, and concatenating the results.
  • In the CUDA codepath, the results are bitwise identical. In this ROCM codepath, apparently they are not?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference mainly comes from using different implementations: on ROCm, scaled_mm is implemented via hipBLASLt, while scaled_grouped_mm uses CK. These two libraries do not guarantee the same FP8 compute/accumulation behavior (e.g., accumulation order, internal kernel numerics), so bitwise-identical results are not expected and small numerical differences are reasonable.

Also, this case is very large with deep accumulation, so differences from floating-point rounding can accumulate. As noted in the comment at line 236 of test/prototype/moe_training/test_scaled_grouped_mm.py, the max delta occurs at 260 vs 262, which is consistent with larger absolute error at larger magnitudes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference mainly comes from using different implementations: on ROCm, scaled_mm is implemented via hipBLASLt, while scaled_grouped_mm uses CK. These two libraries do not guarantee the same FP8 compute/accumulation behavior (e.g., accumulation order, internal kernel numerics), so bitwise-identical results are not expected and small numerical differences are reasonable.

this makes sense, let's only enable the looser tolerance for AMD then and add a comment explaining why.

@danielvegamyhre danielvegamyhre Jan 8, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense, but for CUDA we also use different implementations though (scaled_mm = dispatch to cublas, scaled_grouped_mm = dispatch to a cutlass kernel)?

I also notice this changes the test from total_M=131072 to 4096, so the individual gemms in the grouped gemm are much smaller? what is the error with the original test?

@danielvegamyhre danielvegamyhre Jan 9, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense, let's only enable the looser tolerance for AMD then and add a comment explaining why.

(to clarify, this is fine, but i am very curious about the answers to my questions above^. could it be the case the cublas gemm and cutlass grouped gemm impls coincidentally have the same reduction strategies for this shape size, and thus same exact accum rounding error? that would be interesting)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this case is very large with deep accumulation, so differences from floating-point rounding can accumulate. As noted in the comment at line 236 of test/prototype/moe_training/test_scaled_grouped_mm.py, the max delta occurs at 260 vs 262, which is consistent with larger absolute error at larger magnitudes.

131072 is too large and causes an int32 overflow in the _triton_fp8_per_group_colwise_scales_kernel. I have already fixed this issue.

Comment thread test/dtypes/test_nf4.py Outdated
)

def _test_comm(self, input_size: int):
from torchao.utils import is_ROCM

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to test_comm to match codebase convention

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

logger: logging.Logger = logging.getLogger(__name__)


def _get_float8_dtype():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the util is fine, but this should be passed in as configuration to the top level API not looked up just-in-time to minimize complexity

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I wrote it this way is that the dtype was previously hardcoded as torch.float8_e4m3fn, but AMD MI300 requires torch.float8_e4m3fnuz. So I added this function to automatically select the correct dtype based on the platform.
However, I'm not entirely sure how you'd like this to be designed. Could you please clarify?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of

def top_level_user_api(...):
    float8_dtype = _get_float8_dtype()
    ...

it would be better to do this:

def top_level_user_api(
    float8_dtype: torch.dtype = torch.float8_e4m3fn,
    ...,
):
    ...

# user overrides the dtype
float8_dtype_for_amd = _get_float8_dtype_from_user_hardware()
top_level_user_api(float8_dtype=float8_dtype_for_amd)

this ensures that the "magic" dtype selection happens in one place and is controlled by the user

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve made the changes.


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("m", [131072])
def _get_float8_dtype():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse the util instead of copy-pasting it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@xiaobochen-amd xiaobochen-amd force-pushed the dev branch 2 times, most recently from 438bbaf to 8ecf355 Compare January 13, 2026 03:40
@xiaobochen-amd

Copy link
Copy Markdown
Contributor Author

Hi, @danielvegamyhre @vkuzo is there anything else in this PR that you would like me to modify or address?

)
assert torch.equal(result1, ref_group_result1)
assert torch.equal(result2, ref_group_result2)
# FP8 matmul allows some error due to precision limits and accumulation order differences.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be gated by rocm

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

)
group_row_end_idx = tl.load(offsets_ptr + offset_idx)
block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Force int64 math for pointer offsets (avoid i32 overflow on large problems)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts about splitting this change to a separate PR? It's hard to reason about the current PR with all three of these mixed together:

  1. adding ROCM
  2. very loose tolerances on ROCM (which is a bit unexpected)
  3. a fix in this kernel (not clear whether this is for ROCM, CUDA or both, and whether this fix affects the tolerances in (2))

@xiaobochen-amd xiaobochen-amd Jan 13, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo If I split this into multiple PRs, I would need to modify the unit tests as well. Specifically, in test_valid_scaled_grouped_mm_2d_3d, using m = 131072 can trigger an int32 overflow. This is the reason I previously removed that case. danielvegamyhre had asked about it earlier, and this was the underlying issue. Now that I have added m = 131072 back, the int32 overflow issue needs to be fixed accordingly.

At the moment, we are actively working on enabling low-precision support for TorchTitan on AMD GPUs, and this PR is required to move that effort forward. Given this context, would it be possible to keep this change in a single PR instead of splitting it into multiple ones? That would help us move the TorchTitan work forward more efficiently.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it isn't clear how fixing the int32 overflow impacts the tolerances for AMD, would be good to clarify this. Can you set tolerance to zero after your int32 fix? Can you make the tolerance tighter after it? etc

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo
I have updated rtol and atol to 1e-2, and the tests pass locally. In my understanding, this is a relatively strict tolerance for FP8. I also referenced FlashInfer, which uses the same values. flashinfer-link

The int32 overflow is not related to tolerances. The issue is that the current test case has very large shapes, causing the kernel to exceed the int32 range when computing indices, which leads to incorrect results.

Could you please review it again? Thanks.

"B should be a ScaledGroupedMMTensor"
)
scaling_type = B.scaling_type
float8_dtype = torch.float8_e4m3fnuz if is_MI300() else torch.float8_e4m3fn

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be passed by the user at the very top level of the API, not set automagically in the middle of the codebase

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

# - scaled_mm is implemented via hipBLASLt
# - scaled_grouped_mm uses CK
# These do not guarantee identical FP8 compute/accumulation behavior (e.g. accumulation order),
# and this test is very large (deep accumulation), so rounding differences can accumulate.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't feel great about this. The current comment says the differences are magnified because of the large problem size. Should we add a test for a small problem size which matches exactly on ROCM? rtol/atol 1e-2 can mask various underlying issues with the kernels.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@vkuzo vkuzo left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm at a high level, an equality check in unit tests with a smaller problem size would be nice. I'll let @danielvegamyhre review the triton code.

@vkuzo

vkuzo commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

looks like ruff is failing, could you fix that

@xiaobochen-amd

Copy link
Copy Markdown
Contributor Author

looks like ruff is failing, could you fix that

Fixed

@xiaobochen-amd

Copy link
Copy Markdown
Contributor Author

lgtm at a high level, an equality check in unit tests with a smaller problem size would be nice. I'll let @danielvegamyhre review the triton code.

Based on my experience, numerical differences can still be observed even when the shape is reduced. The issue is not necessarily that “large shapes are inherently inaccurate,” but rather that it is difficult to guarantee that different kernels (or different library implementations) are fully aligned in the following implementation details:

  • blocking / tiling strategy (block size)
  • reduction path
  • accumulation count and accumulation order

As the problem size grows, the number of accumulation steps increases and the numerical error may further accumulate. That said, even in higher-precision formats such as FP16 or BF16, differences can still be observed when these implementation details differ. In FP8, this effect is more pronounced.

Regarding the CUDA path currently achieving exact equal, I am not sure of the underlying reason. From my prior experience, both in PyTorch’s own operator tests and in other operator libraries (e.g., flashinfer, transformer-engine), it is uncommon to use equal as the correctness criterion for GEMM-like operators.

I also noticed that in test_scaled_grouped_mm.py, the MXFP8 tests use an SNR metric rather than equal. My understanding is that SNR is a looser validation criterion compared to rtol/atol. Could you share the considerations behind choosing SNR instead of equal in this context?

@vkuzo

vkuzo commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

Based on my experience, numerical differences can still be observed even when the shape is reduced. The issue is not necessarily that “large shapes are inherently inaccurate,” but rather that it is difficult to guarantee that different kernels (or different library implementations) are fully aligned in the following implementation details:

agreed, the tolerance in the current PR is very high though, actual numerical bugs can slip past it. I think adding some more narrow cases that can pass with a tigher tolerance would be good.

@vkuzo

vkuzo commented Jan 14, 2026

Copy link
Copy Markdown
Contributor

I also noticed that in test_scaled_grouped_mm.py, the MXFP8 tests use an SNR metric rather than equal.

for comparing high precision output to quantized, SQNR makes sense as it measures the strength of the signal vs quantization noise. For comparing quantized output with kernel 1 vs quantized output with kernel 2, something like MSE or comparing the values directly makes more sense to me.

block_col_offs = block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Force int64 math for pointer offsets (avoid i32 overflow on large problems)
block_col_offs = (block_col_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)).to(tl.int64)
stride_input_row_i64 = tl.full((), stride_input_row, tl.int64)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than doing this you can just define the type to be tl.int64 in the kernel function signature i believe. See how _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel does it. we also use int64 there to prevent a similar overflow issue.

Comment thread torchao/testing/utils.py
if message:
skip_message += f": {message}"
pytest.skip(skip_message)
raise unittest.SkipTest(skip_message)

@danielvegamyhre danielvegamyhre Jan 14, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you switching to unittest from pytest here? would prefer not to do that please

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn’t a framework preference change. ROCm doesn’t support nvfp4, so we need to skip ROCm in test_nvfp4.py. In that file, TestComm inherits from FSDPTest, which is a unittest.TestCase-style test. Therefore the skip needs to be implemented via unittest.SkipTest; pytest-style skip/markers won’t take effect under the unittest runner. Pytest can still collect and honor unittest.SkipTest, so this remains compatible with both runners.

@xiaobochen-amd

Copy link
Copy Markdown
Contributor Author

@vkuzo @danielvegamyhre

I reproduced this in a CUDA environment, and the conclusion is that this test cannot achieve equal on CUDA either.

Environment:

  • GPU: H200
  • Docker image: pytorch/pytorch:2.9.1-cuda12.8-cudnn9-devel
  • CUDA: 12.8
  • Driver: 570.172.08
  • TorchAO commit: 21acb9c

Repro steps:

pytest ./test/prototype/moe_training/test_scaled_grouped_mm.py::test_valid_scaled_grouped_mm_2d_3d

Results:

All 4 cases fail. The failures include both:

  • the equal-related numerical comparison issue we have been discussing, and
  • the Triton int32 overflow issue

Error Log:
h200_test_error_log.txt

@pytest.mark.parametrize("m", [131072])
@pytest.mark.parametrize("n", [8192])
@pytest.mark.parametrize("k", [5120])
@pytest.mark.parametrize("m", [256, 1024, 4096, 131072])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a lot of combinations across mkn, does the test finish in a reasonable amount of time?

I'd recommend limiting to max of 3-5 combinations / O(low seconds) runtime instead of having 4 * 3 * 3 here

# - scaled_grouped_mm uses CK
# These do not guarantee identical FP8 compute/accumulation behavior (e.g. accumulation order),
# and this test is very large (deep accumulation), so rounding differences can accumulate.
assert torch.allclose(out, ref_out, rtol=1e-2, atol=1e-2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have added small test cases but did not make the tolerance tighter for them, is that expected? this doesn't seem to match the this test is very large (deep accumulation), so rounding differences can accumulate. comment

# FP8 matmul allows some error due to precision limits and accumulation order differences.
# Tested with M=131072, K=5120, N=8192, bfloat16 output:
# 99.9986% of points have error < 0.1, max error ~2 (-262 vs -260, relative error ~0.77%)
assert torch.allclose(result1, ref_group_result1, rtol=1e-2, atol=1e-2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make tolerance tighter for small problem sizes

@xiaobochen-amd

xiaobochen-amd commented Jan 21, 2026

Copy link
Copy Markdown
Contributor Author

@vkuzo 1e-2 is already relatively strict for FP8. Tightening it further would cause some test cases to fail. Typically, BF16/FP16 use tighter tolerances, such as 1e-3 or 1e-4. In low-precision computation, as the problem size increases, the tolerance typically needs to be moderately relaxed.
If you print the mean absolute error and mean relative error for different shapes, you can observe that the error increases as the shape grows.

@xiaobochen-amd

Copy link
Copy Markdown
Contributor Author

@vkuzo 1e-2 is already relatively strict for FP8. Tightening it further would cause some test cases to fail. Typically, BF16/FP16 use tighter tolerances, such as 1e-3 or 1e-4. In low-precision computation, as the problem size increases, the tolerance typically needs to be moderately relaxed. If you print the mean absolute error and mean relative error for different shapes, you can observe that the error increases as the shape grows.

@vkuzo I also tested on H200, and it did not pass either.

and in column-major memory layout.
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
float8_dtype (torch.dtype): The float8 dtype to use for quantization. Default is torch.float8_e4m3fn.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may need to be renamed later, but doesn't have to be in this PR

@vkuzo vkuzo left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you!

@vkuzo vkuzo added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) ciflow/rocm-mi300 labels Jan 22, 2026
@pytorch-bot

pytorch-bot Bot commented Jan 22, 2026

Copy link
Copy Markdown

Unknown label ciflow/rocm-mi300.
Currently recognized labels are

  • ciflow/benchmark
  • ciflow/tutorials
  • ciflow/rocm
  • ciflow/4xh100
  • ciflow/xpu

@pytorch-bot

pytorch-bot Bot commented Jan 22, 2026

Copy link
Copy Markdown

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot Bot removed the ciflow/rocm label Jan 22, 2026
@vkuzo vkuzo merged commit 2540ac4 into pytorch:main Jan 23, 2026
23 of 24 checks passed
danielvegamyhre added a commit that referenced this pull request Jan 23, 2026
@danielvegamyhre

Copy link
Copy Markdown
Contributor

@xiaobochen-amd unfortunately we need to revert this as it seems to have broken MoE training due to some callsites haven't been properly updated to correctly handle the new param. Starting next week we'll support emulated mode to run these tests in CI, but in the meantime can you please fix then run the tests locally before submitting? thanks!

danielvegamyhre added a commit that referenced this pull request Jan 23, 2026
Revert "[rocm] scaled_grouped_mm support gfx942 fp8 data type (#3540)"

This reverts commit 2540ac4.
@xiaobochen-amd

Copy link
Copy Markdown
Contributor Author

@xiaobochen-amd unfortunately we need to revert this as it seems to have broken MoE training due to some callsites haven't been properly updated to correctly handle the new param. Starting next week we'll support emulated mode to run these tests in CI, but in the meantime can you please fix then run the tests locally before submitting? thanks!

@danielvegamyhre Thanks for the heads up. Could you please point out which callsites are not updated correctly, or share the failing logs? That would help me fix it much faster.

@danielvegamyhre

Copy link
Copy Markdown
Contributor

@xiaobochen-amd unfortunately we need to revert this as it seems to have broken MoE training due to some callsites haven't been properly updated to correctly handle the new param. Starting next week we'll support emulated mode to run these tests in CI, but in the meantime can you please fix then run the tests locally before submitting? thanks!

@danielvegamyhre Thanks for the heads up. Could you please point out which callsites are not updated correctly, or share the failing logs? That would help me fix it much faster.

Sure, the most important one is here:

return _MXFP8GroupedMM.apply(

The unit tests you'll want are:

  • test/prototype/moe_training/test_training.py
  • test/prototype/moe_training/test_scaled_grouped_mm.py
  • test/prototype/moe_training/ep/test_integration.py
  • test/prototype/moe_training/ep/test_compile.py

Please also double check benchmarks in benchmarks/prototype/moe_training

@danielvegamyhre

Copy link
Copy Markdown
Contributor

@xiaobochen-amd you'll also need to update the number of gradients returned from the autograd func, now that you're adding a new parameter to forward().

I have a draft PR I made earlier to do the minimal changes to get it working locally (and fix an unrelated issue) if you want to use it to get started. There may be other changes needed for all tests to pass though: #3712

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. device: rocm topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants