Skip to content

Add Python binding for xla::DotGeneral#7863

Merged
lsy323 merged 3 commits intomasterfrom
lsiyuan/dot-general
Aug 16, 2024
Merged

Add Python binding for xla::DotGeneral#7863
lsy323 merged 3 commits intomasterfrom
lsiyuan/dot-general

Conversation

@lsy323
Copy link
Copy Markdown
Collaborator

@lsy323 lsy323 commented Aug 15, 2024

Add a python binding for xla::DotGeneral, the motivation is to represent the semantic of matmul/einsum(dtype1, dtype2)->dtype3, in which dtype3 can be configurable. Now torch.matmul would always use the inferred output dtype.

Function signature is the same as jax.lax.dot_general.

The only limitation is precision config is not supported, it can be extended easily if needed in the future.

Test:
Added unit tests

@lsy323 lsy323 requested a review from JackCaoG August 15, 2024 23:03
@lsy323 lsy323 force-pushed the lsiyuan/dot-general branch from 55ca18a to ccb8459 Compare August 15, 2024 23:37
@lsy323 lsy323 added the tpuci label Aug 15, 2024
@lsy323 lsy323 merged commit 37312c1 into master Aug 16, 2024
@lsy323 lsy323 deleted the lsiyuan/dot-general branch August 16, 2024 22:20
@lsy323 lsy323 self-assigned this Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants