Skip to content

Add support for ragged inputs to model#666

Merged
oliverholworthy merged 15 commits intoNVIDIA-Merlin:mainfrom
oliverholworthy:pad-ragged-inputs-in-model
Apr 7, 2023
Merged

Add support for ragged inputs to model#666
oliverholworthy merged 15 commits intoNVIDIA-Merlin:mainfrom
oliverholworthy:pad-ragged-inputs-in-model

Conversation

@oliverholworthy
Copy link
Copy Markdown
Contributor

@oliverholworthy oliverholworthy commented Apr 5, 2023

Part of NVIDIA-Merlin/Merlin#255

Goals ⚽

  • Enable Transformers4Rec model to be called with ragged input representation.

Implementation Details 🚧

  • Adds pre-processing step to the first part of the forward method of the model that pads any tensors in the ragged representation.
    • Where there are two tensors with names {feature}__values {feature}__offests.
    • Pads all to minimum of the maximum sequence in batch or the model max_sequence_length (if defined)

Testing Details 🔍

  • Adds a test for model with sequence inputs and passing ragged representation inputs

@oliverholworthy oliverholworthy added the enhancement New feature or request label Apr 5, 2023
@oliverholworthy oliverholworthy added this to the Merlin 23.04 milestone Apr 5, 2023
@oliverholworthy oliverholworthy self-assigned this Apr 5, 2023
Comment thread transformers4rec/torch/model/base.py Outdated
Co-authored-by: Marc Romeyn <marcromeyn@gmail.com>
@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 5, 2023

@oliverholworthy oliverholworthy marked this pull request as ready for review April 5, 2023 17:18
)
model_output = model(inference_inputs)

# if the model is traced with ragged inputs it can only be called with ragged inputs
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that when tracing the model, the representation used as the input determines what the inputs to the traced model expects. (padded vs ragged)

return batch

batch_padded = {}
for col_name, col in TensorTable(batch).items():
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorTable is not currently compatible with torch.jit.script compilation.

example of one of the errors that shows up (I don't think it prints out all errors, only the first it encounters -> there may be more unsupported things apart from the below example)

E   torch.jit.frontend.UnsupportedNodeError: SetComp aren't supported:
E     File "/workspace/merlin/core/merlin/table/tensor_table.py", line 61
E       def _validate_columns(self, cols_dict):
E           col_types = {type(col_obj) for col_obj in cols_dict.values()}
E                       ~ <--- HERE
E           if len(col_types) >= 2:
E               raise TypeError(
E   '__torch__.merlin.table.tensor_table.TensorTable' is being compiled since it was called from 'pad_batch'

]
),
)
assert torch.equal(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dense sequence inputs are not padded as part of this pad_inputs currently. Assuming we'll either have ragged or padded sequence inputs, not a mix of both

head_reduction: str = "mean",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
name: str = None,
max_sequence_length: Optional[int] = None,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a max_sequence_length to limit the size of the padding when receiving ragged inputs.



@torch.jit.script
def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None):
def pad_inputs(
inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None
) -> Dict[str, torch.Tensor]:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants