Skip to content

[train] Add hf trainer support for dictionary of datasets#56484

Merged
matthewdeng merged 1 commit intoray-project:masterfrom
wyhong3103:add-hf-eval-dict-support
Sep 12, 2025
Merged

[train] Add hf trainer support for dictionary of datasets#56484
matthewdeng merged 1 commit intoray-project:masterfrom
wyhong3103:add-hf-eval-dict-support

Conversation

@wyhong3103
Copy link
Copy Markdown
Contributor

Use case

To provide support for evaluating model via HuggingFace trainer using a dictionary of Ray datasets. This is to align with the eval_dataset argument type in get_eval_dataloader in transformers.

Why are these changes needed?

If users provide a dictionary of Ray datasets as eval_datasets to HuggingFace Trainer, it will fail during evaluation step because the current dataloader returned is NOT of the type RayTorchIterableDataset.

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
@wyhong3103 wyhong3103 requested a review from a team as a code owner September 12, 2025 04:38
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for using a dictionary of Ray Datasets for evaluation in the HuggingFace Trainer integration. The changes look good and are well-tested. I've found a potential bug in _transformers_utils.py that could lead to a KeyError and suggested a fix that also improves code clarity. I've also left a couple of comments on the test files to suggest refactoring to reduce code duplication and improve maintainability.

Comment on lines +134 to 143
if (
isinstance(eval_dataset, str)
and isinstance(self.eval_dataset, dict)
and isinstance(self.eval_dataset[eval_dataset], _IterableFromIterator)
):
dataset = RayTorchIterableDataset(self.eval_dataset[eval_dataset])
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
elif isinstance(eval_dataset, _IterableFromIterator):
dataset = RayTorchIterableDataset(eval_dataset)
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation has a potential KeyError if eval_dataset is a string but not a valid key in self.eval_dataset. The expression self.eval_dataset[eval_dataset] would raise an exception before isinstance is checked. It's safer to check for key existence first.

Additionally, there's some code duplication in creating the DataLoader. I suggest refactoring to address both issues, which also improves readability.

Suggested change
if (
isinstance(eval_dataset, str)
and isinstance(self.eval_dataset, dict)
and isinstance(self.eval_dataset[eval_dataset], _IterableFromIterator)
):
dataset = RayTorchIterableDataset(self.eval_dataset[eval_dataset])
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
elif isinstance(eval_dataset, _IterableFromIterator):
dataset = RayTorchIterableDataset(eval_dataset)
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])
iterable_dataset = None
if (
isinstance(eval_dataset, str)
and isinstance(self.eval_dataset, dict)
and eval_dataset in self.eval_dataset
):
iterable_dataset = self.eval_dataset[eval_dataset]
elif isinstance(eval_dataset, _IterableFromIterator):
iterable_dataset = eval_dataset
if isinstance(iterable_dataset, _IterableFromIterator):
dataset = RayTorchIterableDataset(iterable_dataset)
return DataLoader(dataset, batch_size=1, collate_fn=lambda x: x[0])

Comment on lines +90 to +107
if config["use_dict_eval_datasets"]:
eval_ds_shard_1 = ray.train.get_dataset_shard("eval_1")
eval_ds_shard_2 = ray.train.get_dataset_shard("eval_2")

eval_dataset = {
"eval_1": eval_ds_shard_1.iter_torch_batches(
batch_size=BATCH_SIZE_PER_WORKER
),
"eval_2": eval_ds_shard_2.iter_torch_batches(
batch_size=BATCH_SIZE_PER_WORKER
),
}
else:
eval_ds_shard = ray.train.get_dataset_shard("eval")

