Skip to content

Support 0d tensors in reductionOp #1768

@IvanYashchuk

Description

@IvanYashchuk

🚀 The feature, motivation and pitch

PyTorch allows calling reduction on 0d tensors:

In [1]: import torch

In [2]: a = torch.randn(())

In [3]: a.ndim
Out[3]: 0

In [4]: torch.sum(a, 0, keepdim=True)
Out[4]: tensor(0.5233)

In [5]: torch.sum(a, 0, keepdim=False)
Out[5]: tensor(0.5233)

But nvFuser raises an error "Tried to reduce a 0-dim tensor".

Here's a Python reproducer:

import torch
from torch._C._nvfuser import Fusion, FusionDefinition, DataType

# Construct and Define Fusion
fusion = Fusion()

with FusionDefinition(fusion) as fd :
    t0 = fd.define_tensor(0)
    fd.add_input(t0)
    t1 = fd.Ops.sum(t0, [0], False, DataType.Float)
    fd.add_output(t1)

fusion.print_ir()

# Execute Fusion
input1 = torch.ones((), device='cuda')
outputs = fusion.execute([input1, ])

print(outputs)

Using the following patch, that disables ndim checks, makes my tests pass and the above Python example works. Is it safe to relax the requirements on the dimensionality for reduction?

index 3bab78b566..a6fc269d38 100644
--- a/torch/csrc/jit/codegen/cuda/arith.cpp
+++ b/torch/csrc/jit/codegen/cuda/arith.cpp
@@ -867,7 +867,7 @@ static TensorView* newForReduction(
       "Asked for ouput of reduction, but no reduction axis provided.");

   TORCH_INTERNAL_ASSERT(
-      (*(axes_set.rbegin())) < orig_domain.size(),
+      (*(axes_set.rbegin())) <= orig_domain.size(),
       "Error setting up reduction, reduction axis (",
       *(axes_set.rbegin()),
       ") is outside nDims (",
@@ -921,7 +921,7 @@ TensorView* reductionOp(
       TensorDomain::sameAs(tv->getMaybeRFactorDomain(), tv->domain()->domain()),
       "Reducing a tensor once it's gone under transformations is not permitted at this time. Please set reductions before calling split/merge/computeAt.");

-  TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");
+  // TORCH_CHECK(tv->nDims() > 0, "Tried to reduce a 0-dim tensor");

   TORCH_CHECK(axes.size() > 0, "No reduction axis specified");

@@ -933,7 +933,7 @@ TensorView* reductionOp(
     }

     TORCH_CHECK(
-        axis >= 0 && axis < ndims,
+        axis >= 0 && axis <= ndims,
         "Reduction on invalid axis, recieved: ",
         axis,
         " however tensor view only has ",

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions