Skip to content

Commit 89521cc

Browse files
author
Naoya Maruyama
committed
Small refactoring
1 parent b0f92fb commit 89521cc

5 files changed

Lines changed: 6 additions & 4 deletions

File tree

torch/csrc/jit/codegen/cuda/ir_iostream.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ void IRPrinter::handle(const kir::GridReduction* gr) {
698698
indent();
699699
// Since block-level reduction is already done, those dimensions
700700
// with tidx/y/z being true do not participate in the grid reduction.
701-
os << kir::getPredicateFlagName(out->view()) << " = "
701+
os << kir::GridReduction::getPredicateFlagName(out->view()) << " = "
702702
<< "reduction::gridReduce< " << (bidx ? "true" : "false") << ", "
703703
<< (bidy ? "true" : "false") << ", " << (bidz ? "true" : "false") << ", "
704704
<< (!tidx ? "true" : "false") << ", " << (!tidy ? "true" : "false") << ", "

torch/csrc/jit/codegen/cuda/kernel_ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ GridReduction::GridReduction(const GridReduction* src, IrCloner* ir_cloner)
361361
reduction_buffer_(ir_cloner->clone(src->reduction_buffer_)),
362362
sync_buffer_(ir_cloner->clone(src->sync_buffer_)) {}
363363

364-
std::string getPredicateFlagName(const TensorView* val) {
364+
std::string GridReduction::getPredicateFlagName(const TensorView* val) {
365365
std::stringstream ss;
366366
ss << "T" << val->name() << "pred";
367367
return ss.str();

torch/csrc/jit/codegen/cuda/kernel_ir.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@ class TORCH_CUDA_API GridReduction : public Expr {
529529
return sync_buffer_;
530530
}
531531

532+
static std::string getPredicateFlagName(const TensorView* val);
533+
532534
private:
533535
ReductionOp* reduction_op_ = nullptr;
534536
Allocate* reduction_buffer_ = nullptr;

torch/csrc/jit/codegen/cuda/lower_index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ void IndexLowering::handle(TernaryOp* top) {
138138
namespace {
139139

140140
kir::Allocate* allocateGridReductionFlag(TensorView* out_tv) {
141-
auto flag_name = kir::getPredicateFlagName(out_tv);
141+
auto flag_name = kir::GridReduction::getPredicateFlagName(out_tv);
142142
auto flag_var = new kir::NamedScalar(flag_name, DataType::Bool);
143143
return new kir::Allocate(flag_var, MemoryType::Local, new kir::Int(1));
144144
}

torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Val* getPredicatePerParallelType(
1818
TORCH_INTERNAL_ASSERT(!sources.empty(), "No predicate source found");
1919
TORCH_INTERNAL_ASSERT(sources.size() == 1, "Multiple sources detected");
2020
auto src = *sources.begin();
21-
auto flag_name = kir::getPredicateFlagName(src);
21+
auto flag_name = kir::GridReduction::getPredicateFlagName(src);
2222
return new kir::NamedScalar(flag_name, DataType::Bool);
2323
} else {
2424
return kir::eqExpr(kir::NamedScalar::getParallelIndex(pt), new kir::Int(0));

0 commit comments

Comments
 (0)