Skip to content

Part 2: (NEW PREDICTION TASK API) Refactoring all the prediction tasks to use the convention required by the Trainer class:i.e. compute loss inside the forward method + return dict with {"loss", "labels" and "predictions"} during training/evaluation mode. #544

@nzarif

Description

@nzarif

This part of refactoring includes 4 parts:

  • Update base PredictionTask class:
    ⁃ Add the targets argument to the forward method
    ⁃ Move compute_loss() inside the forward method
    ⁃ Return output (dict with the three keys or torch.Tensor) based on training and testing flags
    ⁃ Update calculate metrics to pass the targets to the forward call

  • Update Head and Model classes to support the new convention in their forward method call + calculate_metrics

  • Update the fit method [Done]: loss is computed inside the forward call + add flag compute_metrics=True to control whether to compute metrics during training or not. Replace the compute_loss call loss = self.compute_loss(x, y) by :

outputs = self(x, y, training=True)
loss = outputs['loss']
if compute_metrics=True: 
    self.calculate_metrics(outputs['predictions'], outputs['labels'])
  • Update the failing unit tests

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions