Skip to content

remove MoE token padding paths#2475

Closed
pianpwk wants to merge 4 commits into
gh/pianpwk/1/basefrom
gh/pianpwk/1/head
Closed

remove MoE token padding paths#2475
pianpwk wants to merge 4 commits into
gh/pianpwk/1/basefrom
gh/pianpwk/1/head

Conversation

@pianpwk

@pianpwk pianpwk commented Mar 3, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

Before:

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10

[rank0]:[titan] 2026-03-03 15:24:02,281 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 1,191  tflops: 0.09  mfu: 0.01%
[rank0]:[titan] 2026-03-03 15:24:02,281 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-03-03 15:24:02,458 - root - INFO - step:  2  loss:  7.74952  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 92,412  tflops: 6.62  mfu: 0.67%
[rank0]:[titan] 2026-03-03 15:24:02,506 - root - INFO - step:  3  loss:  6.99885  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 350,428  tflops: 25.09  mfu: 2.54%
[rank0]:[titan] 2026-03-03 15:24:02,553 - root - INFO - step:  4  loss:  6.13972  grad_norm:  2.4066  memory:  1.07GiB(1.13%)  tps: 350,777  tflops: 25.11  mfu: 2.54%
[rank0]:[titan] 2026-03-03 15:24:02,599 - root - INFO - step:  5  loss:  5.25505  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 357,529  tflops: 25.59  mfu: 2.59%
[rank0]:[titan] 2026-03-03 15:24:02,651 - root - INFO - step:  6  loss:  4.75111  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 319,817  tflops: 22.89  mfu: 2.31%
[rank0]:[titan] 2026-03-03 15:24:02,696 - root - INFO - step:  7  loss:  4.42443  grad_norm:  2.4367  memory:  1.07GiB(1.13%)  tps: 366,312  tflops: 26.22  mfu: 2.65%
[rank0]:[titan] 2026-03-03 15:24:02,740 - root - INFO - step:  8  loss:  4.22678  grad_norm:  2.2814  memory:  1.07GiB(1.13%)  tps: 372,464  tflops: 26.66  mfu: 2.70%
[rank0]:[titan] 2026-03-03 15:24:02,787 - root - INFO - step:  9  loss:  4.25165  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 353,424  tflops: 25.30  mfu: 2.56%
[rank0]:[titan] 2026-03-03 15:24:02,840 - root - INFO - step: 10  loss:  4.04419  grad_norm:  2.0104  memory:  1.07GiB(1.13%)  tps: 308,696  tflops: 22.10  mfu: 2.23%


CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2

[rank0]:[titan] 2026-03-03 15:24:46,108 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 1,091  tflops: 0.08  mfu: 0.01%
[rank0]:[titan] 2026-03-03 15:24:46,108 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-03-03 15:24:46,225 - root - INFO - step:  2  loss:  7.74951  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 140,882  tflops: 10.09  mfu: 1.02%
[rank0]:[titan] 2026-03-03 15:24:46,282 - root - INFO - step:  3  loss:  6.99885  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 289,365  tflops: 20.71  mfu: 2.09%
[rank0]:[titan] 2026-03-03 15:24:46,338 - root - INFO - step:  4  loss:  6.13974  grad_norm:  2.4065  memory:  1.07GiB(1.13%)  tps: 295,195  tflops: 21.13  mfu: 2.14%
[rank0]:[titan] 2026-03-03 15:24:46,392 - root - INFO - step:  5  loss:  5.25506  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 303,956  tflops: 21.76  mfu: 2.20%
[rank0]:[titan] 2026-03-03 15:24:46,452 - root - INFO - step:  6  loss:  4.75112  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 276,643  tflops: 19.80  mfu: 2.00%
[rank0]:[titan] 2026-03-03 15:24:46,498 - root - INFO - step:  7  loss:  4.42442  grad_norm:  2.4368  memory:  1.07GiB(1.13%)  tps: 357,858  tflops: 25.62  mfu: 2.59%
[rank0]:[titan] 2026-03-03 15:24:46,543 - root - INFO - step:  8  loss:  4.22681  grad_norm:  2.2814  memory:  1.07GiB(1.13%)  tps: 366,776  tflops: 26.26  mfu: 2.65%
[rank0]:[titan] 2026-03-03 15:24:46,589 - root - INFO - step:  9  loss:  4.25164  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 362,488  tflops: 25.95  mfu: 2.62%
[rank0]:[titan] 2026-03-03 15:24:46,643 - root - INFO - step: 10  loss:  4.04417  grad_norm:  2.0103  memory:  1.07GiB(1.13%)  tps: 305,329  tflops: 21.86  mfu: 2.21%

After:

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10

[rank0]:[titan] 2026-03-03 15:19:41,425 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 2,393  tflops: 0.17  mfu: 0.02%
[rank0]:[titan] 2026-03-03 15:19:41,425 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:/data/users/pianpwk/torchtitan/torchtitan/distributed/utils.py:395: UserWarning: Set timeout is now only supported for either nccl or gloo.
[rank0]:  torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
[rank0]:[titan] 2026-03-03 15:19:41,520 - root - INFO - step:  2  loss:  7.74951  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 173,417  tflops: 12.41  mfu: 1.26%
[rank0]:[titan] 2026-03-03 15:19:41,566 - root - INFO - step:  3  loss:  6.99886  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 359,111  tflops: 25.71  mfu: 2.60%
[rank0]:[titan] 2026-03-03 15:19:41,612 - root - INFO - step:  4  loss:  6.13973  grad_norm:  2.4065  memory:  1.07GiB(1.13%)  tps: 354,734  tflops: 25.39  mfu: 2.57%
[rank0]:[titan] 2026-03-03 15:19:41,656 - root - INFO - step:  5  loss:  5.25507  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 375,163  tflops: 26.86  mfu: 2.72%
[rank0]:[titan] 2026-03-03 15:19:41,707 - root - INFO - step:  6  loss:  4.75113  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 325,100  tflops: 23.27  mfu: 2.35%
[rank0]:[titan] 2026-03-03 15:19:41,751 - root - INFO - step:  7  loss:  4.42440  grad_norm:  2.4367  memory:  1.07GiB(1.13%)  tps: 377,937  tflops: 27.05  mfu: 2.74%
[rank0]:[titan] 2026-03-03 15:19:41,795 - root - INFO - step:  8  loss:  4.22678  grad_norm:  2.2815  memory:  1.07GiB(1.13%)  tps: 375,085  tflops: 26.85  mfu: 2.71%
[rank0]:[titan] 2026-03-03 15:19:41,841 - root - INFO - step:  9  loss:  4.25165  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 359,168  tflops: 25.71  mfu: 2.60%
[rank0]:[titan] 2026-03-03 15:19:41,892 - root - INFO - step: 10  loss:  4.04418  grad_norm:  2.0104  memory:  1.07GiB(1.13%)  tps: 326,346  tflops: 23.36  mfu: 2.36%


CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2

[rank0]:[titan] 2026-03-03 15:21:22,800 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 2,419  tflops: 0.17  mfu: 0.02%
[rank0]:[titan] 2026-03-03 15:21:22,800 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-03-03 15:21:22,869 - root - INFO - step:  2  loss:  7.74951  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 239,517  tflops: 17.15  mfu: 1.73%
[rank0]:[titan] 2026-03-03 15:21:22,915 - root - INFO - step:  3  loss:  6.99886  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 356,371  tflops: 25.51  mfu: 2.58%
[rank0]:[titan] 2026-03-03 15:21:22,962 - root - INFO - step:  4  loss:  6.13973  grad_norm:  2.4065  memory:  1.07GiB(1.13%)  tps: 350,649  tflops: 25.10  mfu: 2.54%
[rank0]:[titan] 2026-03-03 15:21:23,009 - root - INFO - step:  5  loss:  5.25507  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 355,429  tflops: 25.44  mfu: 2.57%
[rank0]:[titan] 2026-03-03 15:21:23,065 - root - INFO - step:  6  loss:  4.75114  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 294,755  tflops: 21.10  mfu: 2.13%
[rank0]:[titan] 2026-03-03 15:21:23,115 - root - INFO - step:  7  loss:  4.42441  grad_norm:  2.4367  memory:  1.07GiB(1.13%)  tps: 331,022  tflops: 23.70  mfu: 2.40%
[rank0]:[titan] 2026-03-03 15:21:23,161 - root - INFO - step:  8  loss:  4.22680  grad_norm:  2.2815  memory:  1.07GiB(1.13%)  tps: 355,568  tflops: 25.45  mfu: 2.57%
[rank0]:[titan] 2026-03-03 15:21:23,206 - root - INFO - step:  9  loss:  4.25164  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 375,135  tflops: 26.85  mfu: 2.72%
[rank0]:[titan] 2026-03-03 15:21:23,255 - root - INFO - step: 10  loss:  4.04419  grad_norm:  2.0104  memory:  1.07GiB(1.13%)  tps: 335,830  tflops: 24.04  mfu: 2.43%

