Conversation
|
I was able to compile by adding the following patch to OpenXLA: |
Thank you @ysiraichi! I added this patch for now. |
|
Persistent cache test is failing on GPU, due to deserialization issue. Skipping the test for now and will file a Github Issue for this. |
ManfeiBai
left a comment
There was a problem hiding this comment.
Thanks for the amazing work, its a really huge change adopt PR, LGTM
|
Thanks @lsy323 for updating the pin. Regarding the paged_attention hang, could you update this line xla/torch_xla/experimental/custom_kernel.py Line 1212 in c044c69 step = torch.ones((1,), dtype=torch.int32).to("xla")? It should make the test pass. I tested locally.
|
Thanks @vanbasten23! Updated the PR. Also do you mind elaborating a bit on this? |
#8908 accidentally enabled some pallas tests on CPU, which is not supported
Yeah, jax-ml/jax@8c73799 made a change (it's not a bug but a valid change). As a result, the torch_xla wrapper needs to change accordingly. |
Accommodate the following changes:
xla::Shape::rank()is renamed toxla::Shape::dimensions_sizexla::Shapector