Skip to content

[Task] Standardize the model output format.  #505

@sararb

Description

@sararb

Description

Currently, the T4Rec model outputs can be either:

  • A tensor: representing the predictions returned by a single task.
  • A dictionary of 3 tensors ({'loss': .., 'labels':..., 'predictions':....}}) representing the output of NextItemPredictionTask with hf_format=True
  • A dictionary where the key is the task name and the value is the prediction tensor or the 3 tensors {'loss': .., 'labels':..., 'predictions':....}}

The objective of this task is to unify the output format returned by the Model and make sure the right information is returned based on whether we are in training, testing, or inference mode.

Additional context

Steps for refactoring:

Three tasks
  • Set the model to the inference mode by default + ensure training script is not impacted + update failing unit-tests [Done]

  • Remove the hf-format from Next Item Prediction task + ensure training / inference are not impacted + update failing unit-tests [Done]

  • Replace ignore_masking with testing flag + ensure training / inference are not impacted + update failing unit-tests [Done]

Four tasks
  • Update base PredictionTask class: [Done]
    ⁃ Add the targets argument to the forward method
    ⁃ Move compute_loss() inside the forward method
    ⁃ Return output (dict with the three keys or torch.Tensor) based on training and testing flags
    ⁃ Update calculate metrics to pass the targets to the forward call

  • Update Head and Model classes to support the new convention in their forward method call + calculate_metrics [Done]

  • Update the fit method [in progress]: loss is computed inside the forward call + add flag compute_metrics=True to control whether to compute metrics during training or not. Replace the compute_loss call loss = self.compute_loss(x, y) by :

outputs = self(x, y, training=True)
loss = outputs['loss']
if compute_metrics=True: 
    self.calculate_metrics(outputs['predictions'], outputs['labels'], mode='train', forward=False, call_body=False)
  • Update the failing unit tests [Open] [ Blocked by Part 1]
Three tasks
  • Update the testing/training flags inside the predict_step and eval_loop are working correctly. [Done]

  • Move the top-k logic (specific to the next item prediction task) to a function of the Trainer class and call it inside the eval_loop only when the item-prediction task is used. [Open]

  • Add a unit test for testing the trained with binary and/or regression tasks. [In progress]

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions