Conversation
…nsion before it might get split by legalisation Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the combineToExtendBoolVectorInReg helper function to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to llvm#175385 Fixes llvm#175385
|
CC @folkertdev |
|
@llvm/pr-subscribers-backend-x86 Author: Simon Pilgrim (RKSimon) ChangesMasked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the combineToExtendBoolVectorInReg helper function to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to #175385 Fixes #175385 Patch is 35.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175769.diff 4 Files Affected:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a354704c5958b..91ced77438492 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -53679,12 +53679,31 @@ static SDValue combineMaskedLoad(SDNode *N, SelectionDAG &DAG,
return Blend;
}
+ EVT VT = Mld->getValueType(0);
+ SDValue Mask = Mld->getMask();
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+ // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get split
+ // by legalization.
+ if (DCI.isBeforeLegalizeOps() && Mask.getOpcode() == ISD::BITCAST &&
+ Mask.getScalarValueSizeInBits() == 1 && !Subtarget.hasAVX512() &&
+ TLI.isOperationLegalOrCustom(ISD::MLOAD, VT)) {
+ SDLoc DL(N);
+ EVT MaskVT = Mask.getValueType();
+ EVT ExtMaskVT = VT.changeVectorElementTypeToInteger();
+ if (SDValue NewMask = combineToExtendBoolVectorInReg(
+ ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget)) {
+ NewMask = DAG.getNode(ISD::TRUNCATE, DL, MaskVT, NewMask);
+ return DAG.getMaskedLoad(
+ VT, DL, Mld->getChain(), Mld->getBasePtr(), Mld->getOffset(), NewMask,
+ Mld->getPassThru(), Mld->getMemoryVT(), Mld->getMemOperand(),
+ Mld->getAddressingMode(), Mld->getExtensionType());
+ }
+ }
+
// If the mask value has been legalized to a non-boolean vector, try to
// simplify ops leading up to it. We only demand the MSB of each lane.
- SDValue Mask = Mld->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
- EVT VT = Mld->getValueType(0);
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedBits(APInt::getSignMask(VT.getScalarSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedBits, DCI)) {
if (N->getOpcode() != ISD::DELETED_NODE)
@@ -53785,6 +53804,24 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+ // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get split
+ // by legalization.
+ if (DCI.isBeforeLegalizeOps() && Mask.getOpcode() == ISD::BITCAST &&
+ Mask.getScalarValueSizeInBits() == 1 && !Subtarget.hasAVX512() &&
+ TLI.isOperationLegalOrCustom(ISD::MSTORE, VT)) {
+ EVT MaskVT = Mask.getValueType();
+ EVT ExtMaskVT = VT.changeVectorElementTypeToInteger();
+ if (SDValue NewMask = combineToExtendBoolVectorInReg(
+ ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget)) {
+ NewMask = DAG.getNode(ISD::TRUNCATE, DL, MaskVT, NewMask);
+ return DAG.getMaskedStore(Mst->getChain(), SDLoc(N), Mst->getValue(),
+ Mst->getBasePtr(), Mst->getOffset(), NewMask,
+ Mst->getMemoryVT(), Mst->getMemOperand(),
+ Mst->getAddressingMode());
+ }
+ }
+
+
// If the mask value has been legalized to a non-boolean vector, try to
// simplify ops leading up to it. We only demand the MSB of each lane.
if (Mask.getScalarValueSizeInBits() != 1) {
@@ -57398,35 +57435,35 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
}
static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
- SDValue Index, SDValue Base, SDValue Scale,
- SelectionDAG &DAG) {
+ SDValue Index, SDValue Base, SDValue Mask,
+ SDValue Scale, SelectionDAG &DAG) {
SDLoc DL(GorS);
if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
- SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
- Gather->getMask(), Base, Index, Scale } ;
- return DAG.getMaskedGather(Gather->getVTList(),
- Gather->getMemoryVT(), DL, Ops,
- Gather->getMemOperand(),
+ SDValue Ops[] = {
+ Gather->getChain(), Gather->getPassThru(), Mask, Base, Index, Scale};
+ return DAG.getMaskedGather(Gather->getVTList(), Gather->getMemoryVT(), DL,
+ Ops, Gather->getMemOperand(),
Gather->getIndexType(),
Gather->getExtensionType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
- SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
- Scatter->getMask(), Base, Index, Scale };
- return DAG.getMaskedScatter(Scatter->getVTList(),
- Scatter->getMemoryVT(), DL,
+ SDValue Ops[] = {
+ Scatter->getChain(), Scatter->getValue(), Mask, Base, Index, Scale};
+ return DAG.getMaskedScatter(Scatter->getVTList(), Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Scatter->getIndexType(),
Scatter->isTruncatingStore());
}
static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ const X86Subtarget &Subtarget) {
SDLoc DL(N);
auto *GorS = cast<MaskedGatherScatterSDNode>(N);
SDValue Index = GorS->getIndex();
SDValue Base = GorS->getBasePtr();
+ SDValue Mask = GorS->getMask();
SDValue Scale = GorS->getScale();
EVT IndexVT = Index.getValueType();
EVT IndexSVT = IndexVT.getVectorElementType();
@@ -57460,7 +57497,8 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOperand(0), NewShAmt);
SDValue NewScale =
DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
- return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+ return rebuildGatherScatter(GorS, NewIndex, Base, Mask, NewScale,
+ DAG);
}
}
}
@@ -57478,7 +57516,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
// a split.
if (SDValue TruncIndex =
DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
- return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);
+ return rebuildGatherScatter(GorS, TruncIndex, Base, Mask, Scale, DAG);
// Shrink any sign/zero extends from 32 or smaller to larger than 32 if
// there are sufficient sign bits. Only do this before legalize types to
@@ -57487,13 +57525,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOpcode() == ISD::ZERO_EXTEND) &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ return rebuildGatherScatter(GorS, Index, Base, Mask, Scale, DAG);
}
// Shrink if we remove an illegal type.
if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ return rebuildGatherScatter(GorS, Index, Base, Mask, Scale, DAG);
}
}
}
@@ -57518,13 +57556,15 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
DAG.getConstant(Adder, DL, PtrVT));
SDValue NewIndex = Index.getOperand(1 - I);
- return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Mask, Scale,
+ DAG);
}
// For non-constant cases, limit this to non-scaled cases.
if (ScaleAmt == 1) {
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
SDValue NewIndex = Index.getOperand(1 - I);
- return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Mask, Scale,
+ DAG);
}
}
}
@@ -57539,7 +57579,8 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
Index.getOperand(1 - I), Splat);
SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
- return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+ return rebuildGatherScatter(GorS, NewIndex, NewBase, Mask, Scale,
+ DAG);
}
}
}
@@ -57550,12 +57591,25 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
MVT EltVT = IndexWidth > 32 ? MVT::i64 : MVT::i32;
IndexVT = IndexVT.changeVectorElementType(*DAG.getContext(), EltVT);
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
- return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+ return rebuildGatherScatter(GorS, Index, Base, Mask, Scale, DAG);
+ }
+
+ // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get
+ // split by legalization.
+ if (GorS->getOpcode() == ISD::MGATHER && Mask.getOpcode() == ISD::BITCAST &&
+ Mask.getScalarValueSizeInBits() == 1 && !Subtarget.hasAVX512() &&
+ TLI.isOperationLegalOrCustom(ISD::MGATHER, N->getValueType(0))) {
+ EVT MaskVT = Mask.getValueType();
+ EVT ExtMaskVT = N->getValueType(0).changeVectorElementTypeToInteger();
+ if (SDValue ExtMask = combineToExtendBoolVectorInReg(
+ ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget)) {
+ ExtMask = DAG.getNode(ISD::TRUNCATE, DL, MaskVT, ExtMask);
+ return rebuildGatherScatter(GorS, Index, Base, ExtMask, Scale, DAG);
+ }
}
}
// With vector masks we only demand the upper bit of the mask.
- SDValue Mask = GorS->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
@@ -61700,7 +61754,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::MGATHER:
case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
case ISD::MGATHER:
- case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
+ case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget);
case X86ISD::PCMPEQ:
case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
case X86ISD::PMULDQ:
diff --git a/llvm/test/CodeGen/X86/masked_gather.ll b/llvm/test/CodeGen/X86/masked_gather.ll
index 962ff66b072a6..2913fe13095ca 100644
--- a/llvm/test/CodeGen/X86/masked_gather.ll
+++ b/llvm/test/CodeGen/X86/masked_gather.ll
@@ -312,27 +312,11 @@ define <4 x float> @masked_gather_v4f32_ptr_v4i32(<4 x ptr> %ptr, i32 %trigger,
;
; AVX2-GATHER-LABEL: masked_gather_v4f32_ptr_v4i32:
; AVX2-GATHER: # %bb.0:
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vmovd %eax, %xmm2
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $1, %eax, %xmm2, %xmm2
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb $2, %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $2, %eax, %xmm2, %xmm2
-; AVX2-GATHER-NEXT: andb $8, %dil
-; AVX2-GATHER-NEXT: shrb $3, %dil
-; AVX2-GATHER-NEXT: movzbl %dil, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $3, %eax, %xmm2, %xmm2
+; AVX2-GATHER-NEXT: vmovd %edi, %xmm2
+; AVX2-GATHER-NEXT: vpbroadcastd %xmm2, %xmm2
+; AVX2-GATHER-NEXT: vpmovsxbd {{.*#+}} xmm3 = [1,2,4,8]
+; AVX2-GATHER-NEXT: vpand %xmm3, %xmm2, %xmm2
+; AVX2-GATHER-NEXT: vpcmpeqd %xmm3, %xmm2, %xmm2
; AVX2-GATHER-NEXT: vgatherqps %xmm2, (,%ymm0), %xmm1
; AVX2-GATHER-NEXT: vmovaps %xmm1, %xmm0
; AVX2-GATHER-NEXT: vzeroupper
@@ -2575,51 +2559,11 @@ define <8 x i32> @masked_gather_v8i32_v8i32(i8 %trigger) {
;
; AVX2-GATHER-LABEL: masked_gather_v8i32_v8i32:
; AVX2-GATHER: # %bb.0:
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb $5, %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: movl %edi, %ecx
-; AVX2-GATHER-NEXT: shrb $4, %cl
-; AVX2-GATHER-NEXT: movzbl %cl, %ecx
-; AVX2-GATHER-NEXT: andl $1, %ecx
-; AVX2-GATHER-NEXT: negl %ecx
-; AVX2-GATHER-NEXT: vmovd %ecx, %xmm0
-; AVX2-GATHER-NEXT: vpinsrd $1, %eax, %xmm0, %xmm0
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb $6, %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $2, %eax, %xmm0, %xmm0
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb $7, %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $3, %eax, %xmm0, %xmm0
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vmovd %eax, %xmm1
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $1, %eax, %xmm1, %xmm1
-; AVX2-GATHER-NEXT: movl %edi, %eax
-; AVX2-GATHER-NEXT: shrb $2, %al
-; AVX2-GATHER-NEXT: movzbl %al, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $2, %eax, %xmm1, %xmm1
-; AVX2-GATHER-NEXT: shrb $3, %dil
-; AVX2-GATHER-NEXT: movzbl %dil, %eax
-; AVX2-GATHER-NEXT: andl $1, %eax
-; AVX2-GATHER-NEXT: negl %eax
-; AVX2-GATHER-NEXT: vpinsrd $3, %eax, %xmm1, %xmm1
-; AVX2-GATHER-NEXT: vinserti128 $1, %xmm0, %ymm1, %ymm0
+; AVX2-GATHER-NEXT: vmovd %edi, %xmm0
+; AVX2-GATHER-NEXT: vpbroadcastb %xmm0, %ymm0
+; AVX2-GATHER-NEXT: vpmovzxbd {{.*#+}} ymm1 = [1,2,4,8,16,32,64,128]
+; AVX2-GATHER-NEXT: vpand %ymm1, %ymm0, %ymm0
+; AVX2-GATHER-NEXT: vpcmpeqd %ymm1, %ymm0, %ymm0
; AVX2-GATHER-NEXT: vpxor %xmm1, %xmm1, %xmm1
; AVX2-GATHER-NEXT: vmovdqa %ymm0, %ymm2
; AVX2-GATHER-NEXT: vpxor %xmm3, %xmm3, %xmm3
diff --git a/llvm/test/CodeGen/X86/masked_load.ll b/llvm/test/CodeGen/X86/masked_load.ll
index 672ec4038d235..99a8918fef93f 100644
--- a/llvm/test/CodeGen/X86/masked_load.ll
+++ b/llvm/test/CodeGen/X86/masked_load.ll
@@ -113,21 +113,27 @@ define <2 x double> @load_v2f64_i2(i2 %trigger, ptr %addr, <2 x double> %dst) {
; SSE-NEXT: movhps {{.*#+}} xmm0 = xmm0[0,1],mem[0,1]
; SSE-NEXT: retq
;
-; AVX1OR2-LABEL: load_v2f64_i2:
-; AVX1OR2: ## %bb.0:
-; AVX1OR2-NEXT: movl %edi, %eax
-; AVX1OR2-NEXT: andl $1, %eax
-; AVX1OR2-NEXT: negq %rax
-; AVX1OR2-NEXT: vmovq %rax, %xmm1
-; AVX1OR2-NEXT: andb $2, %dil
-; AVX1OR2-NEXT: shrb %dil
-; AVX1OR2-NEXT: movzbl %dil, %eax
-; AVX1OR2-NEXT: negq %rax
-; AVX1OR2-NEXT: vmovq %rax, %xmm2
-; AVX1OR2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
-; AVX1OR2-NEXT: vmaskmovpd (%rsi), %xmm1, %xmm2
-; AVX1OR2-NEXT: vblendvpd %xmm1, %xmm2, %xmm0, %xmm0
-; AVX1OR2-NEXT: retq
+; AVX1-LABEL: load_v2f64_i2:
+; AVX1: ## %bb.0:
+; AVX1-NEXT: vmovd %edi, %xmm1
+; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,1,0,1]
+; AVX1-NEXT: vpmovsxbq {{.*#+}} xmm2 = [1,2]
+; AVX1-NEXT: vpand %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vpcmpeqq %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vmaskmovpd (%rsi), %xmm1, %xmm2
+; AVX1-NEXT: vblendvpd %xmm1, %xmm2, %xmm0, %xmm0
+; AVX1-NEXT: retq
+;
+; AVX2-LABEL: load_v2f64_i2:
+; AVX2: ## %bb.0:
+; AVX2-NEXT: vmovd %edi, %xmm1
+; AVX2-NEXT: vpbroadcastd %xmm1, %xmm1
+; AVX2-NEXT: vpmovsxbq {{.*#+}} xmm2 = [1,2]
+; AVX2-NEXT: vpand %xmm2, %xmm1, %xmm1
+; AVX2-NEXT: vpcmpeqq %xmm2, %xmm1, %xmm1
+; AVX2-NEXT: vmaskmovpd (%rsi), %xmm1, %xmm2
+; AVX2-NEXT: vblendvpd %xmm1, %xmm2, %xmm0, %xmm0
+; AVX2-NEXT: retq
;
; AVX512F-LABEL: load_v2f64_i2:
; AVX512F: ## %bb.0:
@@ -281,29 +287,14 @@ define <4 x double> @load_v4f64_i4(i4 %trigger, ptr %addr, <4 x double> %dst) {
;
; AVX1-LABEL: load_v4f64_i4:
; AVX1: ## %bb.0:
-; AVX1-NEXT: movl %edi, %eax
-; AVX1-NEXT: andl $1, %eax
-; AVX1-NEXT: negl %eax
-; AVX1-NEXT: vmovd %eax, %xmm1
-; AVX1-NEXT: movl %edi, %eax
-; AVX1-NEXT: shrb %al
-; AVX1-NEXT: movzbl %al, %eax
-; AVX1-NEXT: andl $1, %eax
-; AVX1-NEXT: negl %eax
-; AVX1-NEXT: vpinsrd $1, %eax, %xmm1, %xmm1
-; AVX1-NEXT: vpmovsxdq %xmm1, %xmm2
-; AVX1-NEXT: movl %edi, %eax
-; AVX1-NEXT: shrb $2, %al
-; AVX1-NEXT: movzbl %al, %eax
-; AVX1-NEXT: andl $1, %eax
-; AVX1-NEXT: negl %eax
-; AVX1-NEXT: vpinsrd $2, %eax, %xmm1, %xmm1
-; AVX1-NEXT: andb $8, %dil
-; AVX1-NEXT: shrb $3, %dil
-; AVX1-NEXT: movzbl %dil, %eax
-; AVX1-NEXT: negl %eax
-; AVX1-NEXT: vpinsrd $3, %eax, %xmm1, %xmm1
-; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[2,2,3,3]
+; AVX1-NEXT: vmovd %edi, %xmm1
+; AVX1-NEXT: vpshufd {{.*#+}} xmm1 = xmm1[0,1,0,1]
+; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm1, %ymm1
+; AVX1-NEXT: vmovaps {{.*#+}} ymm2 = [1,2,4,8]
+; AVX1-NEXT: vandps %ymm2, %ymm1, %ymm1
+; AVX1-NEXT: vpcmpeqq %xmm2, %xmm1, %xmm2
+; AVX1-NEXT: vextractf128 $1, %ymm1, %xmm1
+; AVX1-NEXT: vpcmpeqq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
; AVX1-NEXT: vinsertf128 $1, %xmm1, %ymm2, %ymm1
; AVX1-NEXT: vmaskmovpd (%rsi), %ymm1, %ymm2
; AVX1-NEXT: vblendvpd %ymm1, %ymm2, %ymm0, %ymm0
@@ -311,30 +302,11 @@ define <4 x double> @load_v4f64_i4(i4 %trigger, ptr %addr, <4 x double> %dst) {
;
; AVX2-LABEL: load_v4f64_i4:
; AVX2: ## %bb.0:
-; AVX2-NEXT: movl %edi, %eax
-; AVX2-NEXT: andb $8, %al
-; AVX2-NEXT: shrb $3, %al
-; AVX2-NEXT: movzbl %al, %eax
-; AVX2-NEXT: negq %rax
-; AVX2-NEXT: vmovq %rax, %xmm1
-; AVX2-NEXT: movl %edi, %eax
-; AVX2-NEXT: shrb $2, %al
-; AVX2-NEXT: movzbl %al, %eax
-; AVX2-NEXT: andl $1, %eax
-; AVX2-NEXT: negq %rax
-; AVX2-NEXT: vmovq %rax, %xmm2
-; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm1 = xmm2[0],xmm1[0]
-; AVX2-NEXT: movl %edi, %eax
-; AVX2-NEXT: andl $1, %eax
-; AVX2-NEXT: negq %rax
-; AVX2-NEXT: vmovq %rax, %xmm2
-; AVX2-NEXT: shrb %dil
-; AVX2-NEXT: movzbl %dil, %eax
-; AVX2-NEXT: andl $1, %eax
-; AVX2-NEXT: negq %rax
-; AVX2-NEXT: vmovq %rax, %xmm3
-; AVX2-NEXT: vpunpcklqdq {{.*#+}} xmm2 = xmm2[0],xmm3[0]
-; AVX2-NEXT: vinserti128 $1, %xmm1, %ymm2, %ymm1
+; AVX2-NEXT: vmovd %edi, %xmm1
+; AVX2-NEXT: vpbroadcastd %xmm1, %ymm1
+; AVX2-NEXT: vpmovsxbq {{.*#+}} ymm2 = [1,2,4,8]
+; AVX2-NEXT: vpand %ymm2, %ymm1, %ymm1
+; AVX2-NEXT: vpcmpeqq %ymm2, %ymm1, %ymm1
; AVX2-NEXT: vmaskmovpd (%rsi), %ymm1, %ymm2
; AVX2-NEXT: vblendvpd %ymm1, %ymm2, %ymm0, %ymm0
; AVX2-NEXT: retq
@@ -1552,32 +1524,27 @@ define <4 x float> @load_v4f32_i4(i4 %trigger, ptr %addr, <4 x float> %dst) {
; SSE42-NEXT: insertps {{.*#+}} xmm0 = xmm0[0,1,2],mem[0]
; SSE42-NEXT: retq
;
-; AVX1OR2-LABEL: load_v4f32_i4:
-; AVX1OR2: ## %bb.0:
-; AVX1OR2-NEXT: movl %edi, %eax
-; AVX1OR2-NEXT: andl $1, %eax
-; AVX1OR2-NEXT: negl %eax
-; AVX1OR2-NEXT: vmovd %eax, %xmm1
-; AVX1OR2-NEXT: movl %edi, %eax
-; AVX1OR2-NEXT: shrb %al
-; AVX1OR2-NEXT: movzbl %al, %eax
-; AVX1OR2-NEXT: andl $1, %eax
-; AVX1OR2-NEXT: negl %eax
-; AVX1OR2-NEXT: vpinsrd $1, %eax, %xmm1, %xmm1
-; AVX1OR2-NEXT: movl %edi, %eax
-; AVX1OR2-NEXT: shrb $2, %al
-; AVX1OR2-NEXT: movzbl %al, %eax
-; AVX1OR2-NEXT: andl $1, %eax
-; AVX1OR2-NEXT: negl %eax
-; AVX1OR2-NEXT: vpinsrd $2, %eax, %xmm1, %xmm1
-; AVX1OR2-NEXT:...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
|
||
| // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get | ||
| // split by legalization. | ||
| if (GorS->getOpcode() == ISD::MGATHER && Mask.getOpcode() == ISD::BITCAST && |
There was a problem hiding this comment.
Why only limit to gather here?
There was a problem hiding this comment.
We're only targeting AVX2, as it has no mask registers or scatter instruction
There was a problem hiding this comment.
I've removed the explicit Opcode check for ISD::MGATHER and now rely on canonicalizeBoolMask to early-out because ISD::MSCATTER isn't legal/custom on pre-AVX512 targets
|
|
||
| // Attempt to convert a (vXi1 bitcast(iX Mask)) mask before it might get split | ||
| // by legalization. | ||
| if (DCI.isBeforeLegalizeOps() && Mask.getOpcode() == ISD::BITCAST && |
There was a problem hiding this comment.
The conditions and below code are similar, is it possible to move them into a common function?
There was a problem hiding this comment.
Sure - I've added a canonicalizeBoolMask common helper
| // by legalization. | ||
| if (SDValue NewMask = | ||
| canonicalizeBoolMask(ISD::MLOAD, VT, Mask, DL, DAG, DCI, Subtarget)) { | ||
| NewMask = DAG.getNode(ISD::TRUNCATE, DL, Mask.getValueType(), NewMask); |
There was a problem hiding this comment.
This can be moved into canonicalizeBoolMask too.
| assert(ExtMaskVT.bitsGT(MaskVT) && "Unexpected extension type"); | ||
| SDValue NewMask = combineToExtendBoolVectorInReg( | ||
| ISD::SIGN_EXTEND, DL, ExtMaskVT, Mask, DAG, DCI, Subtarget); | ||
| return DAG.getNode(ISD::TRUNCATE, DL, MaskVT, NewMask); |
There was a problem hiding this comment.
Should we check NewMask before truncate?
…or extension before it might get split by legalisation (llvm#175769) Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to llvm#175385 Fixes llvm#59789
…or extension before it might get split by legalisation (llvm#175769) Masked load/store/gathers often need to bitcast the mask from a bitcasted integer. On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first. This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion. Alternative to llvm#175385 Fixes llvm#59789
Using code/ideas from the x86 backend to optimize a select on a bitcast integer. The previous aarch64 approach was to individually extract the bits from the mask, which is kind of terrible. https://rust.godbolt.org/z/576sndT66 ```llvm define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) { start: %t = load <8 x i32>, ptr %if_true, align 4 %f = load <8 x i32>, ptr %if_false, align 4 %m = bitcast i8 %mask to <8 x i1> %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f store <8 x i32> %s, ptr %out, align 4 ret void } ``` turned into ```asm if_then_else8: // @if_then_else8 sub sp, sp, #16 ubfx w8, w1, #4, #1 and w11, w1, #0x1 ubfx w9, w1, #5, #1 fmov s1, w11 ubfx w10, w1, #1, #1 fmov s0, w8 ubfx w8, w1, #6, #1 ldp q5, q2, [x3] mov v1.h[1], w10 ldp q4, q3, [x2] mov v0.h[1], w9 ubfx w9, w1, #2, #1 mov v1.h[2], w9 ubfx w9, w1, #3, #1 mov v0.h[2], w8 ubfx w8, w1, #7, #1 mov v1.h[3], w9 mov v0.h[3], w8 ushll v1.4s, v1.4h, #0 ushll v0.4s, v0.4h, #0 shl v1.4s, v1.4s, #31 shl v0.4s, v0.4s, #31 cmlt v1.4s, v1.4s, #0 cmlt v0.4s, v0.4s, #0 bsl v1.16b, v4.16b, v5.16b bsl v0.16b, v3.16b, v2.16b stp q1, q0, [x0] add sp, sp, #16 ret ``` With this PR that instead emits ```asm if_then_else8: adrp x8, .LCPI0_1 dup v0.4s, w1 ldr q1, [x8, :lo12:.LCPI0_1] adrp x8, .LCPI0_0 ldr q2, [x8, :lo12:.LCPI0_0] ldp q4, q3, [x2] and v1.16b, v0.16b, v1.16b and v0.16b, v0.16b, v2.16b ldp q5, q2, [x3] cmeq v1.4s, v1.4s, #0 cmeq v0.4s, v0.4s, #0 bsl v1.16b, v2.16b, v3.16b bsl v0.16b, v5.16b, v4.16b stp q0, q1, [x0] ret ``` So substantially shorter. Instead of building the mask element-by-element, this approach (by virtue of not splitting) instead splats the mask value into all vector lanes, performs a bitwise and with powers of 2, and compares with zero to construct the mask vector. cc rust-lang/rust#122376 cc #175769
Using code/ideas from the x86 backend to optimize a select on a bitcast integer. The previous aarch64 approach was to individually extract the bits from the mask, which is kind of terrible. https://rust.godbolt.org/z/576sndT66 ```llvm define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) { start: %t = load <8 x i32>, ptr %if_true, align 4 %f = load <8 x i32>, ptr %if_false, align 4 %m = bitcast i8 %mask to <8 x i1> %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f store <8 x i32> %s, ptr %out, align 4 ret void } ``` turned into ```asm if_then_else8: // @if_then_else8 sub sp, sp, #16 ubfx w8, w1, #4, #1 and w11, w1, #0x1 ubfx w9, w1, #5, #1 fmov s1, w11 ubfx w10, w1, #1, #1 fmov s0, w8 ubfx w8, w1, #6, #1 ldp q5, q2, [x3] mov v1.h[1], w10 ldp q4, q3, [x2] mov v0.h[1], w9 ubfx w9, w1, #2, #1 mov v1.h[2], w9 ubfx w9, w1, #3, #1 mov v0.h[2], w8 ubfx w8, w1, #7, #1 mov v1.h[3], w9 mov v0.h[3], w8 ushll v1.4s, v1.4h, #0 ushll v0.4s, v0.4h, #0 shl v1.4s, v1.4s, #31 shl v0.4s, v0.4s, #31 cmlt v1.4s, v1.4s, #0 cmlt v0.4s, v0.4s, #0 bsl v1.16b, v4.16b, v5.16b bsl v0.16b, v3.16b, v2.16b stp q1, q0, [x0] add sp, sp, #16 ret ``` With this PR that instead emits ```asm if_then_else8: adrp x8, .LCPI0_1 dup v0.4s, w1 ldr q1, [x8, :lo12:.LCPI0_1] adrp x8, .LCPI0_0 ldr q2, [x8, :lo12:.LCPI0_0] ldp q4, q3, [x2] and v1.16b, v0.16b, v1.16b and v0.16b, v0.16b, v2.16b ldp q5, q2, [x3] cmeq v1.4s, v1.4s, #0 cmeq v0.4s, v0.4s, #0 bsl v1.16b, v2.16b, v3.16b bsl v0.16b, v5.16b, v4.16b stp q0, q1, [x0] ret ``` So substantially shorter. Instead of building the mask element-by-element, this approach (by virtue of not splitting) instead splats the mask value into all vector lanes, performs a bitwise and with powers of 2, and compares with zero to construct the mask vector. cc rust-lang/rust#122376 cc llvm/llvm-project#175769
Using code/ideas from the x86 backend to optimize a select on a bitcast integer. The previous aarch64 approach was to individually extract the bits from the mask, which is kind of terrible. https://rust.godbolt.org/z/576sndT66 ```llvm define void @if_then_else8(ptr %out, i8 %mask, ptr %if_true, ptr %if_false) { start: %t = load <8 x i32>, ptr %if_true, align 4 %f = load <8 x i32>, ptr %if_false, align 4 %m = bitcast i8 %mask to <8 x i1> %s = select <8 x i1> %m, <8 x i32> %t, <8 x i32> %f store <8 x i32> %s, ptr %out, align 4 ret void } ``` turned into ```asm if_then_else8: // @if_then_else8 sub sp, sp, llvm#16 ubfx w8, w1, llvm#4, llvm#1 and w11, w1, #0x1 ubfx w9, w1, llvm#5, llvm#1 fmov s1, w11 ubfx w10, w1, llvm#1, llvm#1 fmov s0, w8 ubfx w8, w1, llvm#6, llvm#1 ldp q5, q2, [x3] mov v1.h[1], w10 ldp q4, q3, [x2] mov v0.h[1], w9 ubfx w9, w1, llvm#2, llvm#1 mov v1.h[2], w9 ubfx w9, w1, llvm#3, llvm#1 mov v0.h[2], w8 ubfx w8, w1, llvm#7, llvm#1 mov v1.h[3], w9 mov v0.h[3], w8 ushll v1.4s, v1.4h, #0 ushll v0.4s, v0.4h, #0 shl v1.4s, v1.4s, llvm#31 shl v0.4s, v0.4s, llvm#31 cmlt v1.4s, v1.4s, #0 cmlt v0.4s, v0.4s, #0 bsl v1.16b, v4.16b, v5.16b bsl v0.16b, v3.16b, v2.16b stp q1, q0, [x0] add sp, sp, llvm#16 ret ``` With this PR that instead emits ```asm if_then_else8: adrp x8, .LCPI0_1 dup v0.4s, w1 ldr q1, [x8, :lo12:.LCPI0_1] adrp x8, .LCPI0_0 ldr q2, [x8, :lo12:.LCPI0_0] ldp q4, q3, [x2] and v1.16b, v0.16b, v1.16b and v0.16b, v0.16b, v2.16b ldp q5, q2, [x3] cmeq v1.4s, v1.4s, #0 cmeq v0.4s, v0.4s, #0 bsl v1.16b, v2.16b, v3.16b bsl v0.16b, v5.16b, v4.16b stp q0, q1, [x0] ret ``` So substantially shorter. Instead of building the mask element-by-element, this approach (by virtue of not splitting) instead splats the mask value into all vector lanes, performs a bitwise and with powers of 2, and compares with zero to construct the mask vector. cc rust-lang/rust#122376 cc llvm#175769
Masked load/store/gathers often need to bitcast the mask from a bitcasted integer.
On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first.
This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg helper functions to canonicalise the masks, similar to what we already do for vselect expansion.
Alternative to #175385
Fixes #59789