🐛 Bug
Maxpool operator from xla2 crashes
To Reproduce
- download a mnist toy example I prepared from here
- move it to Trillium machine
- run it
Traceback:
Traceback (most recent call last):
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 174, in
jax_weights, opt_state, loss = training_step(jax_weights, jax_buffers, opt_state, x_j, target_j)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 102, in training_step
loss, grads = jax.value_and_grad(forward)(jax_weights, buffers, x, target)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 89, in forward
pred = jittable_model.functional_call('forward', weights, buffers, x)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/interop.py", line 73, in functional_call
res = getattr(self._model, method_name)(*args, **kwargs)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 68, in forward
x = F.max_pool2d(x, 2, stride=2, return_indices=False)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 783, in _max_pool2d
return handle_torch_function(
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/overrides.py", line 1630, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in torch_function
return func(*args, **(kwargs or {}))
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 796, in _max_pool2d
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 430, in dispatch
res = op.func(args, **kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/ops/jaten.py", line 1136, in _aten_max_pool2d_with_indices
indices, _ = jax.lax.reduce_window(
ValueError: Operands must have the same tree structure as init_values: PyTreeDef([, *, CustomNode(Zero[ShapedArray(float0[1024,64,24,24])], []), ]) vs. PyTreeDef([, *, *, *])
(venv) zmelumian@t1v-n-41085a1d-w-0:~$
Expected behavior
Maxpool to have exact behavior between xla2 and pytorch and be completely wrapped
Environment
- Reproducible on XLA backend TPU
- Jax version 0.4.43
- Trillium machine
- torch_xla2 version: 0.0.1
Additional context
after digging in, I noticed that the operator written in xla2 has two trees, one where it will compute the maxpool values, and one where it computes both values and indices
it is done so, to let the case of the indices unused in the operator to be ignored in the XLA compiler - which will lead better performence due to the skip of the unused indexes prediction, unless they were used
a custom kernel was written to the window function in jax that gets a value that contains a tuple which is the index and value per grid on the window map
the logic in the kernel is accurate and should work, however the windowing function is not supported in Jax, which has some odd internal states per input to better understand it's relationship across devices
🐛 Bug
Maxpool operator from xla2 crashes
To Reproduce
Traceback:
Traceback (most recent call last):
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 174, in
jax_weights, opt_state, loss = training_step(jax_weights, jax_buffers, opt_state, x_j, target_j)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 102, in training_step
loss, grads = jax.value_and_grad(forward)(jax_weights, buffers, x, target)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 89, in forward
pred = jittable_model.functional_call('forward', weights, buffers, x)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/interop.py", line 73, in functional_call
res = getattr(self._model, method_name)(*args, **kwargs)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 68, in forward
x = F.max_pool2d(x, 2, stride=2, return_indices=False)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 783, in _max_pool2d
return handle_torch_function(
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/overrides.py", line 1630, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in torch_function
return func(*args, **(kwargs or {}))
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 796, in _max_pool2d
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 430, in dispatch
res = op.func(args, **kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/ops/jaten.py", line 1136, in _aten_max_pool2d_with_indices
indices, _ = jax.lax.reduce_window(
ValueError: Operands must have the same tree structure as init_values: PyTreeDef([, *, CustomNode(Zero[ShapedArray(float0[1024,64,24,24])], []), ]) vs. PyTreeDef([, *, *, *])
(venv) zmelumian@t1v-n-41085a1d-w-0:~$
Expected behavior
Maxpool to have exact behavior between xla2 and pytorch and be completely wrapped
Environment
Additional context
after digging in, I noticed that the operator written in xla2 has two trees, one where it will compute the maxpool values, and one where it computes both values and indices
it is done so, to let the case of the indices unused in the operator to be ignored in the XLA compiler - which will lead better performence due to the skip of the unused indexes prediction, unless they were used
a custom kernel was written to the window function in jax that gets a value that contains a tuple which is the index and value per grid on the window map
the logic in the kernel is accurate and should work, however the windowing function is not supported in Jax, which has some odd internal states per input to better understand it's relationship across devices