[ghstack-poisoned]
pianpwk added a commit that referenced this pull request Mar 3, 2026
ghstack-source-id: 33cf118
Pull-Request: #2475
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 3, 2026
[ghstack-poisoned]
pianpwk added a commit that referenced this pull request Mar 3, 2026
ghstack-source-id: 32f97dc
Pull Request resolved: #2475
[ghstack-poisoned]
pianpwk added a commit that referenced this pull request Mar 3, 2026
ghstack-source-id: 2454a6f
Pull Request resolved: #2475
Comment thread torchtitan/models/common/moe/moe.py Outdated
torch._check(num_actual_tokens == x.shape[0])
else:
torch._check(num_actual_tokens >= 0)
torch._check(num_actual_tokens <= x.shape[0])

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could delete this? only case supported should be FP8 on SM89

Comment thread torchtitan/models/common/moe/utils.py Outdated
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
# TODO(pianpwk): Consider removing padding paths entirely once HybridEP integration lands,
# moving padding to communication layer.
TOKEN_GROUP_ALIGN_SIZE_M = 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuhuayu is doing final validation on HybridEp. Once that lands, how about let's remove padding in torchtitan entirely. Also need to remove special handling in https://github.com/pytorch/torchtitan/pull/2470/changes#diff-1a1b8aa2c436eb5db983e7327f90645990e906b786ec3b7ea96d77b7050416c8R77

