Skip to content

XLA2 does not support maxpool #8241

@zmelumian972

Description

@zmelumian972

🐛 Bug

Maxpool operator from xla2 crashes

To Reproduce

  1. download a mnist toy example I prepared from here
  2. move it to Trillium machine
  3. run it

Traceback:
Traceback (most recent call last):
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 174, in
jax_weights, opt_state, loss = training_step(jax_weights, jax_buffers, opt_state, x_j, target_j)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 102, in training_step
loss, grads = jax.value_and_grad(forward)(jax_weights, buffers, x, target)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 89, in forward
pred = jittable_model.functional_call('forward', weights, buffers, x)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/interop.py", line 73, in functional_call
res = getattr(self._model, method_name)(*args, **kwargs)
File "/home/zmelumian/xla2_trainer_poc//google_bugs/mnist_with_maxpool.py", line 68, in forward
x = F.max_pool2d(x, 2, stride=2, return_indices=False)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 783, in _max_pool2d
return handle_torch_function(
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/overrides.py", line 1630, in handle_torch_function
result = mode.torch_function(public_api, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 212, in torch_function
return func(*args, **(kwargs or {}))
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/_jit_internal.py", line 503, in fn
return if_false(*args, **kwargs)
File "/home/zmelumian/venv/lib/python3.10/site-packages/torch/nn/functional.py", line 796, in _max_pool2d
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 227, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 430, in dispatch
res = op.func(args, **kwargs)
File "/home/zmelumian/xla/experimental/torch_xla2/torch_xla2/ops/jaten.py", line 1136, in _aten_max_pool2d_with_indices
indices, _ = jax.lax.reduce_window(
ValueError: Operands must have the same tree structure as init_values: PyTreeDef([
, *, CustomNode(Zero[ShapedArray(float0[1024,64,24,24])], []), ]) vs. PyTreeDef([, *, *, *])
(venv) zmelumian@t1v-n-41085a1d-w-0:~$

Expected behavior

Maxpool to have exact behavior between xla2 and pytorch and be completely wrapped

Environment

  • Reproducible on XLA backend TPU
  • Jax version 0.4.43
  • Trillium machine
  • torch_xla2 version: 0.0.1

Additional context

after digging in, I noticed that the operator written in xla2 has two trees, one where it will compute the maxpool values, and one where it computes both values and indices

it is done so, to let the case of the indices unused in the operator to be ignored in the XLA compiler - which will lead better performence due to the skip of the unused indexes prediction, unless they were used

a custom kernel was written to the window function in jax that gets a value that contains a tuple which is the index and value per grid on the window map

the logic in the kernel is accurate and should work, however the windowing function is not supported in Jax, which has some odd internal states per input to better understand it's relationship across devices

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions