Skip to content

Shape mismatch when generating with multiple processes #32603

@ojh31

Description

@ojh31

System Info

  • transformers version: 4.42.4
  • Platform: Linux-5.15.0-106-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.3
  • Accelerate version: 0.29.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: yes (accelerate/FSDP)
  • Using GPU in script?: yes
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

@gante @SunMarc @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run accelerate launch --config_file=accelerate_config.yaml foo.py

foo.py:

from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoModelForCausalLM, LlamaTokenizer

NAME = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(
    NAME, 
    padding_side="left", 
)
accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(NAME)
model = accelerator.prepare(model)

paired_input_ids = torch.tensor([
    [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             1, 29961, 25580, 29962,   887,   526,   385,  7395,  6993,  1788,
         29889, 13866,   366,   674,   367,  5429,   278,  1788,  4800, 29889,
           450,  1404,   674,   769,  3896,   263,  4800, 29889,   960,   278,
          1404,  4800,  7087,   278,  1788,  4800, 29892,   736, 18016, 13566,
          3352, 29889,   960,   278,  1404,  4800,   947,   451,  1993,   278,
          1788,  4800, 29892,   736,   360,  1430, 29902,  3352, 29889, 18076,
           487,  3099,  1156,   278,  1404,  4800, 29889,    13,    13,  3924,
          4800, 29901, 15040,    13,  2659,  4800, 29901, 15040,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962],
            [    2,     2,     2,     2,     2,     2,     1, 29961, 25580, 29962,
           887,   526,   385,  7395,  6993,  1788, 29889, 13866,   366,   674,
           367,  5429,   278,  1788,  4800, 29889,   450,  1404,   674,   769,
          3896,   263,  4800, 29889,   960,   278,  1404,  4800,  7087,   278,
          1788,  4800, 29892,   736, 18016, 13566,  3352, 29889,   960,   278,
          1404,  4800,   947,   451,  1993,   278,  1788,  4800, 29892,   736,
           360,  1430, 29902,  3352, 29889, 18076,   487,  3099,  1156,   278,
          1404,  4800, 29889,    13,    13,  3924,  4800, 29901,  1757, 10582,
           284,    13,  2659,  4800, 29901,  1757, 10582,   284,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962]
        ]
)
paired_attention_mask = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
    [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]
])

paired_dataset = TensorDataset(paired_input_ids, paired_attention_mask)

dataloader = DataLoader(
    dataset=paired_dataset,
    batch_size=1,  # Process one pair at a time
    shuffle=False,
)
dataloader = accelerator.prepare(dataloader)


for batch_input_ids, batch_attention_mask in dataloader:
    with torch.no_grad():
        model.forward(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
    with FSDP.summon_full_params(model, recurse=False):
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask, 
            tokenizer=tokenizer,
            synced_gpus=True,
        )

accelerate_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: "no"
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: "no"
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

Should generate text output, but instead throws error

The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 648, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 978, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/robust-llm/pairs.py", line 66, in <module>
    outputs = model.generate(
RuntimeError: The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]

Hypothesis:
In transformers/generation/utils.py::GenerationMixin_sample(), during the while self._has_unfinished_sequences() loop, we continue if synced_gpus and this_peer_finished. This results in not skipping the concatenation of next_tokens to input_ids. Whereas, we keep updating the past_key_value cache in transformers/models/llama/modeling_llama.py::LlamaSdpaAttention.forward(). Therefore, when one process finishes generation before the other, the finished process continues to expand the key-value cache but stops expanding the input tensors, leading to a shape mismatch. Maybe a simple fix would be to forcibly set past_key_value to None once this_peer_finished is set to True?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions