Skip to content

Trainer is not so compatible with customized optimizer #15784

@allanj

Description

@allanj

Environment info

  • transformers version: 4.16.2
  • Platform: Linux-5.4.56.bsk.6-amd64-x86_64-with-debian-10.11
  • Python version: 3.7.3
  • PyTorch version (GPU?): 1.10.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help

Information

Model I am using BART for conditional generation

The problem arises when using:

  • trainer with customized optimizer
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • generation, but I guess this problem applies to other tasks

To reproduce

Steps to reproduce the behavior:
The reason to create a customized optimizer: I want to set different learning rate for different parameters
Distributed Training with sharded=simple

  1. create the customized optimizer
model = BART.from_pretrained('facebook/bart-base')

def create_optimizer_and_scheduler(model: nn.Module, sharded_ddp:List[ShardedDDPOption], total_trainin_steps:int,
                     default_lr:float, resnet_lr: float,
                     optim_args:TrainingArguments, weight_decay=0.0):

    decay_parameters = get_parameter_names(model, [nn.LayerNorm])
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    param_no_decay = [p for n, p in model.named_parameters() if n not in decay_parameters]
    resnet_param_with_decay = [p for n, p in model.named_parameters() if "patch_embedding" in n and n in decay_parameters]
    other_param_with_decay = [p for n, p in model.named_parameters() if "patch_embedding" not in n and n in decay_parameters]
    optimizer_grouped_parameters = [
        {
            "params": other_param_with_decay,
            "weight_decay": weight_decay,
            "lr": default_lr
        },
        {
            "params": resnet_param_with_decay,
            "weight_decay": weight_decay,
            "lr": resnet_lr
        },
        {
            "params": param_no_decay,
            "weight_decay": 0.0,
            "lr": default_lr
        },
    ]
    optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(optim_args)

    if ShardedDDPOption.SIMPLE in sharded_ddp:
        optimizer = OSS(
            params=optimizer_grouped_parameters,
            optim=optimizer_cls,
            **optimizer_kwargs,
        )
    else:
        optimizer = AdamW(optimizer_grouped_parameters, **optimizer_kwargs)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0,
                                                num_training_steps=total_trainin_steps)
    return optimizer, scheduler
  1. Put the optimizer to model
trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset.remove_columns('metadata'),
            eval_dataset=eval_dataset.remove_columns('metadata'),
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=(build_compute_metrics_fn(tokenizer=tokenizer)),
            optimizer=create_optimizer_and_scheduler(model = model, ....)
        )

Expected behavior

Traceback (most recent call last):
  File "vl_bart_main.py", line 216, in <module>
    hf_trainer_main()
  File "vl_bart_main.py", line 157, in hf_trainer_main
    train_result = trainer.train()
  File "/home/tiger/.local/lib/python3.7/site-packages/transformers/trainer.py", line 1365, in train
    tr_loss_step = self.training_step(model, inputs)
  File "/home/tiger/.local/lib/python3.7/site-packages/transformers/trainer.py", line 1940, in training_step
    loss = self.compute_loss(model, inputs)
  File "/home/tiger/.local/lib/python3.7/site-packages/transformers/trainer.py", line 1972, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1120, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/tiger/.local/lib/python3.7/site-packages/fairscale/nn/data_parallel/sharded_ddp.py", line 219, in forward
    self.refresh_trainable()
  File "/home/tiger/.local/lib/python3.7/site-packages/fairscale/nn/data_parallel/sharded_ddp.py", line 300, in refresh_trainable
    optim.refresh_trainable()
  File "/home/tiger/.local/lib/python3.7/site-packages/fairscale/optim/oss.py", line 478, in refresh_trainable
    self._setup_flat_buffers()
  File "/home/tiger/.local/lib/python3.7/site-packages/fairscale/optim/oss.py", line 652, in _setup_flat_buffers
    bucket.add_param(param)
  File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/tiger/.local/lib/python3.7/site-packages/fairscale/nn/misc/param_bucket.py", line 69, in add_param
    self._add_param_as_view(param)
  File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/tiger/.local/lib/python3.7/site-packages/fairscale/nn/misc/param_bucket.py", line 81, in _add_param_as_view
    ), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}"
AssertionError: Different devices for the bucket and the param, cannot proceed: cuda:2 - cpu

Suspected Reason

The model is wrapped with something else in the initialization of Trainer, but that changes do not reflect in the process of optimizer instantiation.

My Workaround Solution

Get the optimizer after initialize the trainer

model = BART.from_pretrained('facebook/bart-base')
trainer = Seq2SeqTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset.remove_columns('metadata'),
            eval_dataset=eval_dataset.remove_columns('metadata'),
            data_collator=data_collator,
            tokenizer=tokenizer,
            compute_metrics=(build_compute_metrics_fn(tokenizer=tokenizer)),
        )
## get trainer's model
optim_scheduler = create_optimizer_and_scheduler(model=trainer.model, ....) 
## override the default optimizer
trainer.optimizer = optim_scheduler[0]

But this approach seems not really elegant.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions