-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Description
System Info
The constructor for Trainer declares the following parameters:
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
):
evidence:
permalink to source
But the doc says
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
evidence:
permalink to source
If I try to instantiate a Trainer with an actual IterableDataset for the train_dataset parameter, PyRight complains (rightly) that
Argument of type "IterableDataset" cannot be assigned to parameter
"train_dataset" of type "Dataset[Unknown] | None" in function "init"
Type "IterableDataset" cannot be assigned to type "Dataset[Unknown] | None"
"IterableDataset" is incompatible with "Dataset[Unknown]"
"IterableDataset" is incompatible with "None"
Please change the type hints for these parameters to allow for IterableDataset values as well.
workaround:
when pyright complains about your source code, follow the params with a commend like this:
train_dataset=my_iterable_ds, # type: ignore (IterableDataset is apparently not envisioned here)
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
# the core issue is this
iterable_ds = any_dataset.to_iterable_dataset()
trainer = Trainer(model=any_model,
train_dataset = iterable_ds, # causes PyRight warnings
args=training_args)# pedantically complete example
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
# Load a portion of the dataset for quick demonstration
dataset = load_dataset('imdb', split='train[:1%]')
# Convert the loaded dataset to an iterable dataset
iterable_ds = dataset.to_iterable_dataset()
# Load a pre-trained model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
# Define training arguments
training_args = TrainingArguments(
output_dir='./results', # Output directory for model checkpoints
num_train_epochs=1, # Total number of training epochs
per_device_train_batch_size=8, # Batch size per device during training
warmup_steps=500, # Number of warmup steps for learning rate scheduler
weight_decay=0.01, # Strength of weight decay
logging_dir='./logs', # Directory for storing logs
logging_steps=10,
)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=iterable_ds, # Causes PyRight warnings
)
Expected behavior
No pyright warnings when passing an IterableDataset object as the train_dataset param of the Trainer constructor.