eval_dataset = eval_ds_shard.iter_torch_batches(
batch_size=BATCH_SIZE_PER_WORKER
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication in how eval_dataset is created. You could extract the common iter_torch_batches call into a helper function to make this more concise and maintainable.

        def get_eval_batch_iter(shard_name: str):
            shard = ray.train.get_dataset_shard(shard_name)
            return shard.iter_torch_batches(batch_size=BATCH_SIZE_PER_WORKER)

        if config["use_dict_eval_datasets"]:
            eval_dataset = {
                "eval_1": get_eval_batch_iter("eval_1"),
                "eval_2": get_eval_batch_iter("eval_2"),
            }
        else:
            eval_dataset = get_eval_batch_iter("eval")

Comment on lines +307 to +406
def test_e2e_dict_eval_ray_data(ray_start_6_cpus_2_gpus, config_id):
def train_func(config):
# Datasets
if config["use_ray_data"]:
train_ds_shard = ray.train.get_dataset_shard("train")
eval_ds_shard_1 = ray.train.get_dataset_shard("eval_1")
eval_ds_shard_2 = ray.train.get_dataset_shard("eval_2")

train_dataset = train_ds_shard.iter_torch_batches(
batch_size=BATCH_SIZE_PER_WORKER
)
eval_dataset = {
"eval_1": eval_ds_shard_1.iter_torch_batches(
batch_size=BATCH_SIZE_PER_WORKER
),
"eval_2": eval_ds_shard_2.iter_torch_batches(
batch_size=BATCH_SIZE_PER_WORKER
),
}
else:
train_df = pd.read_json(train_data)
validation_df = pd.read_json(validation_data)

train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(validation_df)

# Model
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_config(model_config)

# HF Transformers Trainer
training_args = TrainingArguments(
f"{MODEL_NAME}-wikitext2",
evaluation_strategy=config["evaluation_strategy"],
logging_strategy=config["logging_strategy"],
save_strategy=config["save_strategy"],
eval_steps=config["eval_steps"],
save_steps=config["save_steps"],
logging_steps=config["logging_steps"],
num_train_epochs=config.get("num_train_epochs", MAX_EPOCHS),
max_steps=config.get("max_steps", -1),
learning_rate=config.get("learning_rate", 2e-5),
per_device_train_batch_size=BATCH_SIZE_PER_WORKER,
per_device_eval_batch_size=BATCH_SIZE_PER_WORKER,
weight_decay=0.01,
disable_tqdm=True,
no_cuda=config["no_cuda"],
report_to="none",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)

# Report to Ray Train
trainer.add_callback(RayTrainReportCallback())
trainer = prepare_trainer(trainer)

# Start Training
trainer.train()

train_loop_config = CONFIGURATIONS[config_id]

# Must specify `max_steps` for Iterable Dataset
train_loop_config["use_ray_data"] = True
train_loop_config["max_steps"] = MAX_STEPS

# Calculate the num of Ray training iterations
num_iterations = MAX_STEPS // train_loop_config["save_steps"]

train_df = pd.read_json(train_data)
validation_df = pd.read_json(validation_data)

ray_train_ds = ray.data.from_pandas(train_df)
ray_eval_ds_1 = ray.data.from_pandas(validation_df)
ray_eval_ds_2 = ray.data.from_pandas(validation_df)

use_gpu = not train_loop_config["no_cuda"]

trainer = TorchTrainer(
train_func,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(num_workers=NUM_WORKERS, use_gpu=use_gpu),
datasets={
"train": ray_train_ds,
"eval_1": ray_eval_ds_1,
"eval_2": ray_eval_ds_2,
},
)
result = trainer.fit()

assert result.metrics["step"] == MAX_STEPS
assert result.checkpoint
assert isinstance(result.checkpoint, Checkpoint)
assert len(result.best_checkpoints) == num_iterations
assert "eval_eval_1_loss" in result.metrics
assert "eval_eval_2_loss" in result.metrics

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new test function test_e2e_dict_eval_ray_data duplicates a lot of code from test_e2e_hf_data and test_e2e_ray_data in the same file, especially the inlined train_func. This makes the test suite harder to maintain. Consider refactoring by extracting the common logic into a shared helper function. For example, the model and trainer setup logic is identical across these tests and could be moved to a helper.

Copy link
Copy Markdown
Contributor

@xinyuangui2 xinyuangui2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review is already done in #56046

@matthewdeng matthewdeng enabled auto-merge (squash) September 12, 2025 05:35
@github-actions github-actions bot added the go add ONLY when ready to merge, run all tests label Sep 12, 2025
@matthewdeng matthewdeng merged commit e9670ed into ray-project:master Sep 12, 2025
6 checks passed
ZacAttack pushed a commit to ZacAttack/ray that referenced this pull request Sep 24, 2025
…t#56484)

To provide support for evaluating model via HuggingFace trainer using a
dictionary of Ray datasets. This is to align with the `eval_dataset`
argument type in `get_eval_dataloader` in
[transformers](https://github.com/huggingface/transformers/blob/v4.56.1/src/transformers/trainer.py#L1196).


Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Co-authored-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Signed-off-by: zac <zac@anyscale.com>
marcostephan pushed a commit to marcostephan/ray that referenced this pull request Sep 24, 2025
…t#56484)

To provide support for evaluating model via HuggingFace trainer using a
dictionary of Ray datasets. This is to align with the `eval_dataset`
argument type in `get_eval_dataloader` in
[transformers](https://github.com/huggingface/transformers/blob/v4.56.1/src/transformers/trainer.py#L1196).

Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Co-authored-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Signed-off-by: Marco Stephan <marco@magic.dev>
dstrodtman pushed a commit that referenced this pull request Oct 6, 2025
To provide support for evaluating model via HuggingFace trainer using a
dictionary of Ray datasets. This is to align with the `eval_dataset`
argument type in `get_eval_dataloader` in
[transformers](https://github.com/huggingface/transformers/blob/v4.56.1/src/transformers/trainer.py#L1196).

Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Co-authored-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
justinyeh1995 pushed a commit to justinyeh1995/ray that referenced this pull request Oct 20, 2025
…t#56484)

To provide support for evaluating model via HuggingFace trainer using a
dictionary of Ray datasets. This is to align with the `eval_dataset`
argument type in `get_eval_dataloader` in
[transformers](https://github.com/huggingface/transformers/blob/v4.56.1/src/transformers/trainer.py#L1196).


Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Co-authored-by: yenhong.wong <yenhong.wong@grabtaxi.com>
snorkelopstesting1-a11y pushed a commit to snorkel-marlin-repos/ray-project_ray_pr_56484_e4be0789-6ccb-4da8-b909-00dddc13a5df that referenced this pull request Oct 22, 2025
snorkelopsstgtesting1-spec added a commit to snorkel-marlin-repos/ray-project_ray_pr_56484_e4be0789-6ccb-4da8-b909-00dddc13a5df that referenced this pull request Oct 22, 2025
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…t#56484)

To provide support for evaluating model via HuggingFace trainer using a
dictionary of Ray datasets. This is to align with the `eval_dataset`
argument type in `get_eval_dataloader` in
[transformers](https://github.com/huggingface/transformers/blob/v4.56.1/src/transformers/trainer.py#L1196).


Signed-off-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Co-authored-by: yenhong.wong <yenhong.wong@grabtaxi.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

go add ONLY when ready to merge, run all tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants