Description
Hi here! 🤗 Apparently the TRL's CLI command trl sft is not properly capturing the value provided to the --torch_dtype flag as it does not identify it as a string when calling getattr(torch, model_init_kwargs["torch_dtype"]).
Most likely this issue happens on the rest of the implementations since when parsing the torch_dtype provided by the CLI the conversion to torch.dtype happens, and then the SFTTrainer in this case, receives the torch_dtype=torch.bfloat16 instead, and attempts to getattr(torch, torch.bfloat16).
So there's two potential fixes:
- Handling the received type for
torch_dtype within each ...Trainer subclass so as to provide it to the model_init_kwargs as it without the need of calling getattr(torch, ...)
- Respecting the
torch_dtype as a string and letting each ...Trainer subclass do the str -> torch.dtype conversion instead, which is more convenient IMO
To reproduce
trl sft --model_name_or_path=facebook/opt-125m --dataset_name=imdb --dataset_text_field=text --max_steps=1 --torch_dtype=bfloat16 --output_dir=./test
Description
Hi here! 🤗 Apparently the TRL's CLI command
trl sftis not properly capturing the value provided to the--torch_dtypeflag as it does not identify it as a string when callinggetattr(torch, model_init_kwargs["torch_dtype"]).Most likely this issue happens on the rest of the implementations since when parsing the
torch_dtypeprovided by the CLI the conversion totorch.dtypehappens, and then theSFTTrainerin this case, receives thetorch_dtype=torch.bfloat16instead, and attempts togetattr(torch, torch.bfloat16).So there's two potential fixes:
torch_dtypewithin each...Trainersubclass so as to provide it to themodel_init_kwargsas it without the need of callinggetattr(torch, ...)torch_dtypeas a string and letting each...Trainersubclass do thestr->torch.dtypeconversion instead, which is more convenient IMOTo reproduce