Skip to content

[Pallas] PoC Integration#6340

Merged
alanwaketan merged 9 commits intomasterfrom
alanwaketan/pallas
Jan 30, 2024
Merged

[Pallas] PoC Integration#6340
alanwaketan merged 9 commits intomasterfrom
alanwaketan/pallas

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

@alanwaketan alanwaketan commented Jan 20, 2024

Summary:
This is PoC for Pallas integration. Currently, it can run Pallas kernels that take arbitrary tensors as inputs and output a single tensor. The design doc is here: go/pytorch-xla-pallas.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call

@alanwaketan alanwaketan self-assigned this Jan 20, 2024
@alanwaketan alanwaketan changed the title [WIP] Pallas Integration PoC [Pallas] PoC Integration Jan 26, 2024
@alanwaketan alanwaketan marked this pull request as ready for review January 26, 2024 23:50
@JackCaoG
Copy link
Copy Markdown
Collaborator

test_tpu_custom_call_pallas_add crashed on CPU, if it is a TPU only test I think you need to guard it

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

test_tpu_custom_call_pallas_add crashed on CPU, if it is a TPU only test I think you need to guard it

Lol, fair enough. Will guard that.

Comment thread test/test_operations.py
# def add_vectors_kernel(x_ref, y_ref, o_ref):
# x, y = x_ref[...], y_ref[...]
# o_ref[...] = x + y
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this test needs to change everytime we update libtpu/openxla?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to tell. It depends on if the Mosaic API is good or not.

Later on once I have developed the way to extract the payload systematically. We could import JAX to do the lowering instead.

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, if you can update the test device guard and all CI can pass, feel free to merge

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Thanks, Jack. Kicked off TPU CI as well.

@alanwaketan alanwaketan merged commit 56db8f2 into master Jan 30, 2024
@miladm
Copy link
Copy Markdown
Collaborator

miladm commented Feb 2, 2024

@alanwaketan is the design publicly available? wdyt we put out a RFC?

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

@alanwaketan is the design publicly available? wdyt we put out a RFC?

Will do that once the design is fully fledge.

bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Summary:
This is PoC for Pallas integration. Currently, it can run Pallas kernels that take arbitrary tensors as inputs and output a single tensor. The design doc is here: go/pytorch-xla-pallas.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants