Skip to content

[BUG] TRL CLI not capturing torch_dtype correctly #1751

@alvarobartt

Description

@alvarobartt

Description

Hi here! 🤗 Apparently the TRL's CLI command trl sft is not properly capturing the value provided to the --torch_dtype flag as it does not identify it as a string when calling getattr(torch, model_init_kwargs["torch_dtype"]).

Most likely this issue happens on the rest of the implementations since when parsing the torch_dtype provided by the CLI the conversion to torch.dtype happens, and then the SFTTrainer in this case, receives the torch_dtype=torch.bfloat16 instead, and attempts to getattr(torch, torch.bfloat16).

So there's two potential fixes:

  • Handling the received type for torch_dtype within each ...Trainer subclass so as to provide it to the model_init_kwargs as it without the need of calling getattr(torch, ...)
  • Respecting the torch_dtype as a string and letting each ...Trainer subclass do the str -> torch.dtype conversion instead, which is more convenient IMO

To reproduce

trl sft --model_name_or_path=facebook/opt-125m --dataset_name=imdb  --dataset_text_field=text --max_steps=1 --torch_dtype=bfloat16 --output_dir=./test

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions