Conversation
|
Thanks @bhavya01! Is this ready to be reviewed? Also can you add a corresponding cpp unit test as it is a new op? |
|
@wonjoolee95 This PR should be ready for a review for #5982 |
|
Do we already have a cpp test for this op? Also wil |
|
We do have a cpp test for the op https://github.com/pytorch/xla/blob/master/test/cpp/test_aten_xla_tensor_5.cpp#L250 I think I will create a separate PR for embedding bag |
|
@bhavya01, can we ensure that the lowering get invoked correctly by adding the metric check in the Thanks, let's also work on the |
|
Added the check for metrics. EmbeddingBag is still WIP. |
wonjoo-wj
left a comment
There was a problem hiding this comment.
Thanks!
I left a few comments as follow-up items. Not a blocker, so feel free to merge this PR. But let's follow-up on the items in the next PR (preferably the EmbeddingBag PR).
| /*sparse=*/false); | ||
| AllClose(b, xla_b); | ||
| ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); | ||
| ExpectCounterChanged("xla::embedding_symint", |
There was a problem hiding this comment.
nit: xla::embedding_syntint -> xla::embedding*. Not blocker, we can update in the EmbeddingBag PR.
There was a problem hiding this comment.
Also want to confirm one thing -- before implementing this lowering (i.e. without the changes in this PR), does this metric assertion fail as expected? Just wanted to make sure that the _symint variant is getting recognized properly. It should, but just wanted to confirm.
But not a blocker, we can follow-up in the EmbeddingBag PR.
| scale_grad_by_freq, sparse); | ||
| } | ||
| // TODO: for now route to native, which dispatches supported XLA operations. | ||
| // We need to make use of the TPU embedding core here eventually. |
There was a problem hiding this comment.
This comment (seems like it existed for a while) We need to make use of the TPU embedding core here eventually. makes me a bit cautious about the Embedding op. I assume this comment means to lower this op in a performant way, we somehow want to make use of the TPU embedding core? Just leaving this comment as a future reference. Also let's bring this comment back in the code so we are aware. Again, not a blocker, we can include as a follow-up in the next PR.
There was a problem hiding this comment.
Sounds good. Will add it back in a follow up PR

Reference: https://github.com/pytorch/pytorch/blob/113138aa5575301d914c18bd882d6ab3735aa18a/aten/src/ATen/native/Embedding.cpp#L37
The native pytorch implmentation also uses just the indices and weight matrix. Other options seem to be ignored.