-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Match strided load via DAG combine #66800
[RISCV] Match strided load via DAG combine #66800
Conversation
This change matches a masked.stride.load from a mgather node whose index operand is a strided sequence. We can reuse the VID matching from build_vector lowering for this purpose. Note that this duplicates the matching done at IR by RISCVGatherScatterLowering.cpp. Now that we can widen gathers to a wider SEW, I don't see a good way to remove this duplication. The only obvious alternative is to move thw widening transform to IR, but that's a no-go as I want other DAGs to run first. I think we should just live with the duplication - particularly since the reuse is isSimpleVIDSequence means the duplication is somewhat minimal.
@llvm/pr-subscribers-backend-risc-v ChangesThis change matches a masked.stride.load from a mgather node whose index operand is a strided sequence. We can reuse the VID matching from build_vector lowering for this purpose. Note that this duplicates the matching done at IR by RISCVGatherScatterLowering.cpp. Now that we can widen gathers to a wider SEW, I don't see a good way to remove this duplication. The only obvious alternative is to move thw widening transform to IR, but that's a no-go as I want other DAGs to run first. I think we should just live with the duplication - particularly since the reuse is isSimpleVIDSequence means the duplication is somewhat minimal. Full diff: https://github.com/llvm/llvm-project/pull/66800.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0214bd1d7dda326..f1cea6c6756f4fc 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14055,6 +14055,35 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
MGN->getBasePtr(), Index, ScaleOp},
MGN->getMemOperand(), IndexType, MGN->getExtensionType());
+ if (Index.getOpcode() == ISD::BUILD_VECTOR &&
+ MGN->getExtensionType() == ISD::NON_EXTLOAD) {
+ if (std::optional<VIDSequence> SimpleVID = isSimpleVIDSequence(Index);
+ SimpleVID && SimpleVID->StepDenominator == 1) {
+ const int64_t StepNumerator = SimpleVID->StepNumerator;
+ const int64_t Addend = SimpleVID->Addend;
+
+ // Note: We don't need to check alignment here since (by assumption
+ // from the existance of the gather), our offsets must be sufficiently
+ // aligned.
+
+ const EVT PtrVT = getPointerTy(DAG.getDataLayout());
+ assert(MGN->getBasePtr()->getValueType(0) == PtrVT);
+ assert(IndexType == ISD::UNSIGNED_SCALED);
+ SDValue BasePtr = DAG.getNode(ISD::ADD, DL, PtrVT, MGN->getBasePtr(),
+ DAG.getConstant(Addend, DL, PtrVT));
+
+ SDVTList VTs = DAG.getVTList({VT, MVT::Other});
+ SDValue IntID =
+ DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
+ XLenVT);
+ SDValue Ops[] =
+ {MGN->getChain(), IntID, MGN->getPassThru(), BasePtr,
+ DAG.getConstant(StepNumerator, DL, XLenVT), MGN->getMask()};
+ return DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs,
+ Ops, VT, MGN->getMemOperand());
+ }
+ }
+
SmallVector<int> ShuffleMask;
if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
matchIndexAsShuffle(VT, Index, MGN->getMask(), ShuffleMask)) {
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index 813e16952eca33c..49724cbc4418252 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -13567,20 +13567,16 @@ define <8 x i16> @mgather_strided_unaligned(ptr %base) {
define <8 x i16> @mgather_strided_2xSEW(ptr %base) {
; RV32-LABEL: mgather_strided_2xSEW:
; RV32: # %bb.0:
-; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
-; RV32-NEXT: vid.v v8
-; RV32-NEXT: vsll.vi v9, v8, 3
-; RV32-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; RV32-NEXT: vluxei8.v v8, (a0), v9
+; RV32-NEXT: li a1, 8
+; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV32-NEXT: vlse32.v v8, (a0), a1
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_strided_2xSEW:
; RV64V: # %bb.0:
-; RV64V-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
-; RV64V-NEXT: vid.v v8
-; RV64V-NEXT: vsll.vi v9, v8, 3
-; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; RV64V-NEXT: vluxei8.v v8, (a0), v9
+; RV64V-NEXT: li a1, 8
+; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV64V-NEXT: vlse32.v v8, (a0), a1
; RV64V-NEXT: ret
;
; RV64ZVE32F-LABEL: mgather_strided_2xSEW:
@@ -13684,22 +13680,18 @@ define <8 x i16> @mgather_strided_2xSEW(ptr %base) {
define <8 x i16> @mgather_strided_2xSEW_with_offset(ptr %base) {
; RV32-LABEL: mgather_strided_2xSEW_with_offset:
; RV32: # %bb.0:
-; RV32-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
-; RV32-NEXT: vid.v v8
-; RV32-NEXT: vsll.vi v8, v8, 3
-; RV32-NEXT: vadd.vi v9, v8, 4
-; RV32-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; RV32-NEXT: vluxei8.v v8, (a0), v9
+; RV32-NEXT: addi a0, a0, 4
+; RV32-NEXT: li a1, 8
+; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV32-NEXT: vlse32.v v8, (a0), a1
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_strided_2xSEW_with_offset:
; RV64V: # %bb.0:
-; RV64V-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
-; RV64V-NEXT: vid.v v8
-; RV64V-NEXT: vsll.vi v8, v8, 3
-; RV64V-NEXT: vadd.vi v9, v8, 4
-; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; RV64V-NEXT: vluxei8.v v8, (a0), v9
+; RV64V-NEXT: addi a0, a0, 4
+; RV64V-NEXT: li a1, 8
+; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV64V-NEXT: vlse32.v v8, (a0), a1
; RV64V-NEXT: ret
;
; RV64ZVE32F-LABEL: mgather_strided_2xSEW_with_offset:
@@ -13804,20 +13796,18 @@ define <8 x i16> @mgather_strided_2xSEW_with_offset(ptr %base) {
define <8 x i16> @mgather_reverse_unit_strided_2xSEW(ptr %base) {
; RV32-LABEL: mgather_reverse_unit_strided_2xSEW:
; RV32: # %bb.0:
-; RV32-NEXT: lui a1, 65858
-; RV32-NEXT: addi a1, a1, -2020
+; RV32-NEXT: addi a0, a0, 28
+; RV32-NEXT: li a1, -4
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT: vmv.s.x v9, a1
-; RV32-NEXT: vluxei8.v v8, (a0), v9
+; RV32-NEXT: vlse32.v v8, (a0), a1
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_reverse_unit_strided_2xSEW:
; RV64V: # %bb.0:
-; RV64V-NEXT: lui a1, 65858
-; RV64V-NEXT: addiw a1, a1, -2020
+; RV64V-NEXT: addi a0, a0, 28
+; RV64V-NEXT: li a1, -4
; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV64V-NEXT: vmv.s.x v9, a1
-; RV64V-NEXT: vluxei8.v v8, (a0), v9
+; RV64V-NEXT: vlse32.v v8, (a0), a1
; RV64V-NEXT: ret
;
; RV64ZVE32F-LABEL: mgather_reverse_unit_strided_2xSEW:
@@ -13922,20 +13912,18 @@ define <8 x i16> @mgather_reverse_unit_strided_2xSEW(ptr %base) {
define <8 x i16> @mgather_reverse_strided_2xSEW(ptr %base) {
; RV32-LABEL: mgather_reverse_strided_2xSEW:
; RV32: # %bb.0:
-; RV32-NEXT: lui a1, 16577
-; RV32-NEXT: addi a1, a1, 1052
+; RV32-NEXT: addi a0, a0, 28
+; RV32-NEXT: li a1, -8
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT: vmv.s.x v9, a1
-; RV32-NEXT: vluxei8.v v8, (a0), v9
+; RV32-NEXT: vlse32.v v8, (a0), a1
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_reverse_strided_2xSEW:
; RV64V: # %bb.0:
-; RV64V-NEXT: lui a1, 16577
-; RV64V-NEXT: addiw a1, a1, 1052
+; RV64V-NEXT: addi a0, a0, 28
+; RV64V-NEXT: li a1, -8
; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV64V-NEXT: vmv.s.x v9, a1
-; RV64V-NEXT: vluxei8.v v8, (a0), v9
+; RV64V-NEXT: vlse32.v v8, (a0), a1
; RV64V-NEXT: ret
;
; RV64ZVE32F-LABEL: mgather_reverse_strided_2xSEW:
@@ -14386,20 +14374,16 @@ define <8 x i16> @mgather_gather_2xSEW_unaligned2(ptr %base) {
define <8 x i16> @mgather_gather_4xSEW(ptr %base) {
; RV32V-LABEL: mgather_gather_4xSEW:
; RV32V: # %bb.0:
-; RV32V-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; RV32V-NEXT: vid.v v8
-; RV32V-NEXT: vsll.vi v9, v8, 4
-; RV32V-NEXT: vsetvli zero, zero, e64, m1, ta, ma
-; RV32V-NEXT: vluxei8.v v8, (a0), v9
+; RV32V-NEXT: li a1, 16
+; RV32V-NEXT: vsetivli zero, 2, e64, m1, ta, ma
+; RV32V-NEXT: vlse64.v v8, (a0), a1
; RV32V-NEXT: ret
;
; RV64V-LABEL: mgather_gather_4xSEW:
; RV64V: # %bb.0:
-; RV64V-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; RV64V-NEXT: vid.v v8
-; RV64V-NEXT: vsll.vi v9, v8, 4
-; RV64V-NEXT: vsetvli zero, zero, e64, m1, ta, ma
-; RV64V-NEXT: vluxei8.v v8, (a0), v9
+; RV64V-NEXT: li a1, 16
+; RV64V-NEXT: vsetivli zero, 2, e64, m1, ta, ma
+; RV64V-NEXT: vlse64.v v8, (a0), a1
; RV64V-NEXT: ret
;
; RV32ZVE32F-LABEL: mgather_gather_4xSEW:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I presume we don't need to worry about scalable vectors if we don't handle them in #66694
This change matches a masked.stride.load from a mgather node whose index operand is a strided sequence. We can reuse the VID matching from build_vector lowering for this purpose.
Note that this duplicates the matching done at IR by RISCVGatherScatterLowering.cpp. Now that we can widen gathers to a wider SEW, I don't see a good way to remove this duplication. The only obvious alternative is to move thw widening transform to IR, but that's a no-go as I want other DAGs to run first. I think we should just live with the duplication - particularly since the reuse is isSimpleVIDSequence means the duplication is somewhat minimal.