Skip to content

Fix the predictions returned by Trainer.predict(..)#641

Merged
rnyak merged 4 commits intomainfrom
fix/trainer
Mar 15, 2023
Merged

Fix the predictions returned by Trainer.predict(..)#641
rnyak merged 4 commits intomainfrom
fix/trainer

Conversation

@sararb
Copy link
Copy Markdown
Contributor

@sararb sararb commented Mar 10, 2023

Fixes #636

Goals ⚽

  • Fix the predictions returned by Trainer.predict().
  • Ensure predict_top_k can only be applied to NextItemPredictionTask.
  • Fix the output shapes of the model's predictions in the sequential binary classification or the regression task.

Implementation Details 🚧

Testing Details 🔍

  • Add test_sequential_binary_classification_model to test a sequential model trained with a per-item binary classification task.
  • Checks the shapes of predictions in test_trainer_music_streaming
  • Add test_trainer_trop_k_with_wrong_task to ensure the BC task cannot be used with predict_top_k > 0

@sararb sararb added bug Something isn't working area/pytorch labels Mar 10, 2023
@sararb sararb added this to the Merlin 23.03 milestone Mar 10, 2023
@sararb sararb requested a review from rnyak March 10, 2023 23:22
@sararb sararb self-assigned this Mar 10, 2023
def forward(
self, inputs: TensorOrTabularData, targets=None, training=False, testing=False, **kwargs
):
def forward(self, inputs: TabularData, targets=None, training=False, testing=False, **kwargs):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T4Rec only supports a dictionary of tensors (i.e. TabularData) as inputs.

torch.nn.Linear(input_size[-1], 1, bias=False),
torch.nn.Sigmoid(),
LambdaModule(lambda x: x.view(-1)),
LambdaModule(lambda x: torch.squeeze(x, -1)),
Copy link
Copy Markdown
Contributor Author

@sararb sararb Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of sequential binary classification, we need to keep the second dimension (i.e. sequence length)

@rnyak
Copy link
Copy Markdown
Contributor

rnyak commented Mar 14, 2023

rerun tests

@github-actions
Copy link
Copy Markdown

@rnyak rnyak merged commit abc7a49 into main Mar 15, 2023
@rnyak rnyak deleted the fix/trainer branch March 15, 2023 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area/pytorch bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] BinaryClassificationTask prediction does not return one score per session with summary_type='mean'

2 participants