Skip to content

Support S32/U32 indices for BWD embedding & Neuron implicit downcast#8462

Merged
tengyifei merged 3 commits intopytorch:masterfrom
rpsilva-aws:rpsilva_downcast_v2
Dec 7, 2024
Merged

Support S32/U32 indices for BWD embedding & Neuron implicit downcast#8462
tengyifei merged 3 commits intopytorch:masterfrom
rpsilva-aws:rpsilva_downcast_v2

Conversation

@rpsilva-aws
Copy link
Copy Markdown
Collaborator

In this PR, we extend embedding tensor operations to allow S32 indices. This follows suits with other operations, in order to add flexibility and potentially performance benefits for accelerator backends. Reference for embedding dense bwd: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Embedding.cpp#L117

In addition, we also re-introduce the implicit downcasting for Neuron S64/U64 types, since the Neuron compiler does not support 64 bits.

There is an ongoing effort to further extend this requirement to other tensor operations involving indices: pytorch/pytorch#142160. Once this is resolved, we adapt it on XLA as well.

@rpsilva-aws rpsilva-aws changed the title Rpsilva downcast v2 Support S32/U32 indices for BWD embedding & Neuron implicit downcast Dec 6, 2024
@rpsilva-aws rpsilva-aws marked this pull request as ready for review December 6, 2024 00:28
@rpsilva-aws
Copy link
Copy Markdown
Collaborator Author

FYI, I split the previous PR: @miladm @ManfeiBai @tengyifei, this one is needed for 2.6. Unfortunately #8463 has a dependency on PT.

@tengyifei tengyifei added the tpuci label Dec 6, 2024
Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test at all?

Comment thread torch_xla/csrc/dtype.cpp Outdated
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_downcast_v2 branch 2 times, most recently from c2fb7ef to 95d0f0c Compare December 6, 2024 02:26
@rpsilva-aws
Copy link
Copy Markdown
Collaborator Author

@tengyifei Ran yapf over the test file. PTAL, thanks!

@tengyifei tengyifei merged commit 00c0e96 into pytorch:master Dec 7, 2024
@rpsilva-aws rpsilva-aws deleted the rpsilva_downcast_v2 branch December 9, 2024 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants