Skip to content

[FSDP] caffe2 error in forward method when using fsdp #82461

@pacman100

Description

@pacman100

🐛 Describe the bug

When using FSDP, during inference/evaluation using transformers (gpt2, blenderbot, t5 ...) for generation, i.e., model.generate(), caffe2 error is thrown.

Steps to reproduce the error:

  1. Code is here: run_seq2seq_no_trainer.py
  2. Using 🤗 Accelerate's FSDP integration with config.yaml being below. Using 2 Nvidia Titan RTX GPUs.
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: BlenderbotDecoderLayer
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: 'fp16'
num_machines: 1
num_processes: 2
use_cpu: false
  1. Running below launch command:
accelerate launch --config_file config.yaml \
    run_seq2seq_no_trainer.py \
    --dataset_name "smangrul/MuDoConv" \
    --max_source_length 128 \
    --source_prefix "chatbot: " \
    --max_target_length 64 \
    --val_max_target_length 64 \
    --val_min_target_length 20 \
    --n_val_batch_generations 5 \
    --n_train 10000 \
    --n_val 1000 \
    --pad_to_max_length \
    --num_beams 10 \
    --model_name_or_path "facebook/blenderbot-400M-distill" \
    --per_device_train_batch_size 100 \
    --per_device_eval_batch_size 50 \
    --learning_rate 1e-6 \
    --weight_decay 0.0 \
    --num_train_epochs 1 \
    --gradient_accumulation_steps 1 \
    --num_warmup_steps 100 \
    --output_dir "/tmp/fsdp_test" \
    --seed 25 \
    --logging_steps 100
  1. Output with the error and trace stack:
Traceback (most recent call last):                                                                                             
  File "run_seq2seq_no_trainer.py", line 911, in <module>                                                                      
    bleu_score = evaluate(args, model, metric, tokenizer, eval_dataloader, accelerator, config.max_length)                     
  File "run_seq2seq_no_trainer.py", line 392, in evaluate                                                                      
    generated_tokens = unwrapped_model.generate(                                                                               
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                
    main()                                                                                                                     
  File "run_seq2seq_no_trainer.py", line 840, in main                                                                          
    return func(*args, **kwargs)                                                                                               
  File "/home/sourab/transformers/src/transformers/generation_utils.py", line 1181, in generate                                
    bleu_score = evaluate(args, model, metric, tokenizer, eval_dataloader, accelerator, config.max_length)                     
  File "run_seq2seq_no_trainer.py", line 392, in evaluate                                                                      
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(                                                        
  File "/home/sourab/transformers/src/transformers/generation_utils.py", line 525, in _prepare_encoder_decoder_kwargs_for_gener
ation                                                                                                                          
    generated_tokens = unwrapped_model.generate(                                                                               
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context                
    return func(*args, **kwargs)                                                                                               
  File "/home/sourab/transformers/src/transformers/generation_utils.py", line 1181, in generate                                
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)                                                   
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                     
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(                                                        
  File "/home/sourab/transformers/src/transformers/generation_utils.py", line 525, in _prepare_encoder_decoder_kwargs_for_gener
ation                                                                                                                          
    return forward_call(*input, **kwargs)                                                                                      
  File "/home/sourab/transformers/src/transformers/models/blenderbot/modeling_blenderbot.py", line 736, in forward             
    model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)                                                   
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                     
    inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale                                                            
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                     
    return forward_call(*input, **kwargs)                                                                                      
  File "/home/sourab/transformers/src/transformers/models/blenderbot/modeling_blenderbot.py", line 736, in forward             
    return forward_call(*input, **kwargs)                                                                                      
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/functional.py", line 2199, in embedding
    inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/home/sourab/dev/lib/python3.8/site-packages/torch/nn/functional.py", line 2199, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, s
o you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, s
o you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
  1. The error disappears and everything works if I run model(**dummy_batch) before prediction loop. The concerned snippet in run_seq2seq_no_trainer.py is shown below:
model.eval()
    if args.val_max_target_length is None:
        args.val_max_target_length = args.max_target_length

    gen_kwargs = {
        "max_length": args.val_max_target_length if args is not None else max_length,
        "num_beams": args.num_beams,
        "min_length": args.val_min_target_length,
        "length_penalty": False,
        "no_repeat_ngram_size": 3,
        "encoder_no_repeat_ngram_size": 3,
        "repetition_penalty": 1.2,
    }
    samples_seen = 0
    for step, batch in enumerate(eval_dataloader):
        # had to run this 1 time at the start of eval loop else was giving device `caffe error`.
        # So, before directly using `model.generate` pass a batch with dummy data through the model
        # uncomment below lines for successful run with FSDP
        # if samples_seen == 0 and accelerator.distributed_type == DistributedType.FSDP:
        #     model(**batch)
        with torch.no_grad():
            unwrapped_model = accelerator.unwrap_model(model)
            generated_tokens = unwrapped_model.generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                synced_gpus=True,
                **gen_kwargs,
            )

On uncommenting the corresponding lines, everything works and below is the output:

Screenshot 2022-07-29 at 2 50 40 PM

Expected Output

No caffe2 allocation error is thrown.

Versions

Collecting environment information...
PyTorch version: 1.12.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-6ubuntu2) 7.5.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-122-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 10.2.89
GPU models and configuration:
GPU 0: NVIDIA TITAN RTX
GPU 1: NVIDIA TITAN RTX

Nvidia driver version: 510.73.08
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.0
[pip3] torch==1.12.0
[pip3] torchaudio==0.12.0
[pip3] torchvision==0.13.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang @kwen2501

Metadata

Metadata

Assignees

Labels

high prioritymodule: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriage reviewtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions