Skip to content

Sibiling computeAt pos out of sync in inlining #1803

@shmsong

Description

@shmsong

🐛 Describe the bug

Recent changes in PR #1800 made it possible for multi-output siblings to run out of sync after computeAt pass. Here is a minimal repro:


TEST_F(NVFuserTest, FusionInlineRepro_CUDA) {
  Fusion fusion;
  FusionGuard fg(&fusion);
  
  TensorView* tv0 = makeContigTensor(2);
  
  fusion.addInput(tv0);
  auto tv1 = set(tv0);
  auto tvs = Welford(tv1, {1});
  auto tvo = set(tvs.var_sum);
  fusion.addOutput(tvo);

  tvo->split(0, 16);
  tvo->axis(1)->parallelize(ParallelType::Unroll);

  tv0->computeAt(tvo,-1, ComputeAtMode::BestEffort);

  fusion.printMath();

  TORCH_INTERNAL_ASSERT(tvs.var_sum->getComputeAtPosition() == tvs.avg->getComputeAtPosition());
}

snippet from output:

T2_l[ iS21{( ceilDiv(i1, 16) )}, iS22{16}, rS5{i2} ] ca_pos( 2 ) produce_pos( 2)(Avg),
T3_l[ iS13{( ceilDiv(i1, 16) )}, iS14{16}, rS7{i2} ] ca_pos( 1 ) produce_pos( 2)(Var),
T4_l[ iS19{( ceilDiv(i1, 16) )}, iS20{16}, rS9{i2} ] ca_pos( 2 ) produce_pos( 2)(Count)
 = Welford ( T1_l[ iS15{( ceilDiv(i1, 16) )}, iS16{16}, iS3{i2} ] ca_pos( 2 )(Avg), 
  allreduce = 0 )

Note the difference in ca_pos on the three outputs.

Out of sync is one trouble, and actually ca_pos of 2 should not be allowed, this would lead to confusion at expression sorting stage.

A possible fix I'm guessing would be to incorporate getMaxPosAll of all siblings when propagating among them.

https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/inline_propagator.cpp#L253

Versions

TO BE UPDATED

Metadata

Metadata

Assignees

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions