Skip to content

Multi GPU generate with llama shape error #32885

@dakinggg

Description

@dakinggg

System Info

  • transformers version: 4.44.0
  • Platform: Linux-5.15.0-91-generic-x86_64-with-glibc2.31
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.25.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: yes, running with composer as the distributed launcher
  • Using GPU in script?: yes
  • GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@gante @ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run llama with synced_gpus=True and an attention mask. This worked fine on transformers 4.40.2 (and 4.41.x), but no longer works. The use of Composer for the dist stuff is just convenience, shouldn't affect anything to swap in a different distributed launcher, etc.

import transformers
import torch
from composer.utils import dist

def main():
    dist.initialize_dist('gpu')

    name = 'meta-llama/Meta-Llama-3-8B-Instruct'
    tokenizer = transformers.AutoTokenizer.from_pretrained(name)
    pad_token_id = tokenizer.eos_token_id
    model = transformers.AutoModelForCausalLM.from_pretrained(name)

    rank = dist.get_global_rank()

    model.to(f'cuda:{rank}')

    if dist.get_global_rank() == 0:
        content = 'Write one short sentence.'
    else:
        content = 'Write one long paragraph.'

    messages = [
        {
            'role': 'user',
            'content': content,
        }
    ]

    tokenized_messages = tokenizer.apply_chat_template(messages, return_tensors='pt')

    padded_messages = torch.cat(
        [
            torch.LongTensor((4096 - 20) * [pad_token_id]),
            tokenized_messages[0],  # [seq]
        ],
        dim=0,
    )
    padded_messages = padded_messages.unsqueeze(0)
    padded_messages = padded_messages.to(f'cuda:{rank}')
    attention_mask = ~(padded_messages == pad_token_id)
    attention_mask = attention_mask.to(f'cuda:{rank}')
    output = model.generate(input_ids=padded_messages, attention_mask=attention_mask, synced_gpus=True, max_new_tokens=200)

    print(tokenizer.decode(output[0]))

if __name__ == '__main__':
    main()

This results in

Traceback (most recent call last):
  File "/mnt/workdisk/danielking/github/multi-gpu.py", line 47, in <module>
    main()
  File "/mnt/workdisk/danielking/github/multi-gpu.py", line 42, in main
    output = model.generate(input_ids=padded_messages, attention_mask=attention_mask, synced_gpus=True, max_new_tokens=200)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 2024, in generate
    result = self._sample(
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/transformers/generation/utils.py", line 2982, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/workdisk/danielking/miniconda3/envs/foundry-3.10/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 660, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (4105) must match the existing size (4104) at non-singleton dimension 3.  Target sizes: [1, 32, 1, 4105].  Tensor sizes: [1, 1, 1, 4104]

Expected behavior

Multi GPU generate does not error..

Metadata

Metadata

Assignees

No one assigned

    Labels

    GenerationWIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressbug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions