Skip to content

Commit 52c7ded

Browse files
committed
ForLoop::isGrouped()
1 parent cef2966 commit 52c7ded

3 files changed

Lines changed: 52 additions & 46 deletions

File tree

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 1 addition & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -115,38 +115,6 @@ std::string genCall(
115115
return ss.str();
116116
}
117117

118-
//! A utility class to check if an expression of a particular type exists
119-
class ExprFinder : kir::ConstIrVisitor {
120-
public:
121-
//! True if expr or any of its nested expressions is included in
122-
//! expr_types
123-
static bool exists(
124-
const Expr* expr,
125-
const std::unordered_set<std::type_index>& expr_types) {
126-
ExprFinder finder(expr_types);
127-
finder.handle(std::vector<const Expr*>{expr});
128-
return finder.is_found_;
129-
}
130-
131-
private:
132-
ExprFinder(const std::unordered_set<std::type_index>& expr_types)
133-
: expr_types_(expr_types) {}
134-
135-
using kir::ConstIrVisitor::handle;
136-
137-
void handle(const Expr* expr) final {
138-
if (expr_types_.find(typeid(*expr)) != expr_types_.end()) {
139-
is_found_ = true;
140-
return;
141-
}
142-
kir::ConstIrVisitor::handle(expr);
143-
}
144-
145-
private:
146-
const std::unordered_set<std::type_index>& expr_types_;
147-
bool is_found_ = false;
148-
};
149-
150118
class CudaKernelGenerator : private OptOutConstDispatch {
151119
static constexpr const char* kTab = " ";
152120

@@ -2482,19 +2450,6 @@ class CudaKernelGenerator : private OptOutConstDispatch {
24822450
" which is handled by its own handler");
24832451
}
24842452

2485-
//! True if loop is grouped. The IterDomain of the loop must have
2486-
//! ParallelType::Group, but it isn't sufficient as the loop may be
2487-
//! for an initialization expression, for which the loop shold not
2488-
//! be grouped. Make sure a GroupedGridReduction is found.
2489-
bool isGroupedLoop(const kir::ForLoop* loop) {
2490-
if (loop->iter_domain()->getParallelType() != ParallelType::Group) {
2491-
return false;
2492-
}
2493-
return ExprFinder::exists(
2494-
loop,
2495-
{typeid(kir::GroupedGridReduction), typeid(kir::GroupedGridWelford)});
2496-
}
2497-
24982453
void handle(const kir::ForLoop* loop) final {
24992454
if (loop->isTrivial()) {
25002455
handleTrivialLoop(loop);
@@ -2503,7 +2458,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
25032458

25042459
// If a loop is grouped, no loop is created, but it isn't
25052460
// considered trivial as the loop trip count is not one.
2506-
if (isGroupedLoop(loop)) {
2461+
if (loop->isGrouped()) {
25072462
grouped_loops_.push_back(loop);
25082463
handleScope(loop->body());
25092464
grouped_loops_.pop_back();

torch/csrc/jit/codegen/cuda/kernel_ir.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,54 @@ bool ForLoop::isTrivial() const {
475475
return false;
476476
}
477477

478+
namespace {
479+
480+
//! A utility class to check if an expression of a particular type exists
481+
class ExprFinder : kir::ConstIrVisitor {
482+
public:
483+
//! True if expr or any of its nested expressions is included in
484+
//! expr_types
485+
static bool exists(
486+
const Expr* expr,
487+
const std::unordered_set<std::type_index>& expr_types) {
488+
ExprFinder finder(expr_types);
489+
finder.handle(std::vector<const Expr*>{expr});
490+
return finder.is_found_;
491+
}
492+
493+
private:
494+
ExprFinder(const std::unordered_set<std::type_index>& expr_types)
495+
: expr_types_(expr_types) {}
496+
497+
using kir::ConstIrVisitor::handle;
498+
499+
void handle(const Expr* expr) final {
500+
if (expr_types_.find(typeid(*expr)) != expr_types_.end()) {
501+
is_found_ = true;
502+
return;
503+
}
504+
kir::ConstIrVisitor::handle(expr);
505+
}
506+
507+
private:
508+
const std::unordered_set<std::type_index>& expr_types_;
509+
bool is_found_ = false;
510+
};
511+
512+
} // namespace
513+
514+
bool ForLoop::isGrouped() const {
515+
//! The IterDomain of the loop must have ParallelType::Group, but it isn't
516+
//! sufficient as the loop may be for an initialization expression, for which
517+
//! the loop shold not be grouped. Make sure a GroupedGridReduction is found.
518+
if (iter_domain()->getParallelType() != ParallelType::Group) {
519+
return false;
520+
}
521+
return ExprFinder::exists(
522+
this,
523+
{typeid(kir::GroupedGridReduction), typeid(kir::GroupedGridWelford)});
524+
}
525+
478526
IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond)
479527
: Expr(passkey) {
480528
setPredicate(cond);

torch/csrc/jit/codegen/cuda/kernel_ir.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,9 @@ class TORCH_CUDA_CU_API ForLoop final : public Expr {
502502
//! True if no actual for-loop is materialized
503503
bool isTrivial() const;
504504

505+
//! True if loop is grouped.
506+
bool isGrouped() const;
507+
505508
//! Returns the stage of a double buffered iterdomain
506509
//! that this for loop materializes.
507510
auto doubleBufferLoopStage() const {

0 commit comments

Comments
 (0)