🐛 Bug
I ran into an issue, where I have two branches in a fusion that join later.
When I naively call computeAt for all inputs to outputs. The generated code screws up the dependency and is using intermediate before it is written to.
Here's the example:
void testGPU_FusionBranches() {
Fusion fusion;
FusionGuard fg(&fusion);
// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2);
TensorView* tv1 = makeDummyTensor(2);
TensorView* tv2 = makeDummyTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addInput(tv2);
auto tv3 = add(tv0, new Float(1.0));
auto tv4 = add(tv3, tv1);
auto tv5 = add(tv3, tv2);
auto tv6 = add(tv4, tv5);
fusion.addOutput(tv6);
constexpr int x = 63, y = 33;
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({x, y}, options);
at::Tensor t1 = at::randn({x, y}, options);
at::Tensor t2 = at::randn({x, y}, options);
torch::jit::fuser::cuda::FusionExecutor fe;
//fuser::cuda::scheduleFusion(&fusion, {t0, t1, t2});
tv6->merge(0);
tv6->split(0, 128);
tv6->split(0, 4);
tv6->axis(0)->parallelize(ParallelType::BIDx);
tv0->computeAt(tv6, 1);
tv1->computeAt(tv6, 1);
tv2->computeAt(tv6, 1);
tv3->axis(-2)->parallelize(ParallelType::Unroll);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv4->axis(-2)->parallelize(ParallelType::Unroll);
tv4->axis(-1)->parallelize(ParallelType::TIDx);
tv5->axis(-2)->parallelize(ParallelType::Unroll);
tv5->axis(-1)->parallelize(ParallelType::TIDx);
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0, t1, t2});
auto t3 = t0.add(1.0);
auto t4 = t3.add(t1);
auto t5 = t3.add(t2);
auto t6 = t4.add(t5);
TORCH_CHECK(t4.allclose(outputs[0]));
}
generated code (where T3 is used before it's created and filled):
__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1, Tensor<float, 2> T2, Tensor<float, 2> T6){
float T4[4];
if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) % T6.size[1] ) < T6.size[1] ) ) ) {
for(size_t i47 = 0; i47 < 4; ++i47 ) {
T4[ i47 ]
= T3[ i47 ]
+ T1[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) / T6.size[1] ) * T1.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) % T6.size[1] ) * T1.stride[1] ) ];
}
} else {
for(size_t i47 = 0; i47 < 4; ++i47 ) {
if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) % T6.size[1] ) < T6.size[1] ) ) ) {
T4[ i47 ]
= T3[ i47 ]
+ T1[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) / T6.size[1] ) * T1.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i47 ) % T6.size[1] ) * T1.stride[1] ) ];
}
}
}
float T3[4];
if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) % T6.size[1] ) < T6.size[1] ) ) ) {
for(size_t i48 = 0; i48 < 4; ++i48 ) {
T3[ i48 ]
= T0[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) / T6.size[1] ) * T0.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) % T6.size[1] ) * T0.stride[1] ) ]
+ float(1);
}
} else {
for(size_t i48 = 0; i48 < 4; ++i48 ) {
if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) % T6.size[1] ) < T6.size[1] ) ) ) {
T3[ i48 ]
= T0[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) / T6.size[1] ) * T0.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i48 ) % T6.size[1] ) * T0.stride[1] ) ]
+ float(1);
}
}
}
float T5[4];
if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + ( 4 - 1 ) ) % T6.size[1] ) < T6.size[1] ) ) ) {
for(size_t i49 = 0; i49 < 4; ++i49 ) {
T5[ i49 ]
= T3[ i49 ]
+ T2[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) / T6.size[1] ) * T2.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) % T6.size[1] ) * T2.stride[1] ) ];
}
} else {
for(size_t i49 = 0; i49 < 4; ++i49 ) {
if ( ( ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) / T6.size[1] ) < T6.size[0] ) && ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) % T6.size[1] ) < T6.size[1] ) ) ) {
T5[ i49 ]
= T3[ i49 ]
+ T2[ ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) / T6.size[1] ) * T2.stride[0] ) + ( ( ( ( ( ( blockIdx.x * 128 ) + threadIdx.x ) * 4 ) + i49 ) % T6.size[1] ) * T2.stride[1] ) ];
}
}
}
for(size_t i50 = 0; i50 < 128; ++i50 ) {
for(size_t i51 = 0; i51 < 4; ++i51 ) {
if ( ( ( ( ( ( blockIdx.x * 128 ) + i50 ) * 4 ) + i51 ) < ( T6.size[0] * T6.size[1] ) ) ) {
T6[ ( ( ( ( blockIdx.x * 128 ) + i50 ) * 4 ) + i51 ) ]
= T4[ ( i50 * 4 ) + i51 ]
+ T5[ ( i50 * 4 ) + i51 ];
}
}
}
}
To Reproduce
I have pushed the repro here:
https://github.com/csarofeen/pytorch/tree/failed_branching_in_scheduling
you should just build it and wrong the cpp test with:
CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_FUSER_DEBUG=1 ./test_jit --gtest_filter="*GPU_FusionBranches*"
🐛 Bug
I ran into an issue, where I have two branches in a fusion that join later.
When I naively call computeAt for all inputs to outputs. The generated code screws up the dependency and is using intermediate before it is written to.
Here's the example:
generated code (where T3 is used before it's created and filled):
To Reproduce
I have pushed the repro here:
https://github.com/csarofeen/pytorch/tree/failed_branching_in_scheduling
you should just build it and wrong the cpp test with:
CUDA_VISIBLE_DEVICES=1 PYTORCH_CUDA_FUSER_DEBUG=1 ./test_jit --gtest_filter="*GPU_FusionBranches*"