Skip to content

copying from CPU -> XLA doesn't work if the XLA tensor is resized first #2881

@bdhirsh

Description

@bdhirsh

🐛 Bug

If I perform a resize on an XLA tensor and then try to copy the contents of an equally sized cpu tensor into that XLA tensor, it looks like the copy doesn't respect the resize. See below for a better example with IR graph output.

To Reproduce

Here's a minimum repro, which I reproduced on collab:

import torch
import torch_xla
import torch_xla.core.xla_model as xm
dev = xm.xla_device()

a = torch.tensor([1])
b = torch.tensor([], device=dev)
b.resize_(a.size()) # resize from (0,) to (1,)
b.copy_(a) # doesn't actually copy anything! (?)
print(b.cpu())
# prints tensor([0.]). It should print tensor([1])

Expected behavior

The contents of the CPU tensor should be copied to the XLA tensor. I'd expect tensor([1]).

Environment

Reproduced on the Collab link from the getting started page. (https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb?pli=1&authuser=1#scrollTo=l50-R2kwFY7Z).

Additional context

Here's an example with slightly better contrast, that also includes the IR:

Good version:

dev = xm.xla_device()

a = torch.tensor([4, 4, 4], dtype=torch.int64)
b = torch.tensor([0, 0, 0], device=dev, dtype=torch.int64)
b.resize_(a.size()) # shouldn't actually do anything
b.copy_(a)
print(torch_xla._XLAC._get_xla_tensors_text([b]))
print(b.cpu())

That prints:

IR {
  %0 = s64[3]{0} xla::device_data(), device=TPU:0
  %1 = s64[3]{0} aten::reshape(%0), size=(3)
  %2 = s64[3]{0} xla::device_data(), device=TPU:0
  %3 = s64[] prim::Constant(), value=0
  %4 = s64[3]{0} aten::erfinv(%3), size=(3)
  %5 = s64[3]{0} xla::as_strided_view_update(%4, %2), size=(3), stride=(1), storage_offset=0
  %6 = s64[3]{0} xla::as_strided_view_update(%5, %1), size=(3), stride=(1), storage_offset=0
  %7 = s64[3]{0} aten::argmax(%6), size=(3), stride=(1), storage_offset=0
  %8 = s64[3]{0} aten::reshape(%7), size=(3), ROOT=0
}

tensor([4, 4, 4])

Contrast that to identical code where I initialize the XLA tensor to start off with a smaller size:

dev = xm.xla_device()

a = torch.tensor([4, 4, 4], dtype=torch.int64)
b = torch.tensor([0, 0], device=dev, dtype=torch.int64)
b.resize_(a.size()) # This resizes b from (2,) to (3,)
b.copy_(a)
print(torch_xla._XLAC._get_xla_tensors_text([b]))
print(b.cpu())

That prints:

IR {
  %0 = s64[3]{0} xla::device_data(), device=TPU:0
  %1 = s64[2]{0} aten::reshape(%0), size=(2)
  %2 = s64[2]{0} xla::device_data(), device=TPU:0
  %3 = s64[] prim::Constant(), value=0
  %4 = s64[2]{0} aten::erfinv(%3), size=(2)
  %5 = s64[2]{0} xla::as_strided_view_update(%4, %2), size=(2), stride=(1), storage_offset=0
  %6 = s64[2]{0} xla::as_strided_view_update(%5, %1), size=(2), stride=(1), storage_offset=0
  %7 = s64[2]{0} aten::argmax(%6), size=(2), stride=(1), storage_offset=0
  %8 = s64[3]{0} aten::reshape(%7), size=(3), ROOT=0
}

tensor([4, 4, 0]) # Wrong! Should be [4, 4, 4]

My first observations is that as_strided_view_update, which is the instruction that's meant to perform the in-place write (thanks @ailzhang), is using size=(2) instead of size=(3), because this instruction runs before the aten::reshape instruction. Is it possible that the reshape instruction needs to come earlier?

Metadata

Metadata

Assignees

No one assigned

    Labels

    staleHas not had recent activity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions