The current as_strided implementation only supports a limited set of shapes, and currently falls back to the "default" implementation on CPU for shapes that it doesn't support.
That's technically incorrect behavior, since the fallback creates the view on CPU, and copies the results back to XLA; this creates a brand new tensor, rather than sharing storage with the original. e.g:
>>> t = torch.ones(4, device=xm.xla_device())
>>> t2 = t.as_strided((2,), (2,))
>>> t2[0] = -1
>>> t
tensor([1., 1., 1., 1.], device='xla:0') # should have been mutated
>>> t2
tensor([-1., 1.], device='xla:0')
>>>
In a future PR I'm changing the CPU fallback to make calls to view operators an error instead of secretly doing a copy, but eventually all view operators should have lowerings.
The current as_strided implementation only supports a limited set of shapes, and currently falls back to the "default" implementation on CPU for shapes that it doesn't support.
That's technically incorrect behavior, since the fallback creates the view on CPU, and copies the results back to XLA; this creates a brand new tensor, rather than sharing storage with the original. e.g:
In a future PR I'm changing the CPU fallback to make calls to view operators an error instead of secretly doing a copy, but eventually all view operators should have lowerings.