Skip to content

[Pallas] Introduce make_kernel_from_pallas #6713

Merged
alanwaketan merged 5 commits intomasterfrom
alanwaketan/pallas_api
Mar 13, 2024
Merged

[Pallas] Introduce make_kernel_from_pallas #6713
alanwaketan merged 5 commits intomasterfrom
alanwaketan/pallas_api

Conversation

@alanwaketan
Copy link
Copy Markdown
Collaborator

Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.

Test Plan:
python test/test_pallas.py

@alanwaketan alanwaketan requested review from JackCaoG and qihqi March 11, 2024 19:04
return None


def convert_torch_dtype_to_jax(dtype: torch.dtype) -> jnp.dtype:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qihqi do we already have such converstation somewhere in torchxla2?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's generated by copilot, lol

Comment thread test/test_pallas.py
(x.shape, x.dtype))

dtypes = [torch.float32, torch.float
] # TODO: torch.float64, torch.bfloat16, torch.float16 don't work.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why bf16 won't work?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mosaic complaints. Need to dig more into it.

Comment thread test/test_pallas.py
Comment on lines +141 to +143
import jax
import jax.numpy as jnp
import jax._src.pallas.mosaic.pallas_call_registration
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this is repeated on multiple tests, maybe just move to the top?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a compatibility issue where jax will try to lock tpu devices if we import them before any pt/xla computations... I will need to resolve that...

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this pr in 2.3?

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Do you need this pr in 2.3?

Yea, will also need a couple for the TODOs.

Comment thread test/test_pallas.py
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: This test cannot be ran individually, let's fix it.
def test_tpu_custom_call_pallas_wrap_add_payload(self):
import jax
Copy link
Copy Markdown
Collaborator

@miladm miladm Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned JAX-based tests cause failures due to libtpu version inconsistencies, and in turn CI hiccups. How do we resolve this concern?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's resolved in the last PR: #6696



def make_kernel_from_pallas(kernel: Callable, output_shape_dtype_fn: Callable):
# TODO: Maybe we can cache the payload for the same input.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The payload may change if the input is dynamic. We need to confirm this with pallas folks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the cache itself should deal with the dynamism.

@alanwaketan alanwaketan force-pushed the alanwaketan/pallas_api branch from 8b8be2e to ae6b62b Compare March 12, 2024 23:38
@alanwaketan
Copy link
Copy Markdown
Collaborator Author

Can I get any reviews?

Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should refactor convert_torch_dtype_to_jax and invesgate bf16(which I assume most people will use), approve to unblock.

@alanwaketan
Copy link
Copy Markdown
Collaborator Author

I still think we should refactor convert_torch_dtype_to_jax and invesgate bf16(which I assume most people will use), approve to unblock.

Yea, for sure. Let me follow up with that.

@alanwaketan alanwaketan merged commit 1bbe333 into master Mar 13, 2024
@alanwaketan alanwaketan deleted the alanwaketan/pallas_api branch March 13, 2024 18:39
lsy323 pushed a commit that referenced this pull request Mar 13, 2024
Summary:
This pull request introduces make_kernel_from_pallas API which is the top level API to interact with the Pallas integration. It takes a pallas_call wrapper and than make it a custom pytorch op.

Test Plan:
python test/test_pallas.py
lsy323 added a commit that referenced this pull request Mar 13, 2024
Co-authored-by: Jiewen Tan <jwtan@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants