Skip to content

cummax: raises error if dim is 0-sized dimension. #8610

@ysiraichi

Description

@ysiraichi

🐛 Bug

The following snippet fails to run on PyTorch/XLA.

>>> torch.cummax(torch.rand(0, device="cuda"), -1)
torch.return_types.cummax(
values=tensor([], device='cuda:0', size=(4, 4, 0)),
indices=tensor([], device='cuda:0', size=(4, 4, 0), dtype=torch.int64))

>>> torch.cummax(torch.rand(0, device=xm.xla_device()), -1)
Aborted (core dumped)
Full Error
F0000 00:00:1737580672.300866   85730 debug_macros.h:21] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: Window dimensions {
  stride: 1
  padding_low: -1
  window_dilation: 1
  base_dilation: 1
}
 has a non-positive dimension.
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
        torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)
        torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)





        std::function<xla::Shape ()>::operator()() const
        torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
        torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
        torch_xla::CumMax::CumMax(torch::lazy::Value const&, long)
        void __gnu_cxx::new_allocator<torch_xla::CumMax>::construct<torch_xla::CumMax, torch::lazy::Value, long>(torch_xla::CumMax*, torch::lazy::Value&&, long&&)
        void std::allocator_traits<std::allocator<torch_xla::CumMax> >::construct<torch_xla::CumMax, torch::lazy::Value, long>(std::allocator<torch_xla::CumMax>&, torch_xla::CumMax*, torch::lazy::Value&&, long&&)
        std::_Sp_counted_ptr_inplace<torch_xla::CumMax, std::allocator<torch_xla::CumMax>, (__gnu_cxx::_Lock_policy)2>::_Sp_counted_ptr_inplace<torch::lazy::Value, long>(std::allocator<torch_xla::CumMax>, torch::lazy::Value&&, long&&)
        std::__shared_count<(__gnu_cxx::_Lock_policy)2>::__shared_count<torch_xla::CumMax, std::allocator<torch_xla::CumMax>, torch::lazy::Value, long>(torch_xla::CumMax*&, std::_Sp_alloc_shared_tag<std::allocator<torch_xla::CumMax> >, torch::lazy::Value&&, long&&)
        std::__shared_ptr<torch_xla::CumMax, (__gnu_cxx::_Lock_policy)2>::__shared_ptr<std::allocator<torch_xla::CumMax>, torch::lazy::Value, long>(std::_Sp_alloc_shared_tag<std::allocator<torch_xla::CumMax> >, torch::lazy::Value&&, long&&)
        std::shared_ptr<torch_xla::CumMax>::shared_ptr<std::allocator<torch_xla::CumMax>, torch::lazy::Value, long>(std::_Sp_alloc_shared_tag<std::allocator<torch_xla::CumMax> >, torch::lazy::Value&&, long&&)
        std::shared_ptr<torch_xla::CumMax> std::allocate_shared<torch_xla::CumMax, std::allocator<torch_xla::CumMax>, torch::lazy::Value, long>(std::allocator<torch_xla::CumMax> const&, torch::lazy::Value&&, long&&)
        std::shared_ptr<torch_xla::CumMax> std::make_shared<torch_xla::CumMax, torch::lazy::Value, long>(torch::lazy::Value&&, long&&)
        std::shared_ptr<torch::lazy::Node> torch_xla::MakeNode<torch_xla::CumMax, torch::lazy::Value, long>(torch::lazy::Value&&, long&&)
        torch_xla::tensor_methods::cummax(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, long)
        torch_xla::XLANativeFunctions::cummax(at::Tensor const&, long)





        c10::BoxedKernel::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
        c10::KernelFunction::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
        c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const



        c10::BoxedKernel::callBoxed(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const


        at::_ops::cummax::redispatch(c10::DispatchKeySet, at::Tensor const&, long)





        at::_ops::cummax::call(at::Tensor const&, long)
        at::Tensor::cummax(long) const
        


        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

*** Check failure stack trace: ***
    @     0x781d8d99484d  absl::lts_20230802::log_internal::LogMessage::PrepareToDie()
    @     0x781d8d9948bd  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x781d8d994340  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x781d8d994b8c  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x781d79c5b192  ConsumeValue<>()
    @     0x781d79c5afb4  torch_xla::ShapeHelper::ShapeOfXlaOp()
    @     0x781d79b3a38c  torch_xla::InferOutputShape()
    @     0x781d79ae2888  torch_xla::(anonymous namespace)::NodeOutputShape()
    @     0x781d79ae2969  torch_xla::CumMax::CumMax()::{lambda()#1}::operator()()
    @     0x781d79ae3c9c  std::__invoke_impl<>()
    @     0x781d79ae3a9e  std::__invoke_r<>()
    @     0x781d79ae3812  std::_Function_handler<>::_M_invoke()
    @     0x781d79c4132f  std::function<>::operator()()
    @     0x781d79c3f3a7  torch_xla::XlaNode::GetOpShape()
    @     0x781d79c3e3f6  torch_xla::XlaNode::XlaNode()
    @     0x781d79ae2a82  torch_xla::CumMax::CumMax()
    @     0x781d795e7d5d  __gnu_cxx::new_allocator<>::construct<>()
    @     0x781d795da692  std::allocator_traits<>::construct<>()
    @     0x781d795c7bf8  std::_Sp_counted_ptr_inplace<>::_Sp_counted_ptr_inplace<>()
    @     0x781d795b6c5b  std::__shared_count<>::__shared_count<>()
    @     0x781d795ad762  std::__shared_ptr<>::__shared_ptr<>()
    @     0x781d795a6f83  std::shared_ptr<>::shared_ptr<>()
    @     0x781d7959d90b  std::allocate_shared<>()
    @     0x781d795936d0  std::make_shared<>()
    @     0x781d7958838a  torch_xla::MakeNode<>()
    @     0x781d79562eef  torch_xla::tensor_methods::cummax()
    @     0x781d793c0606  torch_xla::XLANativeFunctions::cummax()
    @     0x781d798999e1  at::(anonymous namespace)::(anonymous namespace)::wrapper_XLA__cummax()
    @     0x781d7995fca2  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x781d799b234d  c10::impl::call_functor_with_args_from_stack_<>()
    @     0x781d79999c88  c10::impl::call_functor_with_args_from_stack<>()
    @     0x781d7995fcef  c10::impl::make_boxed_from_unboxed_functor<>::call()
    @     0x781f4f146c81  c10::BoxedKernel::callBoxed()
    @     0x781f4f1479b0  c10::KernelFunction::callBoxed()
    @     0x781f4fb53516  c10::Dispatcher::callBoxed()
    @     0x781f37607ceb  c10::OperatorHandle::callBoxed()
    @     0x781f375ff36e  (anonymous namespace)::functionalizeFallback()
    @     0x781f376031f3  c10::BoxedKernel::make_boxed_function<>()
    @     0x781f4f146c81  c10::BoxedKernel::callBoxed()
    @     0x781f388d9f4b  c10::impl::BoxedKernelWrapper<>::call()
    @     0x781f38792591  c10::Dispatcher::redispatch<>()
    @     0x781f38d24ce1  at::_ops::cummax::redispatch()
    @     0x781f3c15fd1b  at::redispatch::cummax()
    @     0x781f3c049620  torch::autograd::VariableType::(anonymous namespace)::cummax()::{lambda()#1}::operator()()
    @     0x781f3c04992d  torch::autograd::VariableType::(anonymous namespace)::cummax()
    @     0x781f3c1265d3  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x781f388d9ed5  c10::callUnboxedKernelFunction<>()
    @     0x781f38d24ab3  at::_ops::cummax::call()
    @     0x781f4ee43107  at::Tensor::cummax()
    @     0x781f4ee73ecf  torch::autograd::THPVariable_cummax()::{lambda()#1}::operator()()
    @     0x781f4ee7440b  torch::autograd::THPVariable_cummax()
    @     0x781f51cbf425  cfunction_call

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 1c89675

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions