Skip to content

Non-contiguous Tensor with Broadcast Bug #1741

@kevinstephano

Description

@kevinstephano

🐛 Describe the bug

This issue was originally provided by @ngimel. This was her example using refs. It is exposed via the Python Frontend. This problem is not seen in TorchScript because the inputs are permuted to be contiguous in the integration. So, a C++ test had to be written.

Original Test:

import torch
from functools import partial
from torch._prims.executor import make_traced
import torch._refs

def main():
    device="cuda"
    dtype=torch.half
    size=(4,32,16,112,112)
    size1=(32,1,112,1)
    a=torch.randn(size, device=device, dtype=dtype).transpose(-1,-2)
    b=torch.randn(size1, device=device, dtype=dtype).transpose(-1,-2)

    a+b
    def fn_add(a,b):
        return torch._refs.add(a,b)

    traced = make_traced(fn_add)
    for executor in ('aten','nvfuser'):
        fn = partial(traced, executor=executor)
        result = fn(a,b)
        result = fn(a,b)
        result = fn(a,b)
        expected = torch.add(a,b)
        print((result-expected).abs().max())

    torch.cuda.synchronize()

if __name__ == "__main__":
    main()

C++ Test:

TEST_F(NVFuserTest, FusionReproNoncontigBroadcast_CUDA) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  
  auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({4, 32, 16, 112, 112}, options).transpose(-1, -2);
  at::Tensor t1 = at::randn({32, 1, 112, 1}, options).transpose(-1, -2);

  auto tv0 = TensorViewBuilder()
                 .ndims(5)
                 .contiguity({true, true, false, false, false}) // ttfff
                 .shape({4, 32, 16, 112, 112})
                 .dtype(DataType::Half)
                 .build();
  auto tv1 = TensorViewBuilder()
                 .ndims(4)
                 .contiguity({true, false, false, true}) // tfft
                 .shape({32, 1, 1, 112})
                 .dtype(DataType::Half)
                 .build();

  fusion->addInput(tv0);
  fusion->addInput(tv1);

  auto tv2 = add(tv0, tv1);

  fusion->addOutput(tv2);

  std::vector<IValue> aten_inputs({t0, t1});

  FusionExecutorCache executor_cache(std::move(fusion));
  auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);

  auto t2 = t0 + t1;

  testValidate(
      executor_cache.fusion(),
      cg_outputs,
      {t0, t1},
      {t2},
      __LINE__,
      __FILE__);
}

Error:

C++ exception with description "stride == cur_contig_stride || (still_rightmost && stride == 1) || (!still_rightmost && stride % word_size == 0) INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/executor_utils.cpp":586, please report a bug to PyTorch. Vectorization of T1_g[ iS97{( ceilDiv(( ceilDiv(( 32 * ( 1 * ( 1 * 112 ) ) ), 4) ), blockDim.x) )}, iS96{4}, iS98{blockDim.x} ] with word size 4 not possible due to invalid stride. Domain: iS98{blockDim.x}, stride: 1

Versions

TOT

Metadata

Metadata

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions