Skip to content

Remove torch.squeeze() step from the model's forward method. #659

Merged
sararb merged 1 commit intomainfrom
fix_inference
Mar 28, 2023
Merged

Remove torch.squeeze() step from the model's forward method. #659
sararb merged 1 commit intomainfrom
fix_inference

Conversation

@sararb
Copy link
Copy Markdown
Contributor

@sararb sararb commented Mar 27, 2023

  • The new data loader is now returning scalar inputs as a 1-D tensor instead of a 2-D so the torch.squeeze() op inside the model's forward method is no longer needed.
  • This torch.squeeze() op was also raising an issue at inference when we pass a sequence with one element only. In fact, by calling torch.squeeze() on all inputs we lose the information to separate between scalar features and list features with only one element. The new data loader is ensuring that list input with one element are kept as 2-D.
  • I removed the torch.squeeze() op and added a check for applying the model on input sequences with one element, at inference.

@sararb sararb added bug Something isn't working area/inference area/pytorch labels Mar 27, 2023
@sararb sararb added this to the Merlin 23.03 milestone Mar 27, 2023
@sararb sararb requested review from angmc and edknv March 27, 2023 21:25
@sararb sararb self-assigned this Mar 27, 2023
@github-actions
Copy link
Copy Markdown

@sararb sararb merged commit f5a0d42 into main Mar 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants