Skip to content

[X86] Ensure a (vXi1 bitcast(iX Mask)) memory mask is canonicalised for extension before it might get split by legalisation#175769

Merged
RKSimon merged 6 commits intollvm:mainfrom
RKSimon:x86-bool-mask-memory-ext
Jan 14, 2026
Merged

[X86] Ensure a (vXi1 bitcast(iX Mask)) memory mask is canonicalised for extension before it might get split by legalisation#175769
RKSimon merged 6 commits intollvm:mainfrom
RKSimon:x86-bool-mask-memory-ext

Conversation

@RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Jan 13, 2026

Masked load/store/gathers often need to bitcast the mask from a bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first.

This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion.

Alternative to #175385

Fixes #59789

…nsion before it might get split by legalisation

Masked load/store/gathers often need to bitcast the mask from a bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first.

This patch uses the combineToExtendBoolVectorInReg helper function to canonicalise the masks, similar to what we already do for vselect expansion.

Alternative to llvm#175385

Fixes llvm#175385
@RKSimon
Copy link
Collaborator Author

RKSimon commented Jan 13, 2026

CC @folkertdev

@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2026

@llvm/pr-subscribers-backend-x86

Author: Simon Pilgrim (RKSimon)

Changes

Masked load/store/gathers often need to bitcast the mask from a bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first.

This patch uses the combineToExtendBoolVectorInReg helper function to canonicalise the masks, similar to what we already do for vselect expansion.

Alternative to #175385

Fixes #175385


Patch is 35.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175769.diff

4 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+79-25)
  • (modified) llvm/test/CodeGen/X86/masked_gather.ll (+10-66)
  • (modified) llvm/test/CodeGen/X86/masked_load.ll (+66-178)
  • (modified) llvm/test/CodeGen/X86/masked_store.ll (+62-176)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a354704c5958b..91ced77438492 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -53679,12 +53679,31 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
         return Blend;
   }
 
+  EVT VT = Mld->getValueType(0);
+  SDValue Mask = Mld->getMask();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get split
+  // by legalization.
+  if (DCI.isBeforeLegalizeOps() && Mask.getOpcode() == ISD::BITCAST &&
+      Mask.getScalarValueSizeInBits() == 1 && !Subtarget.hasAVX512() &&
+      TLI.isOperationLegalOrCustom(ISD::MLOAD, VT)) {
+    SDLoc DL(N);
+    EVT MaskVT = Mask.getValueType();
+    EVT ExtMaskVT = VT.changeVectorElementTypeToInteger();
+    if (SDValue NewMask = combineToExtendBoolVectorInReg(
+            ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget)) {
+      NewMask = DAG.getNode(ISD::TRUNCATE, DL, MaskVT, NewMask);
+      return DAG.getMaskedLoad(
+          VT, DL, Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(), NewMask,
+          Mld->getPassThru(), Mld->getMemoryVT(), Mld->getMemOperand(),
+          Mld->getAddressingMode(), Mld->getExtensionType());
+    }
+  }
+
   // If the mask value has been legalized to a non-boolean vector, try to
   // simplify ops leading up to it. We only demand the MSB of each lane.
-  SDValue Mask = Mld->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
-    EVT VT = Mld->getValueType(0);
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
     if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
       if (N->getOpcode() != ISD::DELETED_NODE)
@@ -53785,6 +53804,24 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
     return SDValue();
   }
 
+  // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get split
+  // by legalization.
+  if (DCI.isBeforeLegalizeOps() && Mask.getOpcode() == ISD::BITCAST &&
+      Mask.getScalarValueSizeInBits() == 1 && !Subtarget.hasAVX512() &&
+      TLI.isOperationLegalOrCustom(ISD::MSTORE, VT)) {
+    EVT MaskVT = Mask.getValueType();
+    EVT ExtMaskVT = VT.changeVectorElementTypeToInteger();
+    if (SDValue NewMask = combineToExtendBoolVectorInReg(
+            ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget)) {
+      NewMask = DAG.getNode(ISD::TRUNCATE, DL, MaskVT, NewMask);
+      return DAG.getMaskedStore(Mst->getChain(), SDLoc(N), Mst->getValue(),
+                                Mst->getBasePtr(), Mst->getOffset(), NewMask,
+                                Mst->getMemoryVT(), Mst->getMemOperand(),
+                                Mst->getAddressingMode());
+    }
+  }
+
+
   // If the mask value has been legalized to a non-boolean vector, try to
   // simplify ops leading up to it. We only demand the MSB of each lane.
   if (Mask.getScalarValueSizeInBits() != 1) {
@@ -57398,35 +57435,35 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
-                                    SDValue Index, SDValue Base, SDValue Scale,
-                                    SelectionDAG &DAG) {
+                                    SDValue Index, SDValue Base, SDValue Mask,
+                                    SDValue Scale, SelectionDAG &DAG) {
   SDLoc DL(GorS);
 
   if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
-    SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
-                      Gather->getMask(), Base, Index, Scale } ;
-    return DAG.getMaskedGather(Gather->getVTList(),
-                               Gather->getMemoryVT(), DL, Ops,
-                               Gather->getMemOperand(),
+    SDValue Ops[] = {
+        Gather->getChain(), Gather->getPassThru(), Mask, Base, Index, Scale};
+    return DAG.getMaskedGather(Gather->getVTList(), Gather->getMemoryVT(), DL,
+                               Ops, Gather->getMemOperand(),
                                Gather->getIndexType(),
                                Gather->getExtensionType());
   }
   auto *Scatter = cast<MaskedScatterSDNode>(GorS);
-  SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
-                    Scatter->getMask(), Base, Index, Scale };
-  return DAG.getMaskedScatter(Scatter->getVTList(),
-                              Scatter->getMemoryVT(), DL,
+  SDValue Ops[] = {
+      Scatter->getChain(), Scatter->getValue(), Mask, Base, Index, Scale};
+  return DAG.getMaskedScatter(Scatter->getVTList(), Scatter->getMemoryVT(), DL,
                               Ops, Scatter->getMemOperand(),
                               Scatter->getIndexType(),
                               Scatter->isTruncatingStore());
 }
 
 static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
-                                    TargetLowering::DAGCombinerInfo &DCI) {
+                                    TargetLowering::DAGCombinerInfo &DCI,
+                                    const X86Subtarget &Subtarget) {
   SDLoc DL(N);
   auto *GorS = cast<MaskedGatherScatterSDNode>(N);
   SDValue Index = GorS->getIndex();
   SDValue Base = GorS->getBasePtr();
+  SDValue Mask = GorS->getMask();
   SDValue Scale = GorS->getScale();
   EVT IndexVT = Index.getValueType();
   EVT IndexSVT = IndexVT.getVectorElementType();
@@ -57460,7 +57497,8 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                          Index.getOperand(0), NewShAmt);
           SDValue NewScale =
               DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
-          return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+          return rebuildGatherScatter(GorS, NewIndex, Base, Mask, NewScale,
+                                      DAG);
         }
       }
     }
@@ -57478,7 +57516,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       // a split.
       if (SDValue TruncIndex =
               DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
-        return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);
+        return rebuildGatherScatter(GorS, TruncIndex, Base, Mask, Scale, DAG);
 
       // Shrink any sign/zero extends from 32 or smaller to larger than 32 if
       // there are sufficient sign bits. Only do this before legalize types to
@@ -57487,13 +57525,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
            Index.getOpcode() == ISD::ZERO_EXTEND) &&
           Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
-        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+        return rebuildGatherScatter(GorS, Index, Base, Mask, Scale, DAG);
       }
 
       // Shrink if we remove an illegal type.
       if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
-        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+        return rebuildGatherScatter(GorS, Index, Base, Mask, Scale, DAG);
       }
     }
   }
@@ -57518,13 +57556,15 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
               SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
                                             DAG.getConstant(Adder, DL, PtrVT));
               SDValue NewIndex = Index.getOperand(1 - I);
-              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+              return rebuildGatherScatter(GorS, NewIndex, NewBase, Mask, Scale,
+                                          DAG);
             }
             // For non-constant cases, limit this to non-scaled cases.
             if (ScaleAmt == 1) {
               SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
               SDValue NewIndex = Index.getOperand(1 - I);
-              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+              return rebuildGatherScatter(GorS, NewIndex, NewBase, Mask, Scale,
+                                          DAG);
             }
           }
         }
@@ -57539,7 +57579,8 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
           SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
                                          Index.getOperand(1 - I), Splat);
           SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
-          return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+          return rebuildGatherScatter(GorS, NewIndex, NewBase, Mask, Scale,
+                                      DAG);
         }
       }
   }
@@ -57550,12 +57591,25 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       MVT EltVT = IndexWidth > 32 ? MVT::i64 : MVT::i32;
       IndexVT = IndexVT.changeVectorElementType(*DAG.getContext(), EltVT);
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
-      return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+      return rebuildGatherScatter(GorS, Index, Base, Mask, Scale, DAG);
+    }
+
+    // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get
+    // split by legalization.
+    if (GorS->getOpcode() == ISD::MGATHER && Mask.getOpcode() == ISD::BITCAST &&
+        Mask.getScalarValueSizeInBits() == 1 && !Subtarget.hasAVX512() &&
+        TLI.isOperationLegalOrCustom(ISD::MGATHER, N->getValueType(0))) {
+      EVT MaskVT = Mask.getValueType();
+      EVT ExtMaskVT = N->getValueType(0).changeVectorElementTypeToInteger();
+      if (SDValue ExtMask = combineToExtendBoolVectorInReg(
+              ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget)) {
+        ExtMask = DAG.getNode(ISD::TRUNCATE, DL, MaskVT, ExtMask);
+        return rebuildGatherScatter(GorS, Index, Base, ExtMask, Scale, DAG);
+      }
     }
   }
 
   // With vector masks we only demand the upper bit of the mask.
-  SDValue Mask = GorS->getMask();
   if (Mask.getScalarValueSizeInBits() != 1) {
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
     if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
@@ -61700,7 +61754,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case X86ISD::MGATHER:
   case X86ISD::MSCATTER:    return combineX86GatherScatter(N, DAG, DCI);
   case ISD::MGATHER:
-  case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
+  case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI, Subtarget);
   case X86ISD::PCMPEQ:
   case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
   case X86ISD::PMULDQ:
diff --git a/llvm/test/CodeGen/X86/masked_gather.ll b/llvm/test/CodeGen/X86/masked_gather.ll
index 962ff66b072a6..2913fe13095ca 100644
--- a/llvm/test/CodeGen/X86/masked_gather.ll
+++ b/llvm/test/CodeGen/X86/masked_gather.ll
@@ -312,27 +312,11 @@ define <4 x float> @masked_gather_v4f32_ptr_v4i32(<4 x ptr> %ptr, i32 %trigger,
 ;
 ; AVX2-GATHER-LABEL: masked_gather_v4f32_ptr_v4i32:
 ; AVX2-GATHER:       # %bb.0:
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vmovd %eax, %xmm2
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $1, %eax, %xmm2, %xmm2
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb $2, %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $2, %eax, %xmm2, %xmm2
-; AVX2-GATHER-NEXT:    andb $8, %dil
-; AVX2-GATHER-NEXT:    shrb $3, %dil
-; AVX2-GATHER-NEXT:    movzbl %dil, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $3, %eax, %xmm2, %xmm2
+; AVX2-GATHER-NEXT:    vmovd %edi, %xmm2
+; AVX2-GATHER-NEXT:    vpbroadcastd %xmm2, %xmm2
+; AVX2-GATHER-NEXT:    vpmovsxbd {{.*#+}} xmm3 = [1,2,4,8]
+; AVX2-GATHER-NEXT:    vpand %xmm3, %xmm2, %xmm2
+; AVX2-GATHER-NEXT:    vpcmpeqd %xmm3, %xmm2, %xmm2
 ; AVX2-GATHER-NEXT:    vgatherqps %xmm2, (,%ymm0), %xmm1
 ; AVX2-GATHER-NEXT:    vmovaps %xmm1, %xmm0
 ; AVX2-GATHER-NEXT:    vzeroupper
@@ -2575,51 +2559,11 @@ define <8 x i32> @masked_gather_v8i32_v8i32(i8 %trigger) {
 ;
 ; AVX2-GATHER-LABEL: masked_gather_v8i32_v8i32:
 ; AVX2-GATHER:       # %bb.0:
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb $5, %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    movl %edi, %ecx
-; AVX2-GATHER-NEXT:    shrb $4, %cl
-; AVX2-GATHER-NEXT:    movzbl %cl, %ecx
-; AVX2-GATHER-NEXT:    andl $1, %ecx
-; AVX2-GATHER-NEXT:    negl %ecx
-; AVX2-GATHER-NEXT:    vmovd %ecx, %xmm0
-; AVX2-GATHER-NEXT:    vpinsrd $1, %eax, %xmm0, %xmm0
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb $6, %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $2, %eax, %xmm0, %xmm0
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb $7, %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $3, %eax, %xmm0, %xmm0
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vmovd %eax, %xmm1
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $1, %eax, %xmm1, %xmm1
-; AVX2-GATHER-NEXT:    movl %edi, %eax
-; AVX2-GATHER-NEXT:    shrb $2, %al
-; AVX2-GATHER-NEXT:    movzbl %al, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $2, %eax, %xmm1, %xmm1
-; AVX2-GATHER-NEXT:    shrb $3, %dil
-; AVX2-GATHER-NEXT:    movzbl %dil, %eax
-; AVX2-GATHER-NEXT:    andl $1, %eax
-; AVX2-GATHER-NEXT:    negl %eax
-; AVX2-GATHER-NEXT:    vpinsrd $3, %eax, %xmm1, %xmm1
-; AVX2-GATHER-NEXT:    vinserti128 $1, %xmm0, %ymm1, %ymm0
+; AVX2-GATHER-NEXT:    vmovd %edi, %xmm0
+; AVX2-GATHER-NEXT:    vpbroadcastb %xmm0, %ymm0
+; AVX2-GATHER-NEXT:    vpmovzxbd {{.*#+}} ymm1 = [1,2,4,8,16,32,64,128]
+; AVX2-GATHER-NEXT:    vpand %ymm1, %ymm0, %ymm0
+; AVX2-GATHER-NEXT:    vpcmpeqd %ymm1, %ymm0, %ymm0
 ; AVX2-GATHER-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX2-GATHER-NEXT:    vmovdqa %ymm0, %ymm2
 ; AVX2-GATHER-NEXT:    vpxor %xmm3, %xmm3, %xmm3
diff --git a/llvm/test/CodeGen/X86/masked_load.ll b/llvm/test/CodeGen/X86/masked_load.ll
index 672ec4038d235..99a8918fef93f 100644
--- a/llvm/test/CodeGen/X86/masked_load.ll
+++ b/llvm/test/CodeGen/X86/masked_load.ll
@@ -113,21 +113,27 @@ define <2 x double> @load_v2f64_i2(i2 %trigger, ptr %addr, <2 x double> %dst) {
 ; SSE-NEXT:    movhps {{.*#+}} xmm0 = xmm0[0,1],mem[0,1]
 ; SSE-NEXT:    retq
 ;
-; AVX1OR2-LABEL: load_v2f64_i2:
-; AVX1OR2:       ## %bb.0:
-; AVX1OR2-NEXT:    movl %edi, %eax
-; AVX1OR2-NEXT:    andl $1, %eax
-; AVX1OR2-NEXT:    negq %rax
-; AVX1OR2-NEXT:    vmovq %rax, %xmm1
-; AVX1OR2-NEXT:    andb $2, %dil
-; AVX1OR2-NEXT:    shrb %dil
-; AVX1OR2-NEXT:    movzbl %dil, %eax
-; AVX1OR2-NEXT:    negq %rax
-; AVX1OR2-NEXT:    vmovq %rax, %xmm2
-; AVX1OR2-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; AVX1OR2-NEXT:    vmaskmovpd (%rsi), %xmm1, %xmm2
-; AVX1OR2-NEXT:    vblendvpd %xmm1, %xmm2, %xmm0, %xmm0
-; AVX1OR2-NEXT:    retq
+; AVX1-LABEL: load_v2f64_i2:
+; AVX1:       ## %bb.0:
+; AVX1-NEXT:    vmovd %edi, %xmm1
+; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,0,1]
+; AVX1-NEXT:    vpmovsxbq {{.*#+}} xmm2 = [1,2]
+; AVX1-NEXT:    vpand %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vpcmpeqq %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vmaskmovpd (%rsi), %xmm1, %xmm2
+; AVX1-NEXT:    vblendvpd %xmm1, %xmm2, %xmm0, %xmm0
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: load_v2f64_i2:
+; AVX2:       ## %bb.0:
+; AVX2-NEXT:    vmovd %edi, %xmm1
+; AVX2-NEXT:    vpbroadcastd %xmm1, %xmm1
+; AVX2-NEXT:    vpmovsxbq {{.*#+}} xmm2 = [1,2]
+; AVX2-NEXT:    vpand %xmm2, %xmm1, %xmm1
+; AVX2-NEXT:    vpcmpeqq %xmm2, %xmm1, %xmm1
+; AVX2-NEXT:    vmaskmovpd (%rsi), %xmm1, %xmm2
+; AVX2-NEXT:    vblendvpd %xmm1, %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    retq
 ;
 ; AVX512F-LABEL: load_v2f64_i2:
 ; AVX512F:       ## %bb.0:
@@ -281,29 +287,14 @@ define <4 x double> @load_v4f64_i4(i4 %trigger, ptr %addr, <4 x double> %dst) {
 ;
 ; AVX1-LABEL: load_v4f64_i4:
 ; AVX1:       ## %bb.0:
-; AVX1-NEXT:    movl %edi, %eax
-; AVX1-NEXT:    andl $1, %eax
-; AVX1-NEXT:    negl %eax
-; AVX1-NEXT:    vmovd %eax, %xmm1
-; AVX1-NEXT:    movl %edi, %eax
-; AVX1-NEXT:    shrb %al
-; AVX1-NEXT:    movzbl %al, %eax
-; AVX1-NEXT:    andl $1, %eax
-; AVX1-NEXT:    negl %eax
-; AVX1-NEXT:    vpinsrd $1, %eax, %xmm1, %xmm1
-; AVX1-NEXT:    vpmovsxdq %xmm1, %xmm2
-; AVX1-NEXT:    movl %edi, %eax
-; AVX1-NEXT:    shrb $2, %al
-; AVX1-NEXT:    movzbl %al, %eax
-; AVX1-NEXT:    andl $1, %eax
-; AVX1-NEXT:    negl %eax
-; AVX1-NEXT:    vpinsrd $2, %eax, %xmm1, %xmm1
-; AVX1-NEXT:    andb $8, %dil
-; AVX1-NEXT:    shrb $3, %dil
-; AVX1-NEXT:    movzbl %dil, %eax
-; AVX1-NEXT:    negl %eax
-; AVX1-NEXT:    vpinsrd $3, %eax, %xmm1, %xmm1
-; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[2,2,3,3]
+; AVX1-NEXT:    vmovd %edi, %xmm1
+; AVX1-NEXT:    vpshufd {{.*#+}} xmm1 = xmm1[0,1,0,1]
+; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm1, %ymm1
+; AVX1-NEXT:    vmovaps {{.*#+}} ymm2 = [1,2,4,8]
+; AVX1-NEXT:    vandps %ymm2, %ymm1, %ymm1
+; AVX1-NEXT:    vpcmpeqq %xmm2, %xmm1, %xmm2
+; AVX1-NEXT:    vextractf128 $1, %ymm1, %xmm1
+; AVX1-NEXT:    vpcmpeqq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
 ; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm2, %ymm1
 ; AVX1-NEXT:    vmaskmovpd (%rsi), %ymm1, %ymm2
 ; AVX1-NEXT:    vblendvpd %ymm1, %ymm2, %ymm0, %ymm0
@@ -311,30 +302,11 @@ define <4 x double> @load_v4f64_i4(i4 %trigger, ptr %addr, <4 x double> %dst) {
 ;
 ; AVX2-LABEL: load_v4f64_i4:
 ; AVX2:       ## %bb.0:
-; AVX2-NEXT:    movl %edi, %eax
-; AVX2-NEXT:    andb $8, %al
-; AVX2-NEXT:    shrb $3, %al
-; AVX2-NEXT:    movzbl %al, %eax
-; AVX2-NEXT:    negq %rax
-; AVX2-NEXT:    vmovq %rax, %xmm1
-; AVX2-NEXT:    movl %edi, %eax
-; AVX2-NEXT:    shrb $2, %al
-; AVX2-NEXT:    movzbl %al, %eax
-; AVX2-NEXT:    andl $1, %eax
-; AVX2-NEXT:    negq %rax
-; AVX2-NEXT:    vmovq %rax, %xmm2
-; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
-; AVX2-NEXT:    movl %edi, %eax
-; AVX2-NEXT:    andl $1, %eax
-; AVX2-NEXT:    negq %rax
-; AVX2-NEXT:    vmovq %rax, %xmm2
-; AVX2-NEXT:    shrb %dil
-; AVX2-NEXT:    movzbl %dil, %eax
-; AVX2-NEXT:    andl $1, %eax
-; AVX2-NEXT:    negq %rax
-; AVX2-NEXT:    vmovq %rax, %xmm3
-; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; AVX2-NEXT:    vinserti128 $1, %xmm1, %ymm2, %ymm1
+; AVX2-NEXT:    vmovd %edi, %xmm1
+; AVX2-NEXT:    vpbroadcastd %xmm1, %ymm1
+; AVX2-NEXT:    vpmovsxbq {{.*#+}} ymm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %ymm2, %ymm1, %ymm1
+; AVX2-NEXT:    vpcmpeqq %ymm2, %ymm1, %ymm1
 ; AVX2-NEXT:    vmaskmovpd (%rsi), %ymm1, %ymm2
 ; AVX2-NEXT:    vblendvpd %ymm1, %ymm2, %ymm0, %ymm0
 ; AVX2-NEXT:    retq
@@ -1552,32 +1524,27 @@ define <4 x float> @load_v4f32_i4(i4 %trigger, ptr %addr, <4 x float> %dst) {
 ; SSE42-NEXT:    insertps {{.*#+}} xmm0 = xmm0[0,1,2],mem[0]
 ; SSE42-NEXT:    retq
 ;
-; AVX1OR2-LABEL: load_v4f32_i4:
-; AVX1OR2:       ## %bb.0:
-; AVX1OR2-NEXT:    movl %edi, %eax
-; AVX1OR2-NEXT:    andl $1, %eax
-; AVX1OR2-NEXT:    negl %eax
-; AVX1OR2-NEXT:    vmovd %eax, %xmm1
-; AVX1OR2-NEXT:    movl %edi, %eax
-; AVX1OR2-NEXT:    shrb %al
-; AVX1OR2-NEXT:    movzbl %al, %eax
-; AVX1OR2-NEXT:    andl $1, %eax
-; AVX1OR2-NEXT:    negl %eax
-; AVX1OR2-NEXT:    vpinsrd $1, %eax, %xmm1, %xmm1
-; AVX1OR2-NEXT:    movl %edi, %eax
-; AVX1OR2-NEXT:    shrb $2, %al
-; AVX1OR2-NEXT:    movzbl %al, %eax
-; AVX1OR2-NEXT:    andl $1, %eax
-; AVX1OR2-NEXT:    negl %eax
-; AVX1OR2-NEXT:    vpinsrd $2, %eax, %xmm1, %xmm1
-; AVX1OR2-NEXT:...
[truncated]

@RKSimon RKSimon changed the title [X86] Ensure a (vXi1 bitcast(iX Mask)) mask is canonicalised for extension before it might get split by legalisation [X86] Ensure a (vXi1 bitcast(iX Mask)) memory mask is canonicalised for extension before it might get split by legalisation Jan 13, 2026
@github-actions
Copy link

github-actions bot commented Jan 13, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.


// Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get
// split by legalization.
if (GorS->getOpcode() == ISD::MGATHER && Mask.getOpcode() == ISD::BITCAST &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only limit to gather here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're only targeting AVX2, as it has no mask registers or scatter instruction

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the explicit Opcode check for ISD::MGATHER and now rely on canonicalizeBoolMask to early-out because ISD::MSCATTER isn't legal/custom on pre-AVX512 targets


// Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get split
// by legalization.
if (DCI.isBeforeLegalizeOps() && Mask.getOpcode() == ISD::BITCAST &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions and below code are similar, is it possible to move them into a common function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - I've added a canonicalizeBoolMask common helper

Copy link
Contributor

@phoebewang phoebewang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit.

// by legalization.
if (SDValue NewMask =
canonicalizeBoolMask(ISD::MLOAD, VT, Mask, DL, DAG, DCI, Subtarget)) {
NewMask = DAG.getNode(ISD::TRUNCATE, DL, Mask.getValueType(), NewMask);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be moved into canonicalizeBoolMask too.

assert(ExtMaskVT.bitsGT(MaskVT) && "Unexpected extension type");
SDValue NewMask = combineToExtendBoolVectorInReg(
ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget);
return DAG.getNode(ISD::TRUNCATE, DL, MaskVT, NewMask);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check NewMask before truncate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

@RKSimon RKSimon enabled auto-merge (squash) January 14, 2026 13:32
@RKSimon RKSimon merged commit 83586be into llvm:main Jan 14, 2026
10 of 11 checks passed
@RKSimon RKSimon deleted the x86-bool-mask-memory-ext branch January 14, 2026 14:11
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Jan 18, 2026
…or extension before it might get split by legalisation (llvm#175769)

Masked load/store/gathers often need to bitcast the mask from a
bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization
if we don't custom expand the mask first.

This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg
helper functions to canonicalise the masks, similar to what we already
do for vselect expansion.

Alternative to llvm#175385

Fixes llvm#59789
BStott6 pushed a commit to BStott6/llvm-project that referenced this pull request Jan 22, 2026
…or extension before it might get split by legalisation (llvm#175769)

Masked load/store/gathers often need to bitcast the mask from a
bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization
if we don't custom expand the mask first.

This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg
helper functions to canonicalise the masks, similar to what we already
do for vselect expansion.

Alternative to llvm#175385

Fixes llvm#59789
folkertdev added a commit that referenced this pull request Feb 27, 2026
Using code/ideas from the x86 backend to optimize a select on a bitcast
integer. The previous aarch64 approach was to individually extract the
bits from the mask, which is kind of terrible.

https://rust.godbolt.org/z/576sndT66

```llvm
define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) {
start:
  %t = load <8 x i32>, ptr %if_true, align 4
  %f = load <8 x i32>, ptr %if_false, align 4
  %m = bitcast i8 %mask to <8 x i1>
  %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f
  store <8 x i32> %s, ptr %out, align 4
  ret void
}
```

turned into 

```asm
if_then_else8:                          // @if_then_else8
        sub     sp, sp, #16
        ubfx    w8, w1, #4, #1
        and     w11, w1, #0x1
        ubfx    w9, w1, #5, #1
        fmov    s1, w11
        ubfx    w10, w1, #1, #1
        fmov    s0, w8
        ubfx    w8, w1, #6, #1
        ldp     q5, q2, [x3]
        mov     v1.h[1], w10
        ldp     q4, q3, [x2]
        mov     v0.h[1], w9
        ubfx    w9, w1, #2, #1
        mov     v1.h[2], w9
        ubfx    w9, w1, #3, #1
        mov     v0.h[2], w8
        ubfx    w8, w1, #7, #1
        mov     v1.h[3], w9
        mov     v0.h[3], w8
        ushll   v1.4s, v1.4h, #0
        ushll   v0.4s, v0.4h, #0
        shl     v1.4s, v1.4s, #31
        shl     v0.4s, v0.4s, #31
        cmlt    v1.4s, v1.4s, #0
        cmlt    v0.4s, v0.4s, #0
        bsl     v1.16b, v4.16b, v5.16b
        bsl     v0.16b, v3.16b, v2.16b
        stp     q1, q0, [x0]
        add     sp, sp, #16
        ret
```

With this PR that instead emits

```asm
if_then_else8:
   adrp x8, .LCPI0_1
   dup v0.4s, w1
   ldr q1, [x8, :lo12:.LCPI0_1]
   adrp x8, .LCPI0_0
   ldr q2, [x8, :lo12:.LCPI0_0]
   ldp q4, q3, [x2]
   and v1.16b, v0.16b, v1.16b
   and v0.16b, v0.16b, v2.16b
   ldp q5, q2, [x3]
   cmeq v1.4s, v1.4s, #0
   cmeq v0.4s, v0.4s, #0
   bsl v1.16b, v2.16b, v3.16b
   bsl v0.16b, v5.16b, v4.16b
   stp q0, q1, [x0]
   ret
```

So substantially shorter. Instead of building the mask
element-by-element, this approach (by virtue of not splitting) instead
splats the mask value into all vector lanes, performs a bitwise and with
powers of 2, and compares with zero to construct the mask vector.

cc rust-lang/rust#122376
cc #175769
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 27, 2026
Using code/ideas from the x86 backend to optimize a select on a bitcast
integer. The previous aarch64 approach was to individually extract the
bits from the mask, which is kind of terrible.

https://rust.godbolt.org/z/576sndT66

```llvm
define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) {
start:
  %t = load <8 x i32>, ptr %if_true, align 4
  %f = load <8 x i32>, ptr %if_false, align 4
  %m = bitcast i8 %mask to <8 x i1>
  %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f
  store <8 x i32> %s, ptr %out, align 4
  ret void
}
```

turned into

```asm
if_then_else8:                          // @if_then_else8
        sub     sp, sp, #16
        ubfx    w8, w1, #4, #1
        and     w11, w1, #0x1
        ubfx    w9, w1, #5, #1
        fmov    s1, w11
        ubfx    w10, w1, #1, #1
        fmov    s0, w8
        ubfx    w8, w1, #6, #1
        ldp     q5, q2, [x3]
        mov     v1.h[1], w10
        ldp     q4, q3, [x2]
        mov     v0.h[1], w9
        ubfx    w9, w1, #2, #1
        mov     v1.h[2], w9
        ubfx    w9, w1, #3, #1
        mov     v0.h[2], w8
        ubfx    w8, w1, #7, #1
        mov     v1.h[3], w9
        mov     v0.h[3], w8
        ushll   v1.4s, v1.4h, #0
        ushll   v0.4s, v0.4h, #0
        shl     v1.4s, v1.4s, #31
        shl     v0.4s, v0.4s, #31
        cmlt    v1.4s, v1.4s, #0
        cmlt    v0.4s, v0.4s, #0
        bsl     v1.16b, v4.16b, v5.16b
        bsl     v0.16b, v3.16b, v2.16b
        stp     q1, q0, [x0]
        add     sp, sp, #16
        ret
```

With this PR that instead emits

```asm
if_then_else8:
   adrp x8, .LCPI0_1
   dup v0.4s, w1
   ldr q1, [x8, :lo12:.LCPI0_1]
   adrp x8, .LCPI0_0
   ldr q2, [x8, :lo12:.LCPI0_0]
   ldp q4, q3, [x2]
   and v1.16b, v0.16b, v1.16b
   and v0.16b, v0.16b, v2.16b
   ldp q5, q2, [x3]
   cmeq v1.4s, v1.4s, #0
   cmeq v0.4s, v0.4s, #0
   bsl v1.16b, v2.16b, v3.16b
   bsl v0.16b, v5.16b, v4.16b
   stp q0, q1, [x0]
   ret
```

So substantially shorter. Instead of building the mask
element-by-element, this approach (by virtue of not splitting) instead
splats the mask value into all vector lanes, performs a bitwise and with
powers of 2, and compares with zero to construct the mask vector.

cc rust-lang/rust#122376
cc llvm/llvm-project#175769
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
Using code/ideas from the x86 backend to optimize a select on a bitcast
integer. The previous aarch64 approach was to individually extract the
bits from the mask, which is kind of terrible.

https://rust.godbolt.org/z/576sndT66

```llvm
define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) {
start:
  %t = load <8 x i32>, ptr %if_true, align 4
  %f = load <8 x i32>, ptr %if_false, align 4
  %m = bitcast i8 %mask to <8 x i1>
  %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f
  store <8 x i32> %s, ptr %out, align 4
  ret void
}
```

turned into 

```asm
if_then_else8:                          // @if_then_else8
        sub     sp, sp, llvm#16
        ubfx    w8, w1, llvm#4, llvm#1
        and     w11, w1, #0x1
        ubfx    w9, w1, llvm#5, llvm#1
        fmov    s1, w11
        ubfx    w10, w1, llvm#1, llvm#1
        fmov    s0, w8
        ubfx    w8, w1, llvm#6, llvm#1
        ldp     q5, q2, [x3]
        mov     v1.h[1], w10
        ldp     q4, q3, [x2]
        mov     v0.h[1], w9
        ubfx    w9, w1, llvm#2, llvm#1
        mov     v1.h[2], w9
        ubfx    w9, w1, llvm#3, llvm#1
        mov     v0.h[2], w8
        ubfx    w8, w1, llvm#7, llvm#1
        mov     v1.h[3], w9
        mov     v0.h[3], w8
        ushll   v1.4s, v1.4h, #0
        ushll   v0.4s, v0.4h, #0
        shl     v1.4s, v1.4s, llvm#31
        shl     v0.4s, v0.4s, llvm#31
        cmlt    v1.4s, v1.4s, #0
        cmlt    v0.4s, v0.4s, #0
        bsl     v1.16b, v4.16b, v5.16b
        bsl     v0.16b, v3.16b, v2.16b
        stp     q1, q0, [x0]
        add     sp, sp, llvm#16
        ret
```

With this PR that instead emits

```asm
if_then_else8:
   adrp x8, .LCPI0_1
   dup v0.4s, w1
   ldr q1, [x8, :lo12:.LCPI0_1]
   adrp x8, .LCPI0_0
   ldr q2, [x8, :lo12:.LCPI0_0]
   ldp q4, q3, [x2]
   and v1.16b, v0.16b, v1.16b
   and v0.16b, v0.16b, v2.16b
   ldp q5, q2, [x3]
   cmeq v1.4s, v1.4s, #0
   cmeq v0.4s, v0.4s, #0
   bsl v1.16b, v2.16b, v3.16b
   bsl v0.16b, v5.16b, v4.16b
   stp q0, q1, [x0]
   ret
```

So substantially shorter. Instead of building the mask
element-by-element, this approach (by virtue of not splitting) instead
splats the mask value into all vector lanes, performs a bitwise and with
powers of 2, and compares with zero to construct the mask vector.

cc rust-lang/rust#122376
cc llvm#175769
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inefficient mask generation from integer bitmask for llvm.masked.gather on avx2 targets

3 participants