diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp index ea67b526423bf..d7b7d65974860 100644 --- a/llvm/lib/Analysis/CodeMetrics.cpp +++ b/llvm/lib/Analysis/CodeMetrics.cpp @@ -16,6 +16,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/Debug.h" #include "llvm/Support/InstructionCost.h" @@ -112,13 +113,39 @@ void CodeMetrics::collectEphemeralValues( completeEphemeralValues(Visited, Worklist, EphValues); } +/// Check if a block was previously marked dead by setting the terminator to +/// `unreachable` and their is a statically evaluated conditional branch to NOT +/// branch to the block. This is done for instance within the unroll pass, +/// between unrolling inner/outer loops. +static bool isBlockMarkedDead(const BasicBlock *BB) { + if (!isa(BB->getTerminator())) + return false; + for (const BasicBlock *Pred : predecessors(BB)) { + auto *CondBr = dyn_cast(Pred->getTerminator()); + if (!CondBr) + return false; + auto *Cond = dyn_cast(CondBr->getCondition()); + if (!Cond) + return false; + // Check that the dead block is on the not-taken edge. + BasicBlock *TakenSucc = + Cond->isOne() ? CondBr->getSuccessor(0) : CondBr->getSuccessor(1); + if (TakenSucc == BB) + return false; + } + return true; +} + static bool extendsConvergenceOutsideLoop(const Instruction &I, const Loop *L) { if (!L) return false; if (!isa(I)) return false; for (const auto *U : I.users()) { - if (!L->contains(cast(U))) + const auto *UserInst = cast(U); + if (isBlockMarkedDead(UserInst->getParent())) + continue; + if (!L->contains(UserInst)) return true; } return false; diff --git a/llvm/test/Transforms/LoopUnroll/convergent.controlled.ll b/llvm/test/Transforms/LoopUnroll/convergent.controlled.ll index 5dc613e733f00..7bdceb4a0a3a3 100644 --- a/llvm/test/Transforms/LoopUnroll/convergent.controlled.ll +++ b/llvm/test/Transforms/LoopUnroll/convergent.controlled.ll @@ -555,6 +555,124 @@ exit: ret i32 0 } +; The input represents the state after an inner loop has been fully unrolled +; inside an outer loop. The old inner loop body becomes dead (marked +; unreachable and conditional branch is statically known) but still has a +; predecessor from the unrolled code. A convergence token defined in the outer +; loop is used in the dead block, which should not prevent unrolling of the +; outer loop. +define i32 @extended_loop_dead_branch(i32 %n) { +; CHECK-LABEL: @extended_loop_dead_branch( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[OUTER:%.*]] +; CHECK: outer: +; CHECK-NEXT: [[X_0:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_1:%.*]], [[OUTER_LATCH_1:%.*]] ] +; CHECK-NEXT: [[TOK_LOOP:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: [[TOK_INNER_0:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_INNER_0]]) ] +; CHECK-NEXT: [[TOK_INNER_1:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_INNER_1]]) ] +; CHECK-NEXT: br i1 false, label [[DEAD:%.*]], label [[OUTER_LATCH:%.*]] +; CHECK: outer.latch: +; CHECK-NEXT: [[INC:%.*]] = add nuw nsw i32 [[X_0]], 1 +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[INC]], [[N:%.*]] +; CHECK-NEXT: br i1 [[EXITCOND]], label [[EXIT:%.*]], label [[OUTER_1:%.*]] +; CHECK: outer.1: +; CHECK-NEXT: [[TOK_INNER_0_1:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_INNER_0_1]]) ] +; CHECK-NEXT: [[TOK_INNER_1_1:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_INNER_1_1]]) ] +; CHECK-NEXT: br i1 false, label [[DEAD]], label [[OUTER_LATCH_1]] +; CHECK: outer.latch.1: +; CHECK-NEXT: [[INC_1]] = add nsw i32 [[X_0]], 2 +; CHECK-NEXT: [[EXITCOND_1:%.*]] = icmp eq i32 [[INC_1]], [[N]] +; CHECK-NEXT: br i1 [[EXITCOND_1]], label [[EXIT]], label [[OUTER]], !llvm.loop [[LOOP10:![0-9]+]] +; CHECK: exit: +; CHECK-NEXT: ret i32 0 +; CHECK: dead: +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_LOOP]]) ] +; CHECK-NEXT: unreachable +; +entry: + br label %outer + +outer: + %x.0 = phi i32 [ 0, %entry ], [ %inc, %outer.latch ] + %tok.loop = call token @llvm.experimental.convergence.anchor() + ; Unrolled inner iteration 0 + %tok.inner.0 = call token @llvm.experimental.convergence.anchor() + call void @f() [ "convergencectrl"(token %tok.inner.0) ] + ; Unrolled inner iteration 1, exit condition folded + %tok.inner.1 = call token @llvm.experimental.convergence.anchor() + call void @f() [ "convergencectrl"(token %tok.inner.1) ] + br i1 false, label %dead, label %outer.latch + +outer.latch: + %inc = add nsw i32 %x.0, 1 + %exitcond = icmp eq i32 %inc, %n + br i1 %exitcond, label %exit, label %outer, !llvm.loop !1 + +exit: + ret i32 0 + +dead: + call void @f() [ "convergencectrl"(token %tok.loop) ] + unreachable +} + +; Similar to extended_loop_dead_branch, but the branch to the block with the +; convergence token use is NOT statically known to be dead (the condition is +; dynamic). The outer loop should NOT be unrolled because the convergence token +; extends outside the loop. +define i32 @extended_loop_not_dead_branch(i32 %n, i1 %cond) { +; CHECK-LABEL: @extended_loop_not_dead_branch( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[OUTER:%.*]] +; CHECK: outer: +; CHECK-NEXT: [[X_0:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC:%.*]], [[OUTER_LATCH:%.*]] ] +; CHECK-NEXT: [[TOK_LOOP:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: [[TOK_INNER_0:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_INNER_0]]) ] +; CHECK-NEXT: [[TOK_INNER_1:%.*]] = call token @llvm.experimental.convergence.anchor() +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_INNER_1]]) ] +; CHECK-NEXT: br i1 [[COND:%.*]], label [[DEAD:%.*]], label [[OUTER_LATCH]] +; CHECK: outer.latch: +; CHECK-NEXT: [[INC]] = add nsw i32 [[X_0]], 1 +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp eq i32 [[INC]], [[N:%.*]] +; CHECK-NEXT: br i1 [[EXITCOND]], label [[EXIT:%.*]], label [[OUTER]], !llvm.loop [[LOOP4]] +; CHECK: exit: +; CHECK-NEXT: ret i32 0 +; CHECK: dead: +; CHECK-NEXT: call void @f() [ "convergencectrl"(token [[TOK_LOOP]]) ] +; CHECK-NEXT: unreachable +; +entry: + br label %outer + +outer: + %x.0 = phi i32 [ 0, %entry ], [ %inc, %outer.latch ] + %tok.loop = call token @llvm.experimental.convergence.anchor() + ; Unrolled inner iteration 0 + %tok.inner.0 = call token @llvm.experimental.convergence.anchor() + call void @f() [ "convergencectrl"(token %tok.inner.0) ] + ; Unrolled inner iteration 1 + %tok.inner.1 = call token @llvm.experimental.convergence.anchor() + call void @f() [ "convergencectrl"(token %tok.inner.1) ] + br i1 %cond, label %dead, label %outer.latch + +outer.latch: + %inc = add nsw i32 %x.0, 1 + %exitcond = icmp eq i32 %inc, %n + br i1 %exitcond, label %exit, label %outer, !llvm.loop !1 + +exit: + ret i32 0 + +dead: + call void @f() [ "convergencectrl"(token %tok.loop) ] + unreachable +} + declare token @llvm.experimental.convergence.anchor() declare token @llvm.experimental.convergence.loop()