Skip to content

Temporary array not initialized #64

@naoyam

Description

@naoyam

See testGPU_FusionReduction5 in https://github.com/naoyam/pytorch/tree/blockReduce_fail.

It fails at the final validation.

unknown file: Failure
C++ exception with description "Expected aten_output.allclose(cg_output) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
Exception raised from testGPU_FusionReduction5 at ../test/cpp/jit/test_gpu.cpp:2645 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x68 (0x7fd99738ec08 in /home/nmaruyama/pytorch/src/csarofeen/build/lib/libc10.so)
frame #1: torch::jit::testGPU_FusionReduction5() + 0x97a (0x55bb04c93e0a in ./build/bin/test_jit)
frame #2: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x4a (0x55bb04cc277a in ./build/bin/test_jit)
frame #3: <unknown function> + 0x21f4de (0x55bb04cb84de in ./build/bin/test_jit)
frame #4: <unknown function> + 0x21f99d (0x55bb04cb899d in ./build/bin/test_jit)
frame #5: <unknown function> + 0x21fbbd (0x55bb04cb8bbd in ./build/bin/test_jit)
frame #6: testing::internal::UnitTestImpl::RunAllTests() + 0xc59 (0x55bb04cb9a99 in ./build/bin/test_jit)
frame #7: testing::UnitTest::Run() + 0x98 (0x55bb04cb9d98 in ./build/bin/test_jit)
frame #8: main + 0xc8 (0x55bb04b43148 in ./build/bin/test_jit)
frame #9: __libc_start_main + 0xe7 (0x7fd996bd5b97 in /lib/x86_64-linux-gnu/libc.so.6)
frame #10: _start + 0x2a (0x55bb04b4b83a in ./build/bin/test_jit)
" thrown in the test body.
[  FAILED  ] JitTest.GPU_FusionReduction5_CUDA (1424 ms)
[----------] 1 test from JitTest (1424 ms total)

The same test does pass when numel_z is larger like that defined in https://github.com/naoyam/pytorch/tree/reduction3d.

Here's the generated code:

__global__ void CUDAGeneratedKernel(Tensor<float, 3> T0, Tensor<float, 1> T1){
  T1[ ( blockIdx.x * T1.stride[0] ) ]
     = float(0);
  float T3[1];
  if ( ( ( ( 0 * 8 ) + threadIdx.y ) < T0.size[1] ) ) {
    T3[ 0 ]
       = float(0);
  }
  for(size_t i38 = 0; i38 < ( ceilDiv(T0.size[1], 8) ); ++i38 ) {
    float T2[1];
    if ( ( ( ( ( i38 * 8 ) + threadIdx.y ) < T0.size[1] ) && ( ( ( 0 * 128 ) + threadIdx.x ) < T0.size[2] ) ) ) {
      T2[ 0 ]
         = float(0);
    }
    for(size_t i40 = 0; i40 < ( ceilDiv(T0.size[2], 128) ); ++i40 ) {
      if ( ( ( ( ( i38 * 8 ) + threadIdx.y ) < T0.size[1] ) && ( ( ( i40 * 128 ) + threadIdx.x ) < T0.size[2] ) ) ) {
        T2[ 0 ]
           = T2[ 0 ]
           + T0[ ( blockIdx.x * T0.stride[0] ) + ( ( ( i38 * 8 ) + threadIdx.y ) * T0.stride[1] ) + ( ( ( i40 * 128 ) + threadIdx.x ) * T0.stride[2] ) ];
      }
    }
    if ( ( ( ( i38 * 8 ) + threadIdx.y ) < T0.size[1] ) ) {
      T3[ 0 ]
         = T3[ 0 ]
         + T2[ 0 ];
    }
  }
  blockReduce< true, true, false > ( T1[ ( blockIdx.x * T1.stride[0] ) ], T3[ 0 ], reduction_add_float);
}

While I have not confirmed, my suspect is that it is because T3 is not initialized for threads whose threadIdx.y is larger than T0.size[1].

The initialization guard is generated with simpler tests using 2D tensors, but I don't see the validation error. Since T3 is just a local variable, it is probably just zero most of the times, but can be a random value sometimes.

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions