[Pallas] PoC Integration#6340
Conversation
46b8228 to
d43df7f
Compare
|
|
Lol, fair enough. Will guard that. |
| # def add_vectors_kernel(x_ref, y_ref, o_ref): | ||
| # x, y = x_ref[...], y_ref[...] | ||
| # o_ref[...] = x + y | ||
| payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}" |
There was a problem hiding this comment.
does this test needs to change everytime we update libtpu/openxla?
There was a problem hiding this comment.
Hard to tell. It depends on if the Mosaic API is good or not.
Later on once I have developed the way to extract the payload systematically. We could import JAX to do the lowering instead.
JackCaoG
left a comment
There was a problem hiding this comment.
mostly lgtm, if you can update the test device guard and all CI can pass, feel free to merge
37ed6bb to
29a73c7
Compare
|
Thanks, Jack. Kicked off TPU CI as well. |
|
@alanwaketan is the design publicly available? wdyt we put out a RFC? |
Will do that once the design is fully fledge. |
Summary: This is PoC for Pallas integration. Currently, it can run Pallas kernels that take arbitrary tensors as inputs and output a single tensor. The design doc is here: go/pytorch-xla-pallas. Test Plan: PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call
Summary:
This is PoC for Pallas integration. Currently, it can run Pallas kernels that take arbitrary tensors as inputs and output a single tensor. The design doc is here: go/pytorch-xla-pallas.
Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call