Skip to content

[ROCm] Enable FSDP2 Float8 and affine quantized tensor parallel tests on ROCm#3992

Merged
danielvegamyhre merged 3 commits into
pytorch:mainfrom
brucechanglongxu:fsdp2-rocm-enablement
Mar 11, 2026
Merged

[ROCm] Enable FSDP2 Float8 and affine quantized tensor parallel tests on ROCm#3992
danielvegamyhre merged 3 commits into
pytorch:mainfrom
brucechanglongxu:fsdp2-rocm-enablement

Conversation

@brucechanglongxu

@brucechanglongxu brucechanglongxu commented Mar 4, 2026

Copy link
Copy Markdown
Contributor

Currently these tests are blanket-skipped on ROCm even though most of them have no FP8 dependency. This patch removes the unconditional ROCm skips and replaces the hardware capability gates with proper checks that cover both CUDA (SM89+/SM90+) and ROCm (MI300/MI350).

Changes in test/float8/test_fsdp2/test_fsdp2.py:

  • Collapse the two separate module-level skips (one for is_sm_at_least_89, one for ROCm) into a single gate that also accepts MI300/MI350.
  • Use e4m3_dtype from config instead of hardcoded torch.float8_e4m3fn in test_amax_allreduce_device_mesh, since MI300 uses float8_e4m3fnuz.

Changes in test/dtypes/test_affine_quantized_tensor_parallel.py:

  • Remove the module-level pytest.skip on ROCm that was blocking all TP tests including Int8wo, Int4wo, and Int8dq which don't require FP8.
  • Fix the Float8 TP class gate to use is_sm_at_least_90() instead of raw get_device_capability() >= (9, 0). The old check incorrectly passes on MI250X (gfx90a reports capability 9.0 but lacks FP8).

Tested on MI250X (gfx90a, 8 GPUs): FSDP2 Float8 tests correctly skip, affine quantized TP passes 4/6 (Int8wo 3/3, Int8dq 1/1, Float8 classes correctly not defined).

cc: @danielvegamyhre

… on ROCm

Remove blanket ROCm test skips and fix FP8 hardware capability gates
to support AMD MI300/MI350 GPUs alongside NVIDIA SM89+/SM90+.

test/float8/test_fsdp2/test_fsdp2.py:
- Replace dual module-level skip (is_sm_at_least_89 + ROCm skip) with
  a single gate: is_sm_at_least_89() or is_MI300() or is_MI350()
- Import e4m3_dtype from config and use it in test_amax_allreduce_device_mesh
  instead of hardcoded torch.float8_e4m3fn (MI300 uses float8_e4m3fnuz)

test/dtypes/test_affine_quantized_tensor_parallel.py:
- Remove module-level pytest.skip on ROCm that blocked all TP tests
  (Int8wo, Int4wo, Int8dq) even though they have no FP8 dependency
- Fix Float8 TP class gate: use is_sm_at_least_90() instead of raw
  get_device_capability() >= (9, 0), which incorrectly passes on ROCm
  where gfx90a (MI250X) maps to (9, 0) despite lacking FP8 support

Validated on MI250X (gfx90a, 8 GPUs):
- FSDP2 Float8: correctly skipped (MI250X lacks FP8)
- Affine quantized TP: 4 passed, 2 skipped (Int8wo 3/3, Int8dq 1/1)
- Float8 TP classes correctly not defined on non-FP8 hardware
@pytorch-bot

pytorch-bot Bot commented Mar 4, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3992

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit dd82a6b with merge base f04500f (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2026
@danielvegamyhre danielvegamyhre self-requested a review March 6, 2026 21:18
@danielvegamyhre

Copy link
Copy Markdown
Contributor

@brucechanglongxu please fix ruff linter error

fyi @jerryzh168 as this touches AQT test as well

…ensor_parallel.py

The pytest import was left over after removing the module-level
pytest.skip on ROCm.
@danielvegamyhre

Copy link
Copy Markdown
Contributor

please fix ruff linter @brucechanglongxu

@danielvegamyhre danielvegamyhre merged commit 605a22e into pytorch:main Mar 11, 2026
18 of 19 checks passed
danielvegamyhre pushed a commit that referenced this pull request Mar 11, 2026
… on ROCm (#3992)

* [ROCm] Enable FSDP2 Float8 and affine quantized tensor parallel tests on ROCm

Remove blanket ROCm test skips and fix FP8 hardware capability gates
to support AMD MI300/MI350 GPUs alongside NVIDIA SM89+/SM90+.

test/float8/test_fsdp2/test_fsdp2.py:
- Replace dual module-level skip (is_sm_at_least_89 + ROCm skip) with
  a single gate: is_sm_at_least_89() or is_MI300() or is_MI350()
- Import e4m3_dtype from config and use it in test_amax_allreduce_device_mesh
  instead of hardcoded torch.float8_e4m3fn (MI300 uses float8_e4m3fnuz)

test/dtypes/test_affine_quantized_tensor_parallel.py:
- Remove module-level pytest.skip on ROCm that blocked all TP tests
  (Int8wo, Int4wo, Int8dq) even though they have no FP8 dependency
- Fix Float8 TP class gate: use is_sm_at_least_90() instead of raw
  get_device_capability() >= (9, 0), which incorrectly passes on ROCm
  where gfx90a (MI250X) maps to (9, 0) despite lacking FP8 support

Validated on MI250X (gfx90a, 8 GPUs):
- FSDP2 Float8: correctly skipped (MI250X lacks FP8)
- Affine quantized TP: 4 passed, 2 skipped (Int8wo 3/3, Int8dq 1/1)
- Float8 TP classes correctly not defined on non-FP8 hardware

* Fix ruff F401: remove unused pytest import in test_affine_quantized_tensor_parallel.py

The pytest import was left over after removing the module-level
pytest.skip on ROCm.

* Fix ruff format: break long pytest.skip line
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: rocm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants