Skip to content

upsample_bilinear2d HLO returns unexpected data-type. #7095

@ysiraichi

Description

@ysiraichi

🐛 Bug

At first, it seems to work, i.e. the returned data-type is torch.float16, as expected. However, when using it with another torch.float16 tensor, it breaks unexpectedly.

In the example below, foo stacks the result of an upsample_bilinear with another torch.float16 tensor. The function fails when using PyTorch/XLA because stack (lowered to concatenate) expects all inputs to be of the same data-type (note that this behavior is being fixed in #7091). However, as we can see from the error message, we are trying to call concatenate(f32[...], f16[...]). Meaning that the result of upsample_bilinear wasn't really f16.

In summary: upsample_bilinear2d returns a torch.float16 tensor, even though its HLO representation is f32. The expected data-type is f16

def foo(x, y):
    return torch.stack([torch.nn.functional.upsample_bilinear(x, scale_factor=2), y])

a = torch.rand(1, 3, 10, 10, dtype=torch.half)
b = torch.rand(1, 3, 20, 20, dtype=torch.half)

Xa = a.to(xm.xla_device())
Xb = b.to(xm.xla_device())

out = foo(a, b)
print(out.dtype)  # torch.float16

Xout = foo(Xa, Xb)
print(Xout.dtype)  # torch.float16

# Fails with the error below.
Xout.cpu()
Non-OK-status: status.status() status: INTERNAL: during context [Unknown]: Seen floating point types of different precisions in %concatenate.82 = f32[2,1,3,20,20]{4,3,2,1,0} concatenate(f32[1,1,3,20,20]{4,3,2,1,0} %reshape.80, f16[1,1,3,20,20]{4,3,2,1,0} %reshape.81), dimensions={0}, but mixed precision is disallowed.
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(absl::lts_20230802::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
        torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
        torch_xla::XLATensor::ApplyPendingGraph()
        torch_xla::XLATensor::GetXlaData()
        torch_xla::XLATensor::ToTensor(bool)
        torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)




        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)



        at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)





        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)





        at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)


        at::native::to(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)



        at::_ops::to_dtype_layout::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, bool, std::optional<c10::MemoryFormat>)
        at::Tensor::to(c10::TensorOptions, bool, bool, std::optional<c10::MemoryFormat>) const



        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 8d35eb0

Additional context

This seems to happen due to the fact that we are computing on F32 regardless of what the original input data type is.

if (is_kernel_bilinear || xla::primitive_util::IsIntegralType(input_type)) {
input = xla::ConvertElementType(input, xla::F32);
input_type = xla::F32;
}

cc @miladm @JackCaoG

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions