Minimal support for calling JAX from PyTorch/XLA#8781
Conversation
The key is to use `as_serialized_hlo_module_proto`. Our self cooked `_xla_computation_text_to_proto` causes an undefined op reference error.
|
Our CPU CI doesn't have JAX and torchax so I'm installing it like this: https://github.com/pytorch/xla/pull/8781/files#diff-d30e144e1a94b9125c97915d4bd55eeb00d136257f96c9c753c255b66b1c00b4 |
|
What happens if we tie |
|
I guess we are skipping this op :)
|
We'll have to option to either use this with AOTAutograd, or avoid AOTAutograd. If we would like to use AOTAutograd, we can use AOTAutograd to turn a PyTorch function into FX Graph, then use torchax to turn the FX Graph to a JAX function, then use If we would like to avoid AOTAutograd, we'll use torchax to turn a PyTorch function into a JAX function, use In the latter approach, we'll be able to use powerful JAX remat features like
In the general case, this won't support non-functional code. For example if some PyTorch function inserts a tensor into a global list, we'll just end up inserting a JAX tracer into the global list. But AOTAutograd has the same constraint. This limitation does not exist when only using LazyTensor because that framework blurs the boundary between tracing/execution at the cost of accidental graph breaks. My hypothesis is that most high performance models (esp. ones in torchprime) won't do weird stuff like that. We'll need to find some other solution to cover the long tail e.g. use dynamo.
We can't do that (e.g. Similarly, my hypothesis is that high performance models won't do a data dependent branch. They need to rewrite the |
Co-authored-by: Han Qi <hanq@google.com>
The trick is to turn JAX into HLO and weave that into an existing PyTorch/XLA lazy tensor graph.
So far we just test a simple numerical function.