Skip to content

Fix EP token group padding issue#1718

Merged
tianyu-l merged 3 commits into
pytorch:mainfrom
danielvegamyhre:group-pad
Sep 18, 2025
Merged

Fix EP token group padding issue#1718
tianyu-l merged 3 commits into
pytorch:mainfrom
danielvegamyhre:group-pad

Conversation

@danielvegamyhre

@danielvegamyhre danielvegamyhre commented Sep 17, 2025

Copy link
Copy Markdown
Contributor

Fixes #1651

Summary

  • Round up max_len of permuted token indicies in expert parallel decorator to be a multiple of token group alignment size.

Test plan

  • Llama4 debug model with FSDP=2, EP=2: NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --compile.enable

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

cc @tianyu-l for review

thanks @vkuzo for pointing this out!


# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
ceil_div = lambda x, y: (x + y - 1) // y

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: define a regular function instead of using a lambda

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added round up util for this

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 18, 2025
return total_norm


def _round_up(x: int, y: int) -> int:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should put in torchtitan/tools/utils.py instead of torchtitan/distributed/utils

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
x_padded_per_expert = (
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh so the previous issue was caused by x.shape[0] not divisible by TOKEN_GROUP_ALIGN_SIZE_M?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's my understanding.

The experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M padding does upper bound based padding (for each token group, variable amount of padding will be needed since token group sizes are variable, but at most we will have to add TOKEN_GROUP_ALIGN_SIZE_M per group, so it does that). However, it doesn't account for the original total M (x.shape[0]) potentially not being divisible by alignment size.

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think it can be merged without changing the attention backends

Depends on previous PR in stack: #1717

@danielvegamyhre

Copy link
Copy Markdown
Contributor Author

LGTM, I think it can be merged without changing the attention backends

Depends on previous PR in stack: #1717

Confirmed CUDNN issue is resolved in todays nightly cuda 12.8 build. Reverted that change.

@tianyu-l tianyu-l merged commit 60645bc into pytorch:main Sep 18, 2025
8 checks passed
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
Fixes pytorch#1651 

## Summary
- Round up `max_len` of permuted token indicies in expert parallel
decorator to be a multiple of token group alignment size.

## Test plan
- Llama4 debug model with FSDP=2, EP=2: `NGPU=2
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml"
./run_train.sh --parallelism.data_parallel_shard_degree=2
--parallelism.expert_parallel_degree=2 --compile.enable `
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
Fixes pytorch#1651 

## Summary
- Round up `max_len` of permuted token indicies in expert parallel
decorator to be a multiple of token group alignment size.

## Test plan
- Llama4 debug model with FSDP=2, EP=2: `NGPU=2
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml"
./run_train.sh --parallelism.data_parallel_shard_degree=2
--parallelism.expert_parallel_degree=2 --compile.enable `
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EP: token alignment not working as expected

3 participants