The only thing to keep is an arg to adjust the alignment size when calling HybridEP + mxfp8, which you can set at config time, or pass in during runtime (similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama4/parallelize.py#L129)

cc @danielvegamyhre

Before:
```
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10

[rank0]:[titan] 2026-03-03 15:24:02,281 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 1,191  tflops: 0.09  mfu: 0.01%
[rank0]:[titan] 2026-03-03 15:24:02,281 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-03-03 15:24:02,458 - root - INFO - step:  2  loss:  7.74952  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 92,412  tflops: 6.62  mfu: 0.67%
[rank0]:[titan] 2026-03-03 15:24:02,506 - root - INFO - step:  3  loss:  6.99885  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 350,428  tflops: 25.09  mfu: 2.54%
[rank0]:[titan] 2026-03-03 15:24:02,553 - root - INFO - step:  4  loss:  6.13972  grad_norm:  2.4066  memory:  1.07GiB(1.13%)  tps: 350,777  tflops: 25.11  mfu: 2.54%
[rank0]:[titan] 2026-03-03 15:24:02,599 - root - INFO - step:  5  loss:  5.25505  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 357,529  tflops: 25.59  mfu: 2.59%
[rank0]:[titan] 2026-03-03 15:24:02,651 - root - INFO - step:  6  loss:  4.75111  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 319,817  tflops: 22.89  mfu: 2.31%
[rank0]:[titan] 2026-03-03 15:24:02,696 - root - INFO - step:  7  loss:  4.42443  grad_norm:  2.4367  memory:  1.07GiB(1.13%)  tps: 366,312  tflops: 26.22  mfu: 2.65%
[rank0]:[titan] 2026-03-03 15:24:02,740 - root - INFO - step:  8  loss:  4.22678  grad_norm:  2.2814  memory:  1.07GiB(1.13%)  tps: 372,464  tflops: 26.66  mfu: 2.70%
[rank0]:[titan] 2026-03-03 15:24:02,787 - root - INFO - step:  9  loss:  4.25165  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 353,424  tflops: 25.30  mfu: 2.56%
[rank0]:[titan] 2026-03-03 15:24:02,840 - root - INFO - step: 10  loss:  4.04419  grad_norm:  2.0104  memory:  1.07GiB(1.13%)  tps: 308,696  tflops: 22.10  mfu: 2.23%


CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2

[rank0]:[titan] 2026-03-03 15:24:46,108 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 1,091  tflops: 0.08  mfu: 0.01%
[rank0]:[titan] 2026-03-03 15:24:46,108 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-03-03 15:24:46,225 - root - INFO - step:  2  loss:  7.74951  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 140,882  tflops: 10.09  mfu: 1.02%
[rank0]:[titan] 2026-03-03 15:24:46,282 - root - INFO - step:  3  loss:  6.99885  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 289,365  tflops: 20.71  mfu: 2.09%
[rank0]:[titan] 2026-03-03 15:24:46,338 - root - INFO - step:  4  loss:  6.13974  grad_norm:  2.4065  memory:  1.07GiB(1.13%)  tps: 295,195  tflops: 21.13  mfu: 2.14%
[rank0]:[titan] 2026-03-03 15:24:46,392 - root - INFO - step:  5  loss:  5.25506  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 303,956  tflops: 21.76  mfu: 2.20%
[rank0]:[titan] 2026-03-03 15:24:46,452 - root - INFO - step:  6  loss:  4.75112  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 276,643  tflops: 19.80  mfu: 2.00%
[rank0]:[titan] 2026-03-03 15:24:46,498 - root - INFO - step:  7  loss:  4.42442  grad_norm:  2.4368  memory:  1.07GiB(1.13%)  tps: 357,858  tflops: 25.62  mfu: 2.59%
[rank0]:[titan] 2026-03-03 15:24:46,543 - root - INFO - step:  8  loss:  4.22681  grad_norm:  2.2814  memory:  1.07GiB(1.13%)  tps: 366,776  tflops: 26.26  mfu: 2.65%
[rank0]:[titan] 2026-03-03 15:24:46,589 - root - INFO - step:  9  loss:  4.25164  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 362,488  tflops: 25.95  mfu: 2.62%
[rank0]:[titan] 2026-03-03 15:24:46,643 - root - INFO - step: 10  loss:  4.04417  grad_norm:  2.0103  memory:  1.07GiB(1.13%)  tps: 305,329  tflops: 21.86  mfu: 2.21%
```


After:
```
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10

[rank0]:[titan] 2026-03-03 15:19:41,425 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 2,393  tflops: 0.17  mfu: 0.02%
[rank0]:[titan] 2026-03-03 15:19:41,425 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:/data/users/pianpwk/torchtitan/torchtitan/distributed/utils.py:395: UserWarning: Set timeout is now only supported for either nccl or gloo.
[rank0]:  torch.distributed.distributed_c10d._set_pg_timeout(timeout, group)
[rank0]:[titan] 2026-03-03 15:19:41,520 - root - INFO - step:  2  loss:  7.74951  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 173,417  tflops: 12.41  mfu: 1.26%
[rank0]:[titan] 2026-03-03 15:19:41,566 - root - INFO - step:  3  loss:  6.99886  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 359,111  tflops: 25.71  mfu: 2.60%
[rank0]:[titan] 2026-03-03 15:19:41,612 - root - INFO - step:  4  loss:  6.13973  grad_norm:  2.4065  memory:  1.07GiB(1.13%)  tps: 354,734  tflops: 25.39  mfu: 2.57%
[rank0]:[titan] 2026-03-03 15:19:41,656 - root - INFO - step:  5  loss:  5.25507  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 375,163  tflops: 26.86  mfu: 2.72%
[rank0]:[titan] 2026-03-03 15:19:41,707 - root - INFO - step:  6  loss:  4.75113  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 325,100  tflops: 23.27  mfu: 2.35%
[rank0]:[titan] 2026-03-03 15:19:41,751 - root - INFO - step:  7  loss:  4.42440  grad_norm:  2.4367  memory:  1.07GiB(1.13%)  tps: 377,937  tflops: 27.05  mfu: 2.74%
[rank0]:[titan] 2026-03-03 15:19:41,795 - root - INFO - step:  8  loss:  4.22678  grad_norm:  2.2815  memory:  1.07GiB(1.13%)  tps: 375,085  tflops: 26.85  mfu: 2.71%
[rank0]:[titan] 2026-03-03 15:19:41,841 - root - INFO - step:  9  loss:  4.25165  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 359,168  tflops: 25.71  mfu: 2.60%
[rank0]:[titan] 2026-03-03 15:19:41,892 - root - INFO - step: 10  loss:  4.04418  grad_norm:  2.0104  memory:  1.07GiB(1.13%)  tps: 326,346  tflops: 23.36  mfu: 2.36%


CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --debug.seed 10 --parallelism.expert_parallel_degree=2

[rank0]:[titan] 2026-03-03 15:21:22,800 - root - INFO - step:  1  loss:  8.06832  grad_norm:  1.4637  memory:  1.02GiB(1.08%)  tps: 2,419  tflops: 0.17  mfu: 0.02%
[rank0]:[titan] 2026-03-03 15:21:22,800 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2026-03-03 15:21:22,869 - root - INFO - step:  2  loss:  7.74951  grad_norm:  1.5518  memory:  1.07GiB(1.13%)  tps: 239,517  tflops: 17.15  mfu: 1.73%
[rank0]:[titan] 2026-03-03 15:21:22,915 - root - INFO - step:  3  loss:  6.99886  grad_norm:  2.0142  memory:  1.07GiB(1.13%)  tps: 356,371  tflops: 25.51  mfu: 2.58%
[rank0]:[titan] 2026-03-03 15:21:22,962 - root - INFO - step:  4  loss:  6.13973  grad_norm:  2.4065  memory:  1.07GiB(1.13%)  tps: 350,649  tflops: 25.10  mfu: 2.54%
[rank0]:[titan] 2026-03-03 15:21:23,009 - root - INFO - step:  5  loss:  5.25507  grad_norm:  2.6350  memory:  1.07GiB(1.13%)  tps: 355,429  tflops: 25.44  mfu: 2.57%
[rank0]:[titan] 2026-03-03 15:21:23,065 - root - INFO - step:  6  loss:  4.75114  grad_norm:  2.5402  memory:  1.07GiB(1.13%)  tps: 294,755  tflops: 21.10  mfu: 2.13%
[rank0]:[titan] 2026-03-03 15:21:23,115 - root - INFO - step:  7  loss:  4.42441  grad_norm:  2.4367  memory:  1.07GiB(1.13%)  tps: 331,022  tflops: 23.70  mfu: 2.40%
[rank0]:[titan] 2026-03-03 15:21:23,161 - root - INFO - step:  8  loss:  4.22680  grad_norm:  2.2815  memory:  1.07GiB(1.13%)  tps: 355,568  tflops: 25.45  mfu: 2.57%
[rank0]:[titan] 2026-03-03 15:21:23,206 - root - INFO - step:  9  loss:  4.25164  grad_norm:  2.0176  memory:  1.07GiB(1.13%)  tps: 375,135  tflops: 26.85  mfu: 2.72%
[rank0]:[titan] 2026-03-03 15:21:23,255 - root - INFO - step: 10  loss:  4.04419  grad_norm:  2.0104  memory:  1.07GiB(1.13%)  tps: 335,830  tflops: 24.04  mfu: 2.43%
```

[ghstack-poisoned]
pianpwk added a commit that referenced this pull request Mar 11, 2026
ghstack-source-id: 32dc675
Pull Request resolved: #2475
@pianpwk pianpwk changed the title remove padding for bf16 remove MoE token padding paths Mar 11, 2026
@pianpwk pianpwk marked this pull request as ready for review March 11, 2026 21:28

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't just do deletion. HybridEP kernels requires a handle to do proper padding if mxfp8 grouped gemm is used. I would recommend sending it when apply expert parallel.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file and instead use pytorch code -- make sure the impl doesn't introduce d2h sync. Check rakkit's PR for reference.

padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)
with torch.no_grad():
(permuted_indices, num_tokens_per_expert, _offsets,) = generate_permute_indices(
(permuted_indices, num_tokens_per_expert, _offsets) = generate_permute_indices(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these utils can be put into expert_parallel.py, because single device no longer needs them

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a dumb question - would we still want to support fp8/mxfp8 in the non-EP case? Right now _permute is responsible for both rank->expert-major reordering (no-op for non-EP), as well as alignment padding. So we'd want to keep this path around for low-precision no-EP?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for what it's worth i think not using EP for training MoEs is fairly uncommon (@tianyu-l correct me if i am wrong here) but we should still support it.

@pianpwk you can delete the "rank-major to expert-major + pad with triton kernel" from torchtitan and just import it from torchao here, since it is only needed for fp8/mxfp8 we decided it makes more sense to live there

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pianpwk yes for mxfp8 path, no for fp16

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we change mxfp8.py as well?

@danielvegamyhre

Copy link
Copy Markdown
Contributor

We can't just do deletion. HybridEP kernels requires a handle to do proper padding if mxfp8 grouped gemm is used. I would recommend sending it when apply expert parallel.

@tianyu-l i thought hybridEP did the padding for mxfp8 grouped mm as part of the custom all2all dispatch impl?

@tianyu-l

tianyu-l commented Mar 11, 2026

Copy link
Copy Markdown
Contributor

i thought hybridEP did the padding for mxfp8 grouped mm as part of the custom all2all dispatch impl?

@danielvegamyhre Yes but you have to tell it whether or not you need padding in all2all, and how much.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/deepep/hybridep.py#L196

Also, it sounds like the hybridEP all2all dispatch won't happen in mxfp8, right?

@danielvegamyhre

danielvegamyhre commented Mar 11, 2026

Copy link
Copy Markdown
Contributor

i thought hybridEP did the padding for mxfp8 grouped mm as part of the custom all2all dispatch impl?

@danielvegamyhre Yes but you have to tell it whether or not you need padding in all2all, and how much. https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/deepep/hybridep.py#L196

Right but that is just a number, "32" in this case, we don't need group size padding kernels / logic for that. I think I misunderstood your original comment here. Were you just saying that we can delete the padding kernels/logic but need to keep MXFP8_GROUP_ALIGNMENT_SIZE = 32 for defining the alignment size for hybrid EP?

Also, it sounds like the hybridEP all2all dispatch won't happen in mxfp8, right?

Yeah my understanding is it happens in default precision / bf16

@tianyu-l

Copy link
Copy Markdown
Contributor

@danielvegamyhre

Were you just saying that we can delete the padding kernels/logic but need to keep MXFP8_GROUP_ALIGNMENT_SIZE = 32 for defining the alignment size for hybrid EP?

Yeah that's what I meant.


# For fp8 grouped GEMM, token group sizes must be multiples of 16
# (16 byte alignment / 1 byte per elem = 16 elements)
set_token_group_alignment_size_m(FP8_GROUP_ALIGNMENT_SIZE)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multiple of 16 is still needed for fp8

@tianyu-l

Copy link
Copy Markdown
Contributor

any progress on this? @pianpwk

@danielvegamyhre

Copy link
Copy Markdown
Contributor

any progress on this? @pianpwk

@tianyu-l since @pianpwk is helping out with Dtensor + MXFP8 composability changes I volunteered to take over the padding work, here is the PR: #2620 i haven't tested it yet due to devgpu issues, will let you know once i've tested

@wwwjn

wwwjn commented Mar 27, 2026

Copy link
Copy Markdown
Contributor

any progress on this? @pianpwk

@tianyu-l since @pianpwk is helping out with Dtensor + MXFP8 composability changes I volunteered to take over the padding work, here is the PR: #2620 i haven't tested it yet due to devgpu issues, will let you know once i've tested

Thanks, should we just close this and keep centralized tracking in #2620? @danielvegamyhre

@danielvegamyhre

Copy link
Copy Markdown
Contributor

@wwwjn yes, i'll close

@github-project-automation github-project-automation Bot moved this from In Progress to Done in 26H1 TorchTitan Development Mar 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

torch.compile fails with DeepSeekV3 + SimpleFSDP

4 participants