[ROCm] Enable FSDP2 Float8 and affine quantized tensor parallel tests on ROCm#3992
Merged
danielvegamyhre merged 3 commits intoMar 11, 2026
Merged
Conversation
… on ROCm Remove blanket ROCm test skips and fix FP8 hardware capability gates to support AMD MI300/MI350 GPUs alongside NVIDIA SM89+/SM90+. test/float8/test_fsdp2/test_fsdp2.py: - Replace dual module-level skip (is_sm_at_least_89 + ROCm skip) with a single gate: is_sm_at_least_89() or is_MI300() or is_MI350() - Import e4m3_dtype from config and use it in test_amax_allreduce_device_mesh instead of hardcoded torch.float8_e4m3fn (MI300 uses float8_e4m3fnuz) test/dtypes/test_affine_quantized_tensor_parallel.py: - Remove module-level pytest.skip on ROCm that blocked all TP tests (Int8wo, Int4wo, Int8dq) even though they have no FP8 dependency - Fix Float8 TP class gate: use is_sm_at_least_90() instead of raw get_device_capability() >= (9, 0), which incorrectly passes on ROCm where gfx90a (MI250X) maps to (9, 0) despite lacking FP8 support Validated on MI250X (gfx90a, 8 GPUs): - FSDP2 Float8: correctly skipped (MI250X lacks FP8) - Affine quantized TP: 4 passed, 2 skipped (Int8wo 3/3, Int8dq 1/1) - Float8 TP classes correctly not defined on non-FP8 hardware
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3992
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit dd82a6b with merge base f04500f ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
|
@brucechanglongxu please fix ruff linter error fyi @jerryzh168 as this touches AQT test as well |
…ensor_parallel.py The pytest import was left over after removing the module-level pytest.skip on ROCm.
Contributor
|
please fix ruff linter @brucechanglongxu |
danielvegamyhre
approved these changes
Mar 10, 2026
danielvegamyhre
approved these changes
Mar 11, 2026
danielvegamyhre
pushed a commit
that referenced
this pull request
Mar 11, 2026
… on ROCm (#3992) * [ROCm] Enable FSDP2 Float8 and affine quantized tensor parallel tests on ROCm Remove blanket ROCm test skips and fix FP8 hardware capability gates to support AMD MI300/MI350 GPUs alongside NVIDIA SM89+/SM90+. test/float8/test_fsdp2/test_fsdp2.py: - Replace dual module-level skip (is_sm_at_least_89 + ROCm skip) with a single gate: is_sm_at_least_89() or is_MI300() or is_MI350() - Import e4m3_dtype from config and use it in test_amax_allreduce_device_mesh instead of hardcoded torch.float8_e4m3fn (MI300 uses float8_e4m3fnuz) test/dtypes/test_affine_quantized_tensor_parallel.py: - Remove module-level pytest.skip on ROCm that blocked all TP tests (Int8wo, Int4wo, Int8dq) even though they have no FP8 dependency - Fix Float8 TP class gate: use is_sm_at_least_90() instead of raw get_device_capability() >= (9, 0), which incorrectly passes on ROCm where gfx90a (MI250X) maps to (9, 0) despite lacking FP8 support Validated on MI250X (gfx90a, 8 GPUs): - FSDP2 Float8: correctly skipped (MI250X lacks FP8) - Affine quantized TP: 4 passed, 2 skipped (Int8wo 3/3, Int8dq 1/1) - Float8 TP classes correctly not defined on non-FP8 hardware * Fix ruff F401: remove unused pytest import in test_affine_quantized_tensor_parallel.py The pytest import was left over after removing the module-level pytest.skip on ROCm. * Fix ruff format: break long pytest.skip line
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Currently these tests are blanket-skipped on ROCm even though most of them have no FP8 dependency. This patch removes the unconditional ROCm skips and replaces the hardware capability gates with proper checks that cover both CUDA (SM89+/SM90+) and ROCm (MI300/MI350).
Changes in
test/float8/test_fsdp2/test_fsdp2.py:is_sm_at_least_89, one for ROCm) into a single gate that also accepts MI300/MI350.e4m3_dtypefrom config instead of hardcodedtorch.float8_e4m3fnintest_amax_allreduce_device_mesh, since MI300 usesfloat8_e4m3fnuz.Changes in
test/dtypes/test_affine_quantized_tensor_parallel.py:pytest.skipon ROCm that was blocking all TP tests including Int8wo, Int4wo, and Int8dq which don't require FP8.is_sm_at_least_90()instead of rawget_device_capability() >= (9, 0). The old check incorrectly passes on MI250X (gfx90a reports capability 9.0 but lacks FP8).Tested on MI250X (gfx90a, 8 GPUs): FSDP2 Float8 tests correctly skip, affine quantized TP passes 4/6 (Int8wo 3/3, Int8dq 1/1, Float8 classes correctly not defined).
cc: @danielvegamyhre