Skip to content

[EvaluationContext/ExpressionEvaluator] Example for ExpressionEvaluation after graph lowering #87

@jjsjann123

Description

@jjsjann123

🚀 Feature

I try to modify the cpp test FusionExprEvalBasic to have code lowering prior to use expression evaluation.

But I'm getting assertion error doing so. Looking at the comments in the example, I understand it's probably a case that we do not yet support.

Motivation

Trying to use ExpressionEvaluator to infer launch configuration.

Pitch

This modified example should work.

// Evaluate expressions in a simple IR                                                                                  
void testGPU_FusionExprEvalBasic() {                                                                                    
  Fusion fusion;                                                                                                        
  FusionGuard fg(&fusion);                                                                                              
                                                                                                                        
  // Create a non-trivial IR                                                                                            
  TensorView* tv0 = makeDummyTensor(2);                                                                                 
  TensorView* tv1 = makeDummyTensor(2);                                                                                 
                                                                                                                        
  fusion.addInput(tv0);                                                                                                 
  fusion.addInput(tv1);                                                                                                 
                                                                                                                        
  TensorView* tv2 = add(tv1, new Float(2.0));                                                                           
  TensorView* tv3 = add(tv0, tv2);                                                                                      
                                                                                                                        
  fusion.addOutput(tv3);                                                                                                
                                                                                                                        
  tv3->split(0, 4);                                                                                                     
                                                                                                                        
  tv0->computeAt(tv3, 1);                                                                                               
  tv1->computeAt(tv3, 1);                                                                                               
                                                                                                                        
  tv3->axis(0)->parallelize(ParallelType::BIDx);                                                                        
  tv2->axis(1)->parallelize(ParallelType::Unroll);                                                                      
  tv3->axis(1)->parallelize(ParallelType::Unroll);                                                                      
  tv2->axis(-1)->parallelize(ParallelType::TIDx);                                                                       
  tv3->axis(-1)->parallelize(ParallelType::TIDx);                                                                       
                                                                                                                                                                                               
  auto* bid_x = add(tv3->axis(0)->rawExtent(), new Int(0));                                                             
  auto* tid_x = add(tv3->axis(-1)->rawExtent(), new Int(0));                                                            

  // This appears to be causing issue;
  GPULower gpulw(&fusion);                                                                                              
  std::stringstream cdg;                                                                                                
  gpulw.printKernel(cdg);                                                                                               
  std::cout << cdg.str() << std::endl;                                                                                  

  // 1. Create an evaluation context                                                                                    
  EvaluationContext eval_context(&fusion);
  // 2. Bind values                                                                                                     
  //                                                                                                                    
  // IMPORTANT:                                                                                                         
  // a. The bindings are only as stable as the Vals are in the fusion graph                                             
  // b. You must use the original (rootDomain) extents                                                                  
  //  (ex. `tv0->getRootDomain()[0]->extent()`                                                                          
  //   instead of `tv0->axis(0)->extent()`)                                                                             
  eval_context.bind(tv0->getRootDomain()[0]->extent(), 6);                                                              
  eval_context.bind(tv0->getRootDomain()[1]->extent(), 128);                                                            
  eval_context.bind(tv1->getRootDomain()[0]->extent(), 6);                                                              
  eval_context.bind(tv1->getRootDomain()[1]->extent(), 128);                                                            
                                                                                                                        
  // 3. Evaluate and check result values                                                                                
  TORCH_CHECK(tv2->domain()->nDims() == 3);                                                                             
  checkIntValue(&eval_context, tv2->axis(0)->rawExtent(), 2);                                                           
  checkIntValue(&eval_context, tv2->axis(1)->rawExtent(), 4);                                                           
  checkIntValue(&eval_context, tv2->axis(2)->rawExtent(), 128);                                                         
                                                                                                                        
  TORCH_CHECK(tv3->domain()->nDims() == 3);                                                                             
  checkIntValue(&eval_context, tv3->axis(0)->rawExtent(), 2);                                                           
  checkIntValue(&eval_context, tv3->axis(1)->rawExtent(), 4);                                                           
  checkIntValue(&eval_context, tv3->axis(2)->rawExtent(), 128);                                                         
                                                                                                                        
  const auto bid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context);                                           
  std::cout << "bid x value " << bid_x_val.value() << std::endl;                                                        
  const auto tid_x_val = ExpressionEvaluator::evaluate(bid_x, &eval_context);                                           
  std::cout << "tid x value " << tid_x_val.value() << std::endl;                                                        
}

Metadata

Metadata

Assignees

Labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions