Skip to content

[PyTorch/XLA] XLAShardedTensor untraceable by Dynamo #115970

@wonjoo-wj

Description

@wonjoo-wj

Issue description

In PyTorch/XLA, we're trying to register a custom op with PyTorch to make it traceable by Dynamo. While testing, however, we saw that XLAShardedTensor (which is an extension of torch.Tensor, defined at https://github.com/pytorch/xla/blob/master/torch_xla/distributed/spmd/xla_sharded_tensor.py#L61) was not be able to trace by Dynamo:

torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [TensorVariable()] {}
...
x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 558, in wrap_as_sharded_tensor
    return XLAShardedTensor(t)

IIRC, PyTorch's DTensor which traceable by Dynamo, is implemented as an extension of torch.Tensor. Should we expect XLAShardedTensor to be automatically traced by Dynamo as well or do we need to do something to make this happen?

Code example

We have a draft PR at pytorch/xla#6161. The relevant code is at distributed/spmd/mark_sharding.py, in the mark_sharding function as it returns a type of XLAShardedTensor.

cc @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: xlaRelated to XLA supporttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions