Skip to content

Part 3: ADD support of training regression and binary classification to the Trainer class #554

@nzarif

Description

@nzarif

Tasks included in this part:

  • Analyze and verify the logical correctness of calling calculate_metrics from Model.fit() at this line.

  • Add a simple script for evaluation of metrics and results when using Trainer for BinaryClassification and Regression such that they can be used for CI purposes like this one that is being used for NextItemPrediction.

  • Version control fixes:
    _ Rebase with refactor_part2 and main to get synthetic music data testing for BC and regression.
    _ Pull the changes of fix-inference branch to get latest predict_step and the updated ci tes

  • Update the testing/training flags inside the predict_step and eval_loop and test they are working correctly.

  • Move the top-k logic (specific to the next item prediction task) to a function of the Trainer class and call it inside the eval_loop only when the item-prediction task is used.

  • Customize train_step to get the targets for the predictions tasks:
    _ iterate over the dataloader in training_loop using input that contains labels/targets added to inputs.
    _ Call training_step inside training loop and make sure inputs encompassed with labels are passed to it.
    _ Call compute_loss() from training_step() and make sure labels arrive there with inputs.
    _ Make sure the model() is called with correct arguments passed to it.

  • Use HF nested_detach() inside T4Rec.Trainer.prediction_step() to handle cases when we have dict instead of Tensor. This will make sure corresponding inputs and targets are placed on the same device when using multiple GPUs.

  • Add a unit test for testing the trained with binary and/or regression tasks.

  • Explain the support of BC and regression tasks in the documentation of the Trainer class.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions