Skip to content

assertEqual doesn't work with DTensor #167549

@ezyang

Description

@ezyang

🐛 Describe the bug

For example, you'll get:

  File "/data/users/ezyang/a/pytorch/torch/testing/_comparison.py", line 749, in compare
    self._compare_values(actual, expected)
  File "/data/users/ezyang/a/pytorch/torch/testing/_comparison.py", line 907, in _compare_values
    compare_fn(
  File "/data/users/ezyang/a/pytorch/torch/testing/_comparison.py", line 1101, in _compare_regular_values_close
    msg = make_tensor_mismatch_msg(
  File "/data/users/ezyang/a/pytorch/torch/testing/_comparison.py", line 324, in make_tensor_mismatch_msg
    abs_diff[matches_flat] = 0
  File "/data/users/ezyang/a/pytorch/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/_dynamo/eval_frame.py", line 1129, in _fn
    return fn(*args, **kwargs)
  File "/data/users/ezyang/a/pytorch/torch/distributed/tensor/_api.py", line 347, in __torch_dispatch__
    return DTensor._op_dispatcher.dispatch(
  File "/data/users/ezyang/a/pytorch/torch/distributed/tensor/_dispatch.py", line 187, in dispatch
    op_info = self.unwrap_to_op_info(op_call, args, kwargs)
  File "/data/users/ezyang/a/pytorch/torch/distributed/tensor/_dispatch.py", line 457, in unwrap_to_op_info
    self._try_replicate_spec_for_scalar_tensor(
  File "/data/users/ezyang/a/pytorch/torch/distributed/tensor/_dispatch.py", line 566, in _try_replicate_spec_for_scalar_tensor
    raise RuntimeError(
RuntimeError: aten.index_put_.default: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators! Please see https://docs.pytorch.org/docs/main/distributed.tensor.html#mixed-tensor-and-dtensor-operations for more details.

It's easy to WAR: just don't pass DTensor to assertEqual; full_tensor it first.

Versions

main

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @tianyu-l @XilunWu @SherlockNoMad

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dtensordistributed tensor tagoncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions