Skip to content

EP: token alignment not working as expected #1651

@danielvegamyhre

Description

@danielvegamyhre

Bug description

Summary

Adding logging to torchtitan expert_parallel wrapper and running fp8 rowwise MoE training (where we set token group alignment size to 16), I see the alignment size is set correctly but the resulting M dimension is not divisible by 16:


[rank0]:[titan] 2025-08-28 10:05:31,175 - root - INFO - TOKEN_GROUP_ALIGN_SIZE_M = 16
[rank0]:[titan] 2025-08-28 10:05:31,454 - root - INFO - input_shape = torch.Size([16333, 5120])

This causes an error in scaled_grouped_mm, which expects the contracting dimension (which is M for the the gemm grad_weight = grad_output_t @ input):

  RuntimeError: strides should be multiple of 16 bytes

I see in git blame the last PR that touched this code was #1561

cc @tianyu-l

Repro

NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=50 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="experts,shared_expert" --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --parallelism.tensor_parallel_degree=2

Versions

  • torchtitan with latest main branch
  • torchao latest main branch

Metadata

Metadata

Assignees

No one assigned

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions