Skip to content

[TRITON] Add FP8 support for gfx1200/gfx1201#2621

Merged
brunomazzottiamd merged 1 commit intoROCm:mainfrom
0xDELUXA:rdna4-fp8-support
Apr 9, 2026
Merged

[TRITON] Add FP8 support for gfx1200/gfx1201#2621
brunomazzottiamd merged 1 commit intoROCm:mainfrom
0xDELUXA:rdna4-fp8-support

Conversation

@0xDELUXA
Copy link
Copy Markdown
Contributor

@0xDELUXA 0xDELUXA commented Apr 5, 2026

Motivation

RDNA4 GPUs (gfx1200, gfx1201) support native FP8 operations but are not recognized as FP8-capable in aiter, causing a RuntimeError: gfx1200 does not support FP8 when attempting to use FP8 with Flash Attention 3.

Technical Details

RDNA4 uses the standard IEEE/OCP float8_e4m3fn and float8_e5m2 formats, identical to gfx950 (MI350X), not the FNUZ variants used by gfx942(MI300X). No dtype replacement mapping is needed since the hardware natively supports the standard formats.

Three files changed:

  • aiter/ops/triton/utils/_triton/arch_info.py: add gfx1200/gfx1201 to is_fp8_avail()
  • aiter/ops/triton/utils/types.py: add gfx1200/gfx1201 to get_fp8_dtypes() and get_fp8_e4m3_dtype()
  • aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py: add gfx1200/gfx1201 to FP8_ARCHS

Test Plan

Tested on AMD Radeon RX 9060 XT (gfx1200) running TheRock ROCm 7.13 on Windows via the FA3 interface:

import sys
sys.path.insert(0, 'path/to/flash-attention/hopper')
import torch
from flash_attn_interface import _flash_attn_forward, _flash_attn_backward

b, s, h, d = 2, 512, 8, 128
dtype = torch.float8_e4m3fn
q = torch.randn(b, s, h, d, device='cuda', dtype=torch.bfloat16).to(dtype)
k = torch.randn(b, s, h, d, device='cuda', dtype=torch.bfloat16).to(dtype)
v = torch.randn(b, s, h, d, device='cuda', dtype=torch.bfloat16).to(dtype)
q_descale = torch.ones(b, h, dtype=torch.float32, device='cuda')
k_descale = torch.ones(b, h, dtype=torch.float32, device='cuda')
v_descale = torch.ones(b, h, dtype=torch.float32, device='cuda')

out, lse, _, _ = _flash_attn_forward(q, k, v, softmax_scale=d**-0.5, causal=True,
    window_size_left=-1, window_size_right=-1,
    q_descale=q_descale, k_descale=k_descale, v_descale=v_descale)

do = torch.randn_like(out)
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.zeros_like(k, dtype=torch.float32)
dv = torch.zeros_like(v, dtype=torch.float32)
_flash_attn_backward(do, q, k, v, out, lse, dq=dq, dk=dk, dv=dv,
    softmax_scale=d**-0.5, is_causal=True, window_size_left=-1, window_size_right=-1)

Test Result

Forward: OK, out=torch.Size([2, 512, 8, 128]), dtype=torch.float32
Backward: OK, dq=torch.Size([2, 512, 8, 128]) torch.float32, dk=torch.Size([2, 512, 8, 128]), dv=torch.Size([2, 512, 8, 128])

Note: torch.nn.functional.scaled_dot_product_attention (SDPA Flash via AOTriton) does not support FP8 on RDNA4; it fails with "mul_cuda" not implemented for 'Float8_e4m3fn'. FA3 via the Triton backend would be a viable path for FP8 attention on RDNA4.

Although FA3 is not officially targeted at RDNA architectures, it can be successfully built. Benchmarking on RDNA4 indicates that its performance is roughly on par with FA2.

Submission Checklist

@0xDELUXA 0xDELUXA requested a review from a team April 5, 2026 21:56
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 5, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2621 --add-label <label>

@0xDELUXA 0xDELUXA force-pushed the rdna4-fp8-support branch 2 times, most recently from c8f9e9d to 2b4a6ef Compare April 5, 2026 23:03
@0xDELUXA 0xDELUXA changed the title [RDNA4] Add FP8 support for gfx1200/gfx1201 [TRITON] Add FP8 support for gfx1200/gfx1201 Apr 5, 2026
Copy link
Copy Markdown
Contributor

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's wait for a thumbs up from @micmelesse before merging.

Comment thread aiter/ops/triton/utils/_triton/arch_info.py Outdated
@brunomazzottiamd
Copy link
Copy Markdown
Contributor

Can you please rebase on top of latest main? I'm seeing some UT failures that were already fixed. Thanks!

@0xDELUXA 0xDELUXA force-pushed the rdna4-fp8-support branch from 2b4a6ef to 61711e4 Compare April 9, 2026 15:14
@0xDELUXA
Copy link
Copy Markdown
Contributor Author

0xDELUXA commented Apr 9, 2026

Can you please rebase on top of latest main? I'm seeing some UT failures that were already fixed. Thanks!

Done. Thanks for the review!

@0xDELUXA 0xDELUXA force-pushed the rdna4-fp8-support branch 2 times, most recently from 437d7fd to b5e7be5 Compare April 9, 2026 15:48
@0xDELUXA 0xDELUXA force-pushed the rdna4-fp8-support branch from b5e7be5 to e4296d3 Compare April 9, 2026 15:49
Copy link
Copy Markdown
Contributor

@micmelesse micmelesse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Let us run it through ci and if everything is green. We can merge it.

@brunomazzottiamd brunomazzottiamd merged commit bbd6ef1 into ROCm:main Apr 9, 2026
44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants