Fix EP token group padding issue#1718
Conversation
7bbcb2d to
e64d344
Compare
|
|
||
| # Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M, | ||
| # by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M. | ||
| ceil_div = lambda x, y: (x + y - 1) // y |
There was a problem hiding this comment.
nit: define a regular function instead of using a lambda
There was a problem hiding this comment.
Added round up util for this
e64d344 to
bd14959
Compare
| return total_norm | ||
|
|
||
|
|
||
| def _round_up(x: int, y: int) -> int: |
There was a problem hiding this comment.
probably should put in torchtitan/tools/utils.py instead of torchtitan/distributed/utils
| # Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M, | ||
| # by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M. | ||
| x_padded_per_expert = ( | ||
| x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M |
There was a problem hiding this comment.
oh so the previous issue was caused by x.shape[0] not divisible by TOKEN_GROUP_ALIGN_SIZE_M?
There was a problem hiding this comment.
Yeah that's my understanding.
The experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M padding does upper bound based padding (for each token group, variable amount of padding will be needed since token group sizes are variable, but at most we will have to add TOKEN_GROUP_ALIGN_SIZE_M per group, so it does that). However, it doesn't account for the original total M (x.shape[0]) potentially not being divisible by alignment size.
a066ca7 to
cda0f74
Compare
Confirmed CUDNN issue is resolved in todays nightly cuda 12.8 build. Reverted that change. |
Fixes pytorch#1651 ## Summary - Round up `max_len` of permuted token indicies in expert parallel decorator to be a multiple of token group alignment size. ## Test plan - Llama4 debug model with FSDP=2, EP=2: `NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --compile.enable `
Fixes pytorch#1651 ## Summary - Round up `max_len` of permuted token indicies in expert parallel decorator to be a multiple of token group alignment size. ## Test plan - Llama4 debug model with FSDP=2, EP=2: `NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --compile.enable `
Fixes #1651
Summary
max_lenof permuted token indicies in expert parallel decorator to be a multiple of token group alignment size.Test plan
NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --parallelism.data_parallel_shard_degree=2 --parallelism.expert_parallel_degree=2 --compile.enable