Skip to content

Type annotation for train_dataset and eval_dataset params of Trainer incompatible with IterableDataset #29678

@stevemadere

Description

@stevemadere

System Info

The constructor for Trainer declares the following parameters:

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):


evidence:
permalink to source

But the doc says

   train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
        The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
        `model.forward()` method are automatically removed.

evidence:
permalink to source

If I try to instantiate a Trainer with an actual IterableDataset for the train_dataset parameter, PyRight complains (rightly) that

Argument of type "IterableDataset" cannot be assigned to parameter
"train_dataset" of type "Dataset[Unknown] | None" in function "init"
Type "IterableDataset" cannot be assigned to type "Dataset[Unknown] | None"
"IterableDataset" is incompatible with "Dataset[Unknown]"
 "IterableDataset" is incompatible with "None"

Please change the type hints for these parameters to allow for IterableDataset values as well.

@muellerz , @pacman100

workaround:

when pyright complains about your source code, follow the params with a commend like this:
train_dataset=my_iterable_ds, # type: ignore (IterableDataset is apparently not envisioned here)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

# the core issue is this
iterable_ds = any_dataset.to_iterable_dataset()
trainer = Trainer(model=any_model, 
                            train_dataset = iterable_ds, # causes PyRight warnings
                            args=training_args)
# pedantically complete example
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# Load a portion of the dataset for quick demonstration
dataset = load_dataset('imdb', split='train[:1%]')

# Convert the loaded dataset to an iterable dataset
iterable_ds = dataset.to_iterable_dataset()

# Load a pre-trained model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # Output directory for model checkpoints
    num_train_epochs=1,              # Total number of training epochs
    per_device_train_batch_size=8,   # Batch size per device during training
    warmup_steps=500,                # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # Strength of weight decay
    logging_dir='./logs',            # Directory for storing logs
    logging_steps=10,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args, 
    train_dataset=iterable_ds, # Causes PyRight warnings
)

Expected behavior

No pyright warnings when passing an IterableDataset object as the train_dataset param of the Trainer constructor.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions