Skip to content

Segmentation fault when calling "where(tensor, scalar, scalar)" #1770

@Aidyn-A

Description

@Aidyn-A

🐛 Describe the bug

I am working on primTorch with a function called where and I am getting a segmentation fault in one of the tests on nvfuser. I narrowed it down and the case when it fails is this:

import torch
from torch._C._nvfuser import Fusion, FusionDefinition, DataType

fusion = Fusion()

with FusionDefinition(fusion) as fd :
    t0 = fd.define_tensor(1)
    t1 = fd.define_tensor(1)
    s0 = fd.define_scalar()
    s1 = fd.define_scalar()

    fd.add_input(t0)
    fd.add_input(t1)
    fd.add_input(s0)
    fd.add_input(s1)

    t2 = fd.Ops.ge(t0, t1)
    t3 = fd.Ops.where(t2, s0, s1)

    fd.add_output(t3)

fusion.print_ir()

input1 = torch.randint(0, 2, (5,), dtype=torch.bool, device='cuda')
input2 = torch.randint(0, 2, (5,), dtype=torch.bool, device='cuda')

for _ in range(5) :
    outputs = fusion.execute([input1, input2, 1.0, 2.0])

The output that it prints to the screen is:

Inputs:
  T0_g[ iS0{i0} ], float
  T1_g[ iS1{i2} ], float
  d3, double
  d4, double
Outputs:
  d7, double

%kernel_math {
T2_l[ iS2{i0} ]
   = T0_g[ iS0{i0} ]
   >= T1_g[ iS1{i2} ];
d7 = where(T2_l[ iS2{i0} ], d3, d4);
}

I guess it assumes that the output is scalar double 🤔

Versions

PyTorch-upstream master.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions