Issue description
In PyTorch/XLA, we're trying to register a custom op with PyTorch to make it traceable by Dynamo. While testing, however, we saw that XLAShardedTensor (which is an extension of torch.Tensor, defined at https://github.com/pytorch/xla/blob/master/torch_xla/distributed/spmd/xla_sharded_tensor.py#L61) was not be able to trace by Dynamo:
torch._dynamo.exc.Unsupported: call_function UserDefinedClassVariable() [TensorVariable()] {}
...
x86_64.egg/torch_xla/distributed/spmd/xla_sharding.py", line 558, in wrap_as_sharded_tensor
return XLAShardedTensor(t)
IIRC, PyTorch's DTensor which traceable by Dynamo, is implemented as an extension of torch.Tensor. Should we expect XLAShardedTensor to be automatically traced by Dynamo as well or do we need to do something to make this happen?
Code example
We have a draft PR at pytorch/xla#6161. The relevant code is at distributed/spmd/mark_sharding.py, in the mark_sharding function as it returns a type of XLAShardedTensor.
cc @bdhirsh
Issue description
In PyTorch/XLA, we're trying to register a custom op with PyTorch to make it traceable by Dynamo. While testing, however, we saw that
XLAShardedTensor(which is an extension oftorch.Tensor, defined at https://github.com/pytorch/xla/blob/master/torch_xla/distributed/spmd/xla_sharded_tensor.py#L61) was not be able to trace by Dynamo:IIRC, PyTorch's
DTensorwhich traceable by Dynamo, is implemented as an extension oftorch.Tensor. Should we expectXLAShardedTensorto be automatically traced by Dynamo as well or do we need to do something to make this happen?Code example
We have a draft PR at pytorch/xla#6161. The relevant code is at
distributed/spmd/mark_sharding.py, in themark_shardingfunction as it returns a type ofXLAShardedTensor.cc @bdhirsh