-
Notifications
You must be signed in to change notification settings - Fork 568
copying from CPU -> XLA doesn't work if the XLA tensor is resized first #2881
Description
🐛 Bug
If I perform a resize on an XLA tensor and then try to copy the contents of an equally sized cpu tensor into that XLA tensor, it looks like the copy doesn't respect the resize. See below for a better example with IR graph output.
To Reproduce
Here's a minimum repro, which I reproduced on collab:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
dev = xm.xla_device()
a = torch.tensor([1])
b = torch.tensor([], device=dev)
b.resize_(a.size()) # resize from (0,) to (1,)
b.copy_(a) # doesn't actually copy anything! (?)
print(b.cpu())
# prints tensor([0.]). It should print tensor([1])
Expected behavior
The contents of the CPU tensor should be copied to the XLA tensor. I'd expect tensor([1]).
Environment
Reproduced on the Collab link from the getting started page. (https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb?pli=1&authuser=1#scrollTo=l50-R2kwFY7Z).
Additional context
Here's an example with slightly better contrast, that also includes the IR:
Good version:
dev = xm.xla_device()
a = torch.tensor([4, 4, 4], dtype=torch.int64)
b = torch.tensor([0, 0, 0], device=dev, dtype=torch.int64)
b.resize_(a.size()) # shouldn't actually do anything
b.copy_(a)
print(torch_xla._XLAC._get_xla_tensors_text([b]))
print(b.cpu())
That prints:
IR {
%0 = s64[3]{0} xla::device_data(), device=TPU:0
%1 = s64[3]{0} aten::reshape(%0), size=(3)
%2 = s64[3]{0} xla::device_data(), device=TPU:0
%3 = s64[] prim::Constant(), value=0
%4 = s64[3]{0} aten::erfinv(%3), size=(3)
%5 = s64[3]{0} xla::as_strided_view_update(%4, %2), size=(3), stride=(1), storage_offset=0
%6 = s64[3]{0} xla::as_strided_view_update(%5, %1), size=(3), stride=(1), storage_offset=0
%7 = s64[3]{0} aten::argmax(%6), size=(3), stride=(1), storage_offset=0
%8 = s64[3]{0} aten::reshape(%7), size=(3), ROOT=0
}
tensor([4, 4, 4])
Contrast that to identical code where I initialize the XLA tensor to start off with a smaller size:
dev = xm.xla_device()
a = torch.tensor([4, 4, 4], dtype=torch.int64)
b = torch.tensor([0, 0], device=dev, dtype=torch.int64)
b.resize_(a.size()) # This resizes b from (2,) to (3,)
b.copy_(a)
print(torch_xla._XLAC._get_xla_tensors_text([b]))
print(b.cpu())
That prints:
IR {
%0 = s64[3]{0} xla::device_data(), device=TPU:0
%1 = s64[2]{0} aten::reshape(%0), size=(2)
%2 = s64[2]{0} xla::device_data(), device=TPU:0
%3 = s64[] prim::Constant(), value=0
%4 = s64[2]{0} aten::erfinv(%3), size=(2)
%5 = s64[2]{0} xla::as_strided_view_update(%4, %2), size=(2), stride=(1), storage_offset=0
%6 = s64[2]{0} xla::as_strided_view_update(%5, %1), size=(2), stride=(1), storage_offset=0
%7 = s64[2]{0} aten::argmax(%6), size=(2), stride=(1), storage_offset=0
%8 = s64[3]{0} aten::reshape(%7), size=(3), ROOT=0
}
tensor([4, 4, 0]) # Wrong! Should be [4, 4, 4]
My first observations is that as_strided_view_update, which is the instruction that's meant to perform the in-place write (thanks @ailzhang), is using size=(2) instead of size=(3), because this instruction runs before the aten::reshape instruction. Is it possible that the reshape instruction needs to come earlier?