[Pallas] Introduce make_kernel_from_pallas #6713
Conversation
| return None | ||
|
|
||
|
|
||
| def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype: |
There was a problem hiding this comment.
@qihqi do we already have such converstation somewhere in torchxla2?
There was a problem hiding this comment.
It's generated by copilot, lol
| (x.shape, x.dtype)) | ||
|
|
||
| dtypes = [torch.float32, torch.float | ||
| ] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work. |
There was a problem hiding this comment.
Mosaic complaints. Need to dig more into it.
| import jax | ||
| import jax.numpy as jnp | ||
| import jax._src.pallas.mosaic.pallas_call_registration |
There was a problem hiding this comment.
seems like this is repeated on multiple tests, maybe just move to the top?
There was a problem hiding this comment.
There is a compatibility issue where jax will try to lock tpu devices if we import them before any pt/xla computations... I will need to resolve that...
JackCaoG
left a comment
There was a problem hiding this comment.
Do you need this pr in 2.3?
Yea, will also need a couple for the TODOs. |
| @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") | ||
| # TODO: This test cannot be ran individually, let's fix it. | ||
| def test_tpu_custom_call_pallas_wrap_add_payload(self): | ||
| import jax |
There was a problem hiding this comment.
I am concerned JAX-based tests cause failures due to libtpu version inconsistencies, and in turn CI hiccups. How do we resolve this concern?
|
|
||
|
|
||
| def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable): | ||
| # TODO: Maybe we can cache the payload for the same input. |
There was a problem hiding this comment.
The payload may change if the input is dynamic. We need to confirm this with pallas folks.
There was a problem hiding this comment.
Right, the cache itself should deal with the dynamism.
8b8be2e to
ae6b62b
Compare
|
Can I get any reviews? |
JackCaoG
left a comment
There was a problem hiding this comment.
I still think we should refactor convert_torch_dtype_to_jax and invesgate bf16(which I assume most people will use), approve to unblock.
Yea, for sure. Let me follow up with that. |
Summary: This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op. Test Plan: python test/test_pallas.py
Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.
Test Plan:
python test/test_pallas.py