Skip to content

Incorrect typing of fsdp, fsdp_config, and sharded_ddp in TrainingArguments #24538

@O-T-O-Z

Description

@O-T-O-Z

System Info

  • transformers version: 4.29.2
  • Platform: macOS-14.0-arm64-arm-64bit
  • Python version: 3.10.6
  • Huggingface_hub version: 0.15.1
  • Safetensors version: not installed
  • PyTorch version (GPU?): 1.13.0.dev20220902 (False)
  • Tensorflow version (GPU?): 2.9.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@sgugger

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When initializing a pydantic.BaseModel as follows:

from pydantic import BaseModel
from transformers.training_args import TrainingArguments

class MyTrainingArguments(TrainingArguments):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.my_arg = "my_arg"


class MyModel(BaseModel):
    training_args: MyTrainingArguments


model = MyModel(training_args=MyTrainingArguments(output_dir=""))

The following ValidationErrors occur:

ValidationError: 4 validation errors for MyModel
training_args -> debug
  str type expected (type=type_error.str)
training_args -> sharded_ddp
  str type expected (type=type_error.str)
training_args -> fsdp
  str type expected (type=type_error.str)
training_args -> fsdp_config
  str type expected (type=type_error.str)

Since debug has been fixed in #24033, my main concern are the others.
After investigation, I discovered that the __post_init__()-method changes these parameters from their default str values to for example dict, bool, or List. This becomes a problem for Pydantic (and other type-checkers) since the validation will be incorrect, while the docstring of TrainingArguments describes the following for these parameters:

"""
sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `False`)
fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`)
fsdp_config (`str` or `dict`, *optional*)
"""

Expected behavior

I would like to resolve these issues by providing the correct typehinting. This could look as follows:

sharded_ddp: Union[Optional[str], bool, List[ShardedDDPOption]]
fsdp: Union[Optional[str], bool, List[FSDPOption]]
fsdp_config: Union[Optional[str], Dict]

I checked this configuration and it resolves the issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions