Skip to content

Bad sequence length management when seq_len%tensor_parallel_degree!=0 #1306

@jc-audet

Description

@jc-audet

Bug description

Hey all!

I’ve been trying to set up my own transformer tensor parallelism following the setup here in TorchTitan. I’ve been looking at the internal shapes of the input at various places in my model and have seen that the sequence parallel shapes around embeddings and layer norms are not managed correctly when the sequence length is not a multiple of the TP size. However I’m confused because I have not seen any warning or docs describing this restriction. Am I doing something wrong?

PS: Thanks for this great repository, this is gold.

Repro:

Setup directly on TorchTitan main

git clone https://github.com/pytorch/torchtitan
cd torchtitan
conda env create -n titan python=3.12
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
[For AMD GPU] pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.3 --force-reinstall

Replace Transformer forward with : (Functionally equivalent to original)

    def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None):
        if self.model_args.use_flex_attn:
            init_attention_mask(
                input_batch if input_batch is not None else tokens, eos_id=self.eos_id
            )
        
        # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        torch.distributed.barrier()
        logger.info(f"[Rank {torch.distributed.get_rank()}] Embedding: {tokens.shape=} -> {h.shape=}")

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)
        
        h = self.norm(h) if self.norm else h
        output = self.output(h) if self.output else h
        return output

Replace TransformerBlock forward with: (Functionally equivalent to original)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
    ):

        residual = x
        x_normed = self.attention_norm(x)

        torch.distributed.barrier()
        logger.info(f"[Rank {torch.distributed.get_rank()}] attention_norm: {x.shape} -> {x_normed.shape}")

        h = residual + self.attention(x_normed, freqs_cis)

        torch.distributed.barrier()
        logger.info(f"[Rank {torch.distributed.get_rank()}] attention: {x_normed.shape} -> {h.shape}")

        h_normed = self.ffn_norm(h)

        torch.distributed.barrier()
        logger.info(f"[Rank {torch.distributed.get_rank()}] ffn_norm: {h.shape} -> {h_normed.shape}")

        out = h + self.feed_forward(h_normed)

        torch.distributed.barrier()
        logger.info(f"[Rank {torch.distributed.get_rank()}] feed_forward: {h_normed.shape} -> {out.shape}")
        if torch.distributed.get_rank() == 0:
            breakpoint()
        torch.distributed.barrier()
        return out

Change in llama3_3b.toml

seq_len = 8193
tensor_parallel_degree = 4

Run

NGPU=4 LOG_RANK=0,1,2,3 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh

Observed results

[rank0]:[titan] 2025-06-16 19:22:50,757 - root - INFO - [Rank 0] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 2049, 4096])
[rank1]:[titan] 2025-06-16 19:22:50,757 - root - INFO - [Rank 1] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 2049, 4096])
[rank2]:[titan] 2025-06-16 19:22:50,757 - root - INFO - [Rank 2] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 2049, 4096])
[rank3]:[titan] 2025-06-16 19:22:50,757 - root - INFO - [Rank 3] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 2046, 4096])

[rank0]:[titan] 2025-06-16 19:22:50,788 - root - INFO - [Rank 0] attention_norm: torch.Size([1, 2049, 4096]) -> torch.Size([1, 8196, 4096])
[rank1]:[titan] 2025-06-16 19:22:50,788 - root - INFO - [Rank 1] attention_norm: torch.Size([1, 2049, 4096]) -> torch.Size([1, 8196, 4096])
[rank2]:[titan] 2025-06-16 19:22:50,788 - root - INFO - [Rank 2] attention_norm: torch.Size([1, 2049, 4096]) -> torch.Size([1, 8196, 4096])
[rank3]:[titan] 2025-06-16 19:22:50,788 - root - INFO - [Rank 3] attention_norm: torch.Size([1, 2046, 4096]) -> torch.Size([1, 8184, 4096])

No ranks have the correct sequence length after passing through the attention norm layer, they expect the full sequence length to be the local_size * TP degree. This is especially a problem for Rank 3 which have smaller sequence length, so is definitely missing information.

I understand this is caused by the embedding layer RowwiseParallel which has use_local_output=True by default. This cause the returned tensor of nn.Embedding to be a simple torch.Tensor, at which point the correct placement and relation to other ranks is lost.

Fix

By setting use_local_output=False to all RowwiseParallel layers in the parallel plan, i.e.,

    parallelize_module(
        model,
        tp_mesh,
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
                use_local_output=False,
            ),
...

and in the transformer plan:

        layer_plan = {
            "attention.wo": rowwise_parallel(
                output_layouts=Shard(1),
                use_local_output=False
            ),
...
            "feed_forward.w2": rowwise_parallel(
                output_layouts=Shard(1),
                use_local_output=False
            ),
...

We get the following result with the same setup:

[rank0]:[titan] 2025-06-16 19:44:28,838 - root - INFO - [Rank 0] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 8193, 4096])
[rank1]:[titan] 2025-06-16 19:44:28,838 - root - INFO - [Rank 1] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 8193, 4096])
[rank2]:[titan] 2025-06-16 19:44:28,838 - root - INFO - [Rank 2] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 8193, 4096])
[rank3]:[titan] 2025-06-16 19:44:28,838 - root - INFO - [Rank 3] Embedding: tokens.shape=torch.Size([1, 8193]) -> h.shape=torch.Size([1, 8193, 4096])

[rank0]:[titan] 2025-06-16 19:44:28,887 - root - INFO - [Rank 0] attention_norm: torch.Size([1, 8193, 4096]) -> torch.Size([1, 8193, 4096])
[rank1]:[titan] 2025-06-16 19:44:28,887 - root - INFO - [Rank 1] attention_norm: torch.Size([1, 8193, 4096]) -> torch.Size([1, 8193, 4096])
[rank2]:[titan] 2025-06-16 19:44:28,887 - root - INFO - [Rank 2] attention_norm: torch.Size([1, 8193, 4096]) -> torch.Size([1, 8193, 4096])
[rank3]:[titan] 2025-06-16 19:44:28,887 - root - INFO - [Rank 3] attention_norm: torch.Size([1, 8193, 4096]) -> torch.Size([1, 8193, 4096])

Is there a reason we would not want to set use_local_output=False for all RowWiseParallel layers?

Versions

Version:

>>> import torch
torch.___>>> torch.__version__
'2.8.0.dev20250616+cu126'

toml:

# torchtitan Config.toml
# NOTE: this toml config is a preset for 64 A100 GPUs.

[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"

[model]
name = "llama3"
flavor = "8B"
tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200  # lr scheduler warm up

[training]
local_batch_size = 1
seq_len = 8193
max_norm = 1.0  # grad norm clipping
steps = 1000
compile = false
dataset = "c4"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 4
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 500
last_save_model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective"  # ["none", "selective", "full"]
selective_ac_option = "op"  # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions