Skip to content

Use relaxed memory ordering for Triton atomics on AMDGPU.#3945

Merged
danielvegamyhre merged 1 commit into
pytorch:mainfrom
wenchenvincent:feat/relaxed_atomics_amdgpu
Feb 27, 2026
Merged

Use relaxed memory ordering for Triton atomics on AMDGPU.#3945
danielvegamyhre merged 1 commit into
pytorch:mainfrom
wenchenvincent:feat/relaxed_atomics_amdgpu

Conversation

@wenchenvincent

Copy link
Copy Markdown
Contributor

tl.atomic_add and the like needs to use relaxed memory ordering on AMDGPU for performance. When the default acquire-release semantic is used, memory fence will be inserted before and after the atomics op and thus hurts performance. Such memory fences are not necessary for the functionality of atomic_add and the like (they usually required for tl.atomic_xchg).

Here is an example result when running the benchmarking for
python benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_rowwise_3d_transpose_rhs.py

Before the changes:
image
After the change:
image

@pytorch-bot

pytorch-bot Bot commented Feb 25, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3945

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 14d3791 with merge base be10b2d (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 25, 2026
@danielvegamyhre danielvegamyhre self-requested a review February 25, 2026 02:41
@danielvegamyhre

Copy link
Copy Markdown
Contributor

Interesting, I'm surprised the performance impact is so outsized on AMD but not CUDA. I just tested these changes locally on b200 and the results were virtually identical.

Thanks for improving this!

Please use ruff check --fix <dirs> and ruff format <dirs> to fix the linter error

@wenchenvincent wenchenvincent force-pushed the feat/relaxed_atomics_amdgpu branch from 9628926 to a95ca6c Compare February 25, 2026 03:37
@danielvegamyhre danielvegamyhre added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) module: training quantize_ api training flow moe ciflow/rocm topic: performance Use this tag if this PR improves the performance of a feature labels Feb 25, 2026
Comment thread torchao/prototype/hqq/kernels.py Outdated
"SPLIT_K": lambda args: 1
if args["IS_BFLOAT16"]
else args["SPLIT_K"], # atomic add not supported for bfloat16
"BLOCK_K": lambda args: (

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

odd this formatting was changed, is your ruff version freshly pip installed from requirements.txt

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I directly installed ruff with pip without installing it from dev-requirements.txt. It seems that the ruff version that I used was 0.15.2, which was newer than 0.11.6. Now I reintalled ruff 0.11.6 and rerun formatting on this file but it didn't make it to the original format. Let me know what you would like to do it. Shall I revert to the original format or can we keep this format change as is?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see yeah just revert the formattting changes to this file please

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

brucechanglongxu added a commit to brucechanglongxu/ao that referenced this pull request Feb 25, 2026
…kernels

The existing autotune configs for the MoE training FP8 kernels use a
single configuration each (e.g., num_warps=4, num_stages=4, one block
size), which prevents Triton's autotuner from finding better configs
for different hardware targets.

Expand the search space to cover:
- Multiple num_warps values (4, 8) to better saturate both NVIDIA
  (warp size 32) and AMD (wavefront size 64) GPU compute units
- Multiple num_stages values for software pipelining flexibility
  across different cache hierarchies
- Multiple block sizes to adapt to varying matrix dimensions

This is complementary to PR pytorch#3945 (relaxed atomics on AMDGPU) and
targets the same kernels.
@wenchenvincent wenchenvincent force-pushed the feat/relaxed_atomics_amdgpu branch from a95ca6c to bf0e89a Compare February 26, 2026 05:12
@pytorch-bot pytorch-bot Bot removed the ciflow/rocm label Feb 26, 2026
@danielvegamyhre

Copy link
Copy Markdown
Contributor

@wenchenvincent looks like there is still an issue:

torchao/prototype/hqq/kernels.py:393:25: F821 Undefined name `torch`
    |
391 |     else:
392 |         # AMD GPUs need relaxed semantics for better performance
393 |         if tl.constexpr(torch.version.hip is not None):
    |                         ^^^^^ F821
394 |             tl.atomic_add(C, acc, mask=mask, sem="relaxed")
395 |         else:
    |
    ```

@wenchenvincent wenchenvincent force-pushed the feat/relaxed_atomics_amdgpu branch from bf0e89a to 14d3791 Compare February 27, 2026 00:31
@wenchenvincent

Copy link
Copy Markdown
Contributor Author

@wenchenvincent looks like there is still an issue:

torchao/prototype/hqq/kernels.py:393:25: F821 Undefined name `torch`
    |
391 |     else:
392 |         # AMD GPUs need relaxed semantics for better performance
393 |         if tl.constexpr(torch.version.hip is not None):
    |                         ^^^^^ F821
394 |             tl.atomic_add(C, acc, mask=mask, sem="relaxed")
395 |         else:
    |
    ```

Sorry, my bad! Had an omission when reverting and adding the changes. It should be fixed now.

@danielvegamyhre danielvegamyhre merged commit a4ae9cc into pytorch:main Feb 27, 2026
18 of 19 checks passed
danielvegamyhre pushed a commit that referenced this pull request Feb 27, 2026
… performance (#3952)

* Expand Triton autotune configs for MoE FP8 rowwise and jagged scales kernels

The existing autotune configs for the MoE training FP8 kernels use a
single configuration each (e.g., num_warps=4, num_stages=4, one block
size), which prevents Triton's autotuner from finding better configs
for different hardware targets.

Expand the search space to cover:
- Multiple num_warps values (4, 8) to better saturate both NVIDIA
  (warp size 32) and AMD (wavefront size 64) GPU compute units
- Multiple num_stages values for software pipelining flexibility
  across different cache hierarchies
- Multiple block sizes to adapt to varying matrix dimensions

This is complementary to PR #3945 (relaxed atomics on AMDGPU) and
targets the same kernels.

* Gate expanded autotune configs to AMD only, preserve original NVIDIA configs

H100 benchmarks showed ~18% regression on the atomic kernel with the
expanded search space. The autotuner appears to pick suboptimal configs
from the larger candidate set on NVIDIA. Gate the expanded configs
behind torch.version.hip so AMD gets the broader search (4-7% faster
on MI250X) while NVIDIA keeps the original tuned configs.

* Widen autotune search space and add N_GROUPS to scales kernel autotuning key

Two improvements based on MI300X (gfx942) benchmarking:

1. float8_rowwise.py: Widen block size search space for AMD GPUs.
   - Atomic configs: add BLOCK_SIZE_N=256 and BLOCK_SIZE_K=64
   - Reduction configs: add BLOCK_SIZE_N=128, BLOCK_SIZE_K=64, and num_stages=2,4
   - Yields 1.5-2.2x speedup on MI300X for the atomic kernel and
     1.05-1.25x for the reduction kernel across Llama4 MoE shapes.

2. jagged_float8_scales.py: Add N_GROUPS to autotuning key for both
   rowwise and colwise scales kernels. The previous key (M or K only)
   caused the autotuner to cache a single config across all n_groups
   values, but optimal tile sizes differ significantly by n_groups.
   This eliminates cross-n_groups interference and allows each n_groups
   value to independently find its best config.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) topic: performance Use this tag if this PR improves the performance of a feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants