Skip to content

[LV][NFC] Refactor code for extracting first active element#131118

Merged
david-arm merged 3 commits into
llvm:mainfrom
david-arm:first_active
Mar 14, 2025
Merged

[LV][NFC] Refactor code for extracting first active element#131118
david-arm merged 3 commits into
llvm:mainfrom
david-arm:first_active

Conversation

@david-arm

Copy link
Copy Markdown
Contributor

Refactor the code to extract the first active element of a
vector in the early exit block, in preparation for PR #130766.
I've replaced the VPInstruction::ExtractFirstActive nodes with
a combination of a new VPInstruction::FirstActiveLane node and
a Instruction::ExtractElement node.

@llvmbot

llvmbot commented Mar 13, 2025

Copy link
Copy Markdown
Member

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: David Sherwood (david-arm)

Changes

Refactor the code to extract the first active element of a
vector in the early exit block, in preparation for PR #130766.
I've replaced the VPInstruction::ExtractFirstActive nodes with
a combination of a new VPInstruction::FirstActiveLane node and
a Instruction::ExtractElement node.


Full diff: https://github.com/llvm/llvm-project/pull/131118.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+2-3)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp (+8-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+15-11)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+8-4)
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 43dc30c40bb53..0b63dba76ff9b 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -879,9 +879,8 @@ class VPInstruction : public VPRecipeWithIRFlags,
     // Returns a scalar boolean value, which is true if any lane of its (only
     // boolean) vector operand is true.
     AnyOf,
-    // Extracts the first active lane of a vector, where the first operand is
-    // the predicate, and the second operand is the vector to extract.
-    ExtractFirstActive,
+    // Calculates the first active lane index of the vector predicate operand.
+    FirstActiveLane,
   };
 
 private:
diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
index 6f6875f0e5e0e..d3a379682ba07 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
@@ -50,6 +50,12 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
     return SetResultTyFromOp();
 
   switch (Opcode) {
+  case Instruction::ExtractElement: {
+    Type *ResTy = inferScalarType(R->getOperand(0));
+    VPValue *OtherV = R->getOperand(1);
+    CachedTypes[OtherV] = ResTy;
+    return ResTy;
+  }
   case Instruction::Select: {
     Type *ResTy = inferScalarType(R->getOperand(1));
     VPValue *OtherV = R->getOperand(2);
@@ -78,7 +84,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
   case VPInstruction::CanonicalIVIncrementForPart:
   case VPInstruction::AnyOf:
     return SetResultTyFromOp();
-  case VPInstruction::ExtractFirstActive:
+  case VPInstruction::FirstActiveLane:
+    return Type::getIntNTy(Ctx, 64);
   case VPInstruction::ExtractFromEnd: {
     Type *BaseTy = inferScalarType(R->getOperand(0));
     if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index a8b304271f0da..7cc99896d42f4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -461,6 +461,11 @@ Value *VPInstruction::generate(VPTransformState &State) {
     Value *A = State.get(getOperand(0));
     return Builder.CreateNot(A, Name);
   }
+  case Instruction::ExtractElement: {
+    Value *Vec = State.get(getOperand(0));
+    Value *Idx = State.get(getOperand(1), true);
+    return Builder.CreateExtractElement(Vec, Idx, Name);
+  }
   case Instruction::ICmp: {
     bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
     Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
@@ -705,12 +710,10 @@ Value *VPInstruction::generate(VPTransformState &State) {
     Value *A = State.get(getOperand(0));
     return Builder.CreateOrReduce(A);
   }
-  case VPInstruction::ExtractFirstActive: {
-    Value *Vec = State.get(getOperand(0));
-    Value *Mask = State.get(getOperand(1));
-    Value *Ctz = Builder.CreateCountTrailingZeroElems(
-        Builder.getInt64Ty(), Mask, true, "first.active.lane");
-    return Builder.CreateExtractElement(Vec, Ctz, "early.exit.value");
+  case VPInstruction::FirstActiveLane: {
+    Value *Mask = State.get(getOperand(0));
+    return Builder.CreateCountTrailingZeroElems(Builder.getInt64Ty(), Mask,
+                                                true, Name);
   }
   default:
     llvm_unreachable("Unsupported opcode for instruction");
@@ -753,7 +756,8 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
 
 bool VPInstruction::isVectorToScalar() const {
   return getOpcode() == VPInstruction::ExtractFromEnd ||
-         getOpcode() == VPInstruction::ExtractFirstActive ||
+         getOpcode() == Instruction::ExtractElement ||
+         getOpcode() == VPInstruction::FirstActiveLane ||
          getOpcode() == VPInstruction::ComputeReductionResult ||
          getOpcode() == VPInstruction::AnyOf;
 }
@@ -812,13 +816,14 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
   if (Instruction::isBinaryOp(getOpcode()))
     return false;
   switch (getOpcode()) {
+  case Instruction::ExtractElement:
   case Instruction::ICmp:
   case Instruction::Select:
   case VPInstruction::AnyOf:
   case VPInstruction::CalculateTripCountMinusVF:
   case VPInstruction::CanonicalIVIncrementForPart:
   case VPInstruction::ExtractFromEnd:
-  case VPInstruction::ExtractFirstActive:
+  case VPInstruction::FirstActiveLane:
   case VPInstruction::FirstOrderRecurrenceSplice:
   case VPInstruction::LogicalAnd:
   case VPInstruction::Not:
@@ -927,7 +932,6 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::Broadcast:
     O << "broadcast";
     break;
-
   case VPInstruction::ExtractFromEnd:
     O << "extract-from-end";
     break;
@@ -943,8 +947,8 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
   case VPInstruction::AnyOf:
     O << "any-of";
     break;
-  case VPInstruction::ExtractFirstActive:
-    O << "extract-first-active";
+  case VPInstruction::FirstActiveLane:
+    O << "first-active-lane";
     break;
   default:
     O << Instruction::getOpcodeName(getOpcode());
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 3b44e95a2471c..7a1a74c9dfaa8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2157,10 +2157,14 @@ void VPlanTransforms::handleUncountableEarlyExit(
       ExitIRI->extractLastLaneOfOperand(MiddleBuilder);
     }
     // Add the incoming value from the early exit.
-    if (!IncomingFromEarlyExit->isLiveIn())
-      IncomingFromEarlyExit =
-          EarlyExitB.createNaryOp(VPInstruction::ExtractFirstActive,
-                                  {IncomingFromEarlyExit, EarlyExitTakenCond});
+    if (!IncomingFromEarlyExit->isLiveIn()) {
+      VPValue *FirstActiveLane = EarlyExitB.createNaryOp(
+          VPInstruction::FirstActiveLane, {EarlyExitTakenCond}, nullptr,
+          "first.active.lane");
+      IncomingFromEarlyExit = EarlyExitB.createNaryOp(
+          Instruction::ExtractElement, {IncomingFromEarlyExit, FirstActiveLane},
+          nullptr, "early.exit.value");
+    }
     ExitIRI->addOperand(IncomingFromEarlyExit);
   }
   MiddleBuilder.createNaryOp(VPInstruction::BranchOnCond, {IsEarlyExitTaken});

Comment thread llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp Outdated

@fhahn fhahn left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that this isn't based on the latest main? I am not seeing updates to VPInstruction::computeCost, which should handle ExtractFirstActive?

@david-arm

Copy link
Copy Markdown
Contributor Author

Is it possible that this isn't based on the latest main? I am not seeing updates to VPInstruction::computeCost, which should handle ExtractFirstActive?

Yep, you're absolutely right! And it was me that landed the cost model patch for ExtractFirstActive. :face_palm I've rebased it now.

@lukel97 lukel97 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

}
case Instruction::ExtractElement: {
Value *Vec = State.get(getOperand(0));
Value *Idx = State.get(getOperand(1), true);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtractElement should also be handled in onlyFirstLaneUsed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to handle that though because ExtractElement doesn't necessarily extract the first lane of operand 0. Or does onlyFirstLaneUsed refer to the result?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, do you mean I should pass in onlyFirstLaneUsed as the second arg to State.get?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// Returns true if only the first lane of \p Def is used.
bool onlyFirstLaneUsed(const VPValue *Def);

I must admit I find this area of the code a bit difficult to understand. I can see that for icmp we do something like this:

  case Instruction::ICmp: {
    bool OnlyFirstLaneUsed = vputils::onlyFirstLaneUsed(this);
    Value *A = State.get(getOperand(0), OnlyFirstLaneUsed);
    Value *B = State.get(getOperand(1), OnlyFirstLaneUsed);
    return Builder.CreateCmp(getPredicate(), A, B, Name);
  }

but if I do the same thing for extractelement then the answer to vputils::onlyFirstLaneUsed(this) should always be false, right? Since we may use any lane of the vector passed as operand 0.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant updating VPInstruction::onlyFirstLaneUsed to return true for the second operand of ExtractElement, as we only use the first lane during execute?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need to add

  case Instruction::ExtractElement:
    return Op == getOperand(1);

into VPInstruction::onlyFirstLaneUsed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see, thanks @fhahn and @Mel-Chen!

@Mel-Chen Mel-Chen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great to see extractelement added!
This helps enable tail folding when there are external users. :)

Comment on lines +464 to +466
case Instruction::ExtractElement: {
Value *Vec = State.get(getOperand(0));
Value *Idx = State.get(getOperand(1), true);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
case Instruction::ExtractElement: {
Value *Vec = State.get(getOperand(0));
Value *Idx = State.get(getOperand(1), true);
case Instruction::ExtractElement: {
assert(State.VF.isVector() && "Only extract elements from vectors");
Value *Vec = State.get(getOperand(0));
Value *Idx = State.get(getOperand(1), /*IsScalar*/ true);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

}
case Instruction::ExtractElement: {
Value *Vec = State.get(getOperand(0));
Value *Idx = State.get(getOperand(1), true);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need to add

  case Instruction::ExtractElement:
    return Op == getOperand(1);

into VPInstruction::onlyFirstLaneUsed.

Refactor the code to extract the first active element of a
vector in the early exit block, in preparation for PR llvm#130766.
I've replaced the VPInstruction::ExtractFirstActive nodes with
a combination of a new VPInstruction::FirstActiveLane node and
a Instruction::ExtractElement node.
@david-arm

Copy link
Copy Markdown
Contributor Author

Rebased to fix ridiculous git conflicts.

@fhahn fhahn left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Comment thread llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

@Mel-Chen Mel-Chen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@github-actions

github-actions Bot commented Mar 14, 2025

Copy link
Copy Markdown

✅ With the latest revision this PR passed the C/C++ code formatter.

@david-arm david-arm merged commit 3b6d009 into llvm:main Mar 14, 2025
@david-arm david-arm deleted the first_active branch April 7, 2025 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants