remove MoE token padding paths#2475
Conversation
| torch._check(num_actual_tokens == x.shape[0]) | ||
| else: | ||
| torch._check(num_actual_tokens >= 0) | ||
| torch._check(num_actual_tokens <= x.shape[0]) |
There was a problem hiding this comment.
we could delete this? only case supported should be FP8 on SM89
| ValidTokenGroupAlignmentSize = Literal[8, 16, 32] | ||
| # TODO(pianpwk): Consider removing padding paths entirely once HybridEP integration lands, | ||
| # moving padding to communication layer. | ||
| TOKEN_GROUP_ALIGN_SIZE_M = 1 |
There was a problem hiding this comment.
@shuhuayu is doing final validation on HybridEp. Once that lands, how about let's remove padding in torchtitan entirely. Also need to remove special handling in https://github.com/pytorch/torchtitan/pull/2470/changes#diff-1a1b8aa2c436eb5db983e7327f90645990e906b786ec3b7ea96d77b7050416c8R77
The only thing to keep is an arg to adjust the alignment size when calling HybridEP + mxfp8, which you can set at config time, or pass in during runtime (similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/parallelize.py#L129)
Before: ``` CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 [rank0]:[titan] 2026-03-03 15:24:02,281 - root - INFO - step: 1 loss: 8.06832 grad_norm: 1.4637 memory: 1.02GiB(1.08%) tps: 1,191 tflops: 0.09 mfu: 0.01% [rank0]:[titan] 2026-03-03 15:24:02,281 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2026-03-03 15:24:02,458 - root - INFO - step: 2 loss: 7.74952 grad_norm: 1.5518 memory: 1.07GiB(1.13%) tps: 92,412 tflops: 6.62 mfu: 0.67% [rank0]:[titan] 2026-03-03 15:24:02,506 - root - INFO - step: 3 loss: 6.99885 grad_norm: 2.0142 memory: 1.07GiB(1.13%) tps: 350,428 tflops: 25.09 mfu: 2.54% [rank0]:[titan] 2026-03-03 15:24:02,553 - root - INFO - step: 4 loss: 6.13972 grad_norm: 2.4066 memory: 1.07GiB(1.13%) tps: 350,777 tflops: 25.11 mfu: 2.54% [rank0]:[titan] 2026-03-03 15:24:02,599 - root - INFO - step: 5 loss: 5.25505 grad_norm: 2.6350 memory: 1.07GiB(1.13%) tps: 357,529 tflops: 25.59 mfu: 2.59% [rank0]:[titan] 2026-03-03 15:24:02,651 - root - INFO - step: 6 loss: 4.75111 grad_norm: 2.5402 memory: 1.07GiB(1.13%) tps: 319,817 tflops: 22.89 mfu: 2.31% [rank0]:[titan] 2026-03-03 15:24:02,696 - root - INFO - step: 7 loss: 4.42443 grad_norm: 2.4367 memory: 1.07GiB(1.13%) tps: 366,312 tflops: 26.22 mfu: 2.65% [rank0]:[titan] 2026-03-03 15:24:02,740 - root - INFO - step: 8 loss: 4.22678 grad_norm: 2.2814 memory: 1.07GiB(1.13%) tps: 372,464 tflops: 26.66 mfu: 2.70% [rank0]:[titan] 2026-03-03 15:24:02,787 - root - INFO - step: 9 loss: 4.25165 grad_norm: 2.0176 memory: 1.07GiB(1.13%) tps: 353,424 tflops: 25.30 mfu: 2.56% [rank0]:[titan] 2026-03-03 15:24:02,840 - root - INFO - step: 10 loss: 4.04419 grad_norm: 2.0104 memory: 1.07GiB(1.13%) tps: 308,696 tflops: 22.10 mfu: 2.23% CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2 [rank0]:[titan] 2026-03-03 15:24:46,108 - root - INFO - step: 1 loss: 8.06832 grad_norm: 1.4637 memory: 1.02GiB(1.08%) tps: 1,091 tflops: 0.08 mfu: 0.01% [rank0]:[titan] 2026-03-03 15:24:46,108 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2026-03-03 15:24:46,225 - root - INFO - step: 2 loss: 7.74951 grad_norm: 1.5518 memory: 1.07GiB(1.13%) tps: 140,882 tflops: 10.09 mfu: 1.02% [rank0]:[titan] 2026-03-03 15:24:46,282 - root - INFO - step: 3 loss: 6.99885 grad_norm: 2.0142 memory: 1.07GiB(1.13%) tps: 289,365 tflops: 20.71 mfu: 2.09% [rank0]:[titan] 2026-03-03 15:24:46,338 - root - INFO - step: 4 loss: 6.13974 grad_norm: 2.4065 memory: 1.07GiB(1.13%) tps: 295,195 tflops: 21.13 mfu: 2.14% [rank0]:[titan] 2026-03-03 15:24:46,392 - root - INFO - step: 5 loss: 5.25506 grad_norm: 2.6350 memory: 1.07GiB(1.13%) tps: 303,956 tflops: 21.76 mfu: 2.20% [rank0]:[titan] 2026-03-03 15:24:46,452 - root - INFO - step: 6 loss: 4.75112 grad_norm: 2.5402 memory: 1.07GiB(1.13%) tps: 276,643 tflops: 19.80 mfu: 2.00% [rank0]:[titan] 2026-03-03 15:24:46,498 - root - INFO - step: 7 loss: 4.42442 grad_norm: 2.4368 memory: 1.07GiB(1.13%) tps: 357,858 tflops: 25.62 mfu: 2.59% [rank0]:[titan] 2026-03-03 15:24:46,543 - root - INFO - step: 8 loss: 4.22681 grad_norm: 2.2814 memory: 1.07GiB(1.13%) tps: 366,776 tflops: 26.26 mfu: 2.65% [rank0]:[titan] 2026-03-03 15:24:46,589 - root - INFO - step: 9 loss: 4.25164 grad_norm: 2.0176 memory: 1.07GiB(1.13%) tps: 362,488 tflops: 25.95 mfu: 2.62% [rank0]:[titan] 2026-03-03 15:24:46,643 - root - INFO - step: 10 loss: 4.04417 grad_norm: 2.0103 memory: 1.07GiB(1.13%) tps: 305,329 tflops: 21.86 mfu: 2.21% ``` After: ``` CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 [rank0]:[titan] 2026-03-03 15:19:41,425 - root - INFO - step: 1 loss: 8.06832 grad_norm: 1.4637 memory: 1.02GiB(1.08%) tps: 2,393 tflops: 0.17 mfu: 0.02% [rank0]:[titan] 2026-03-03 15:19:41,425 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:/data/users/pianpwk/torchtitan/torchtitan/distributed/utils.py:395: UserWarning: Set timeout is now only supported for either nccl or gloo. [rank0]: torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) [rank0]:[titan] 2026-03-03 15:19:41,520 - root - INFO - step: 2 loss: 7.74951 grad_norm: 1.5518 memory: 1.07GiB(1.13%) tps: 173,417 tflops: 12.41 mfu: 1.26% [rank0]:[titan] 2026-03-03 15:19:41,566 - root - INFO - step: 3 loss: 6.99886 grad_norm: 2.0142 memory: 1.07GiB(1.13%) tps: 359,111 tflops: 25.71 mfu: 2.60% [rank0]:[titan] 2026-03-03 15:19:41,612 - root - INFO - step: 4 loss: 6.13973 grad_norm: 2.4065 memory: 1.07GiB(1.13%) tps: 354,734 tflops: 25.39 mfu: 2.57% [rank0]:[titan] 2026-03-03 15:19:41,656 - root - INFO - step: 5 loss: 5.25507 grad_norm: 2.6350 memory: 1.07GiB(1.13%) tps: 375,163 tflops: 26.86 mfu: 2.72% [rank0]:[titan] 2026-03-03 15:19:41,707 - root - INFO - step: 6 loss: 4.75113 grad_norm: 2.5402 memory: 1.07GiB(1.13%) tps: 325,100 tflops: 23.27 mfu: 2.35% [rank0]:[titan] 2026-03-03 15:19:41,751 - root - INFO - step: 7 loss: 4.42440 grad_norm: 2.4367 memory: 1.07GiB(1.13%) tps: 377,937 tflops: 27.05 mfu: 2.74% [rank0]:[titan] 2026-03-03 15:19:41,795 - root - INFO - step: 8 loss: 4.22678 grad_norm: 2.2815 memory: 1.07GiB(1.13%) tps: 375,085 tflops: 26.85 mfu: 2.71% [rank0]:[titan] 2026-03-03 15:19:41,841 - root - INFO - step: 9 loss: 4.25165 grad_norm: 2.0176 memory: 1.07GiB(1.13%) tps: 359,168 tflops: 25.71 mfu: 2.60% [rank0]:[titan] 2026-03-03 15:19:41,892 - root - INFO - step: 10 loss: 4.04418 grad_norm: 2.0104 memory: 1.07GiB(1.13%) tps: 326,346 tflops: 23.36 mfu: 2.36% CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2 [rank0]:[titan] 2026-03-03 15:21:22,800 - root - INFO - step: 1 loss: 8.06832 grad_norm: 1.4637 memory: 1.02GiB(1.08%) tps: 2,419 tflops: 0.17 mfu: 0.02% [rank0]:[titan] 2026-03-03 15:21:22,800 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2026-03-03 15:21:22,869 - root - INFO - step: 2 loss: 7.74951 grad_norm: 1.5518 memory: 1.07GiB(1.13%) tps: 239,517 tflops: 17.15 mfu: 1.73% [rank0]:[titan] 2026-03-03 15:21:22,915 - root - INFO - step: 3 loss: 6.99886 grad_norm: 2.0142 memory: 1.07GiB(1.13%) tps: 356,371 tflops: 25.51 mfu: 2.58% [rank0]:[titan] 2026-03-03 15:21:22,962 - root - INFO - step: 4 loss: 6.13973 grad_norm: 2.4065 memory: 1.07GiB(1.13%) tps: 350,649 tflops: 25.10 mfu: 2.54% [rank0]:[titan] 2026-03-03 15:21:23,009 - root - INFO - step: 5 loss: 5.25507 grad_norm: 2.6350 memory: 1.07GiB(1.13%) tps: 355,429 tflops: 25.44 mfu: 2.57% [rank0]:[titan] 2026-03-03 15:21:23,065 - root - INFO - step: 6 loss: 4.75114 grad_norm: 2.5402 memory: 1.07GiB(1.13%) tps: 294,755 tflops: 21.10 mfu: 2.13% [rank0]:[titan] 2026-03-03 15:21:23,115 - root - INFO - step: 7 loss: 4.42441 grad_norm: 2.4367 memory: 1.07GiB(1.13%) tps: 331,022 tflops: 23.70 mfu: 2.40% [rank0]:[titan] 2026-03-03 15:21:23,161 - root - INFO - step: 8 loss: 4.22680 grad_norm: 2.2815 memory: 1.07GiB(1.13%) tps: 355,568 tflops: 25.45 mfu: 2.57% [rank0]:[titan] 2026-03-03 15:21:23,206 - root - INFO - step: 9 loss: 4.25164 grad_norm: 2.0176 memory: 1.07GiB(1.13%) tps: 375,135 tflops: 26.85 mfu: 2.72% [rank0]:[titan] 2026-03-03 15:21:23,255 - root - INFO - step: 10 loss: 4.04419 grad_norm: 2.0104 memory: 1.07GiB(1.13%) tps: 335,830 tflops: 24.04 mfu: 2.43% ``` [ghstack-poisoned]
tianyu-l
left a comment
There was a problem hiding this comment.
We can't just do deletion. HybridEP kernels requires a handle to do proper padding if mxfp8 grouped gemm is used. I would recommend sending it when apply expert parallel.
There was a problem hiding this comment.
Please remove this file and instead use pytorch code -- make sure the impl doesn't introduce d2h sync. Check rakkit's PR for reference.
| padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M) | ||
| with torch.no_grad(): | ||
| (permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices( | ||
| (permuted_indices, num_tokens_per_expert, _offsets) = generate_permute_indices( |
There was a problem hiding this comment.
these utils can be put into expert_parallel.py, because single device no longer needs them
There was a problem hiding this comment.
Maybe a dumb question - would we still want to support fp8/mxfp8 in the non-EP case? Right now _permute is responsible for both rank->expert-major reordering (no-op for non-EP), as well as alignment padding. So we'd want to keep this path around for low-precision no-EP?
There was a problem hiding this comment.
There was a problem hiding this comment.
for what it's worth i think not using EP for training MoEs is fairly uncommon (@tianyu-l correct me if i am wrong here) but we should still support it.
@pianpwk you can delete the "rank-major to expert-major + pad with triton kernel" from torchtitan and just import it from torchao here, since it is only needed for fp8/mxfp8 we decided it makes more sense to live there
There was a problem hiding this comment.
should we change mxfp8.py as well?
@tianyu-l i thought hybridEP did the padding for mxfp8 grouped mm as part of the custom all2all dispatch impl? |
@danielvegamyhre Yes but you have to tell it whether or not you need padding in all2all, and how much. Also, it sounds like the hybridEP all2all dispatch won't happen in mxfp8, right? |
Right but that is just a number, "32" in this case, we don't need group size padding kernels / logic for that. I think I misunderstood your original comment here. Were you just saying that we can delete the padding kernels/logic but need to keep
Yeah my understanding is it happens in default precision / bf16 |
Yeah that's what I meant. |
|
|
||
| # For fp8 grouped GEMM, token group sizes must be multiples of 16 | ||
| # (16 byte alignment / 1 byte per elem = 16 elements) | ||
| set_token_group_alignment_size_m(FP8_GROUP_ALIGNMENT_SIZE) |
There was a problem hiding this comment.
multiple of 16 is still needed for fp8
|
any progress on this? @pianpwk |
Thanks, should we just close this and keep centralized tracking in #2620? @danielvegamyhre |
|
@wwwjn yes, i'll close |
Stack from ghstack (oldest at bottom):
Before:
After: