Skip to content

Support as_strided() for all shapes #2964

@bdhirsh

Description

@bdhirsh

The current as_strided implementation only supports a limited set of shapes, and currently falls back to the "default" implementation on CPU for shapes that it doesn't support.

That's technically incorrect behavior, since the fallback creates the view on CPU, and copies the results back to XLA; this creates a brand new tensor, rather than sharing storage with the original. e.g:

>>> t = torch.ones(4, device=xm.xla_device())
>>> t2 = t.as_strided((2,), (2,))
>>> t2[0] = -1
>>> t
tensor([1., 1., 1., 1.], device='xla:0') # should have been mutated
>>> t2
tensor([-1.,  1.], device='xla:0')
>>>

In a future PR I'm changing the CPU fallback to make calls to view operators an error instead of secretly doing a copy, but eventually all view operators should have lowerings.

Metadata

Metadata

Assignees

No one assigned

    Labels

    nostaleDo not consider for staleness

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions