Skip to content

[BUG] Transformer predict does not use path from parameters #570

@bschifferer

Description

@bschifferer

Bug description

trainer.predict(test_paths) does not use the files parsed in the function to evaluate.

predict calls get_test_dataloader

trainer.predict(
    test_dataset: torch.utils.data.dataset.Dataset,
    ignore_keys: Union[List[str], NoneType] = None,
    metric_key_prefix: str = 'test',
) -> transformers.trainer_utils.PredictionOutput
        self._memory_tracker.start()
        test_dataloader = self.get_test_dataloader(test_dataset) 

predict initialize the test dataloader from self.test_dataset_or_path
see https://github.com/NVIDIA-Merlin/Transformers4Rec/blob/main/transformers4rec/torch/trainer.py#L204-L209

        test_dataset = test_dataset if test_dataset is not None else self.test_dataset
        assert self.schema is not None, "schema is required to generate Test Dataloader"
        return T4RecDataLoader.parse(self.args.data_loader_engine).from_schema(
            self.schema,
            self.test_dataset_or_path, ### <- This is incorrect?
            self.args.per_device_eval_batch_size,
            max_sequence_length=self.args.max_sequence_length,
            drop_last=self.args.dataloader_drop_last,
            shuffle=False,
            shuffle_buffer_size=self.args.shuffle_buffer_size,
        )

Work around could be

trainer.test_dataset_or_path = paths
prediction = trainer.predict(paths)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions