🐛 Bug
At first, it seems to work, i.e. the returned data-type is torch.float16, as expected. However, when using it with another torch.float16 tensor, it breaks unexpectedly.
In the example below, foo stacks the result of an upsample_bilinear with another torch.float16 tensor. The function fails when using PyTorch/XLA because stack (lowered to concatenate) expects all inputs to be of the same data-type (note that this behavior is being fixed in #7091). However, as we can see from the error message, we are trying to call concatenate(f32[...], f16[...]). Meaning that the result of upsample_bilinear wasn't really f16.
In summary: upsample_bilinear2d returns a torch.float16 tensor, even though its HLO representation is f32. The expected data-type is f16
def foo(x, y):
return torch.stack([torch.nn.functional.upsample_bilinear(x, scale_factor=2), y])
a = torch.rand(1, 3, 10, 10, dtype=torch.half)
b = torch.rand(1, 3, 20, 20, dtype=torch.half)
Xa = a.to(xm.xla_device())
Xb = b.to(xm.xla_device())
out = foo(a, b)
print(out.dtype) # torch.float16
Xout = foo(Xa, Xb)
print(Xout.dtype) # torch.float16
# Fails with the error below.
Xout.cpu()
Non-OK-status: status.status() status: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %concatenate.82 = f32[2,1,3,20,20]{4,3,2,1,0} concatenate(f32[1,1,3,20,20]{4,3,2,1,0} %reshape.80, f16[1,1,3,20,20]{4,3,2,1,0} %reshape.81), dimensions={0}, but mixed precision is disallowed.
*** Begin stack trace ***
tsl::CurrentStackTrace[abi:cxx11]()
std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230802::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
torch_xla::XLATensor::ApplyPendingGraph()
torch_xla::XLATensor::GetXlaData()
torch_xla::XLATensor::ToTensor(bool)
torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)
at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)
at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
- torch_xla version: 8d35eb0
Additional context
This seems to happen due to the fact that we are computing on F32 regardless of what the original input data type is.
|
if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) { |
|
input = xla::ConvertElementType(input, xla::F32); |
|
input_type = xla::F32; |
|
} |
cc @miladm @JackCaoG
🐛 Bug
At first, it seems to work, i.e. the returned data-type is
torch.float16, as expected. However, when using it with anothertorch.float16tensor, it breaks unexpectedly.In the example below,
foostacks the result of anupsample_bilinearwith anothertorch.float16tensor. The function fails when using PyTorch/XLA becausestack(lowered toconcatenate) expects all inputs to be of the same data-type (note that this behavior is being fixed in #7091). However, as we can see from the error message, we are trying to callconcatenate(f32[...], f16[...]). Meaning that the result ofupsample_bilinearwasn't reallyf16.In summary:
upsample_bilinear2dreturns atorch.float16tensor, even though its HLO representation isf32. The expected data-type isf16Environment
Additional context
This seems to happen due to the fact that we are computing on
F32regardless of what the original input data type is.xla/torch_xla/csrc/resize_ops.cpp
Lines 56 to 59 in f336317
cc @miladm @JackCaoG