🐛 Describe the bug
I am working on primTorch with a function called where and I am getting a segmentation fault in one of the tests on nvfuser. I narrowed it down and the case when it fails is this:
import torch
from torch._C._nvfuser import Fusion, FusionDefinition, DataType
fusion = Fusion()
with FusionDefinition(fusion) as fd :
t0 = fd.define_tensor(1)
t1 = fd.define_tensor(1)
s0 = fd.define_scalar()
s1 = fd.define_scalar()
fd.add_input(t0)
fd.add_input(t1)
fd.add_input(s0)
fd.add_input(s1)
t2 = fd.Ops.ge(t0, t1)
t3 = fd.Ops.where(t2, s0, s1)
fd.add_output(t3)
fusion.print_ir()
input1 = torch.randint(0, 2, (5,), dtype=torch.bool, device='cuda')
input2 = torch.randint(0, 2, (5,), dtype=torch.bool, device='cuda')
for _ in range(5) :
outputs = fusion.execute([input1, input2, 1.0, 2.0])
The output that it prints to the screen is:
Inputs:
T0_g[ iS0{i0} ], float
T1_g[ iS1{i2} ], float
d3, double
d4, double
Outputs:
d7, double
%kernel_math {
T2_l[ iS2{i0} ]
= T0_g[ iS0{i0} ]
>= T1_g[ iS1{i2} ];
d7 = where(T2_l[ iS2{i0} ], d3, d4);
}
I guess it assumes that the output is scalar double 🤔
Versions
PyTorch-upstream master.
🐛 Describe the bug
I am working on primTorch with a function called
whereand I am getting a segmentation fault in one of the tests on nvfuser. I narrowed it down and the case when it fails is this:The output that it prints to the screen is:
I guess it assumes that the output is scalar double 🤔
Versions
PyTorch-upstream master.