From e840baec19387e935188215cc2c3e76e22b9e118 Mon Sep 17 00:00:00 2001 From: hanbeom Date: Tue, 16 Jul 2024 14:14:40 +0900 Subject: [PATCH] [RISCV] Support saturated truncate Add support for `ISD::TRUNCATE_[US]SAT`. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 24 +++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d40d4997d76149f..ead79ec47becee3 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -853,7 +853,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, // Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR_VL" // nodes which truncate by one power of two at a time. - setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction({ISD::TRUNCATE, ISD::TRUNCATE_SSAT_S, + ISD::TRUNCATE_SSAT_U, ISD::TRUNCATE_USAT_U}, + VT, Custom); // Custom-lower insert/extract operations to simplify patterns. setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT}, VT, @@ -1168,7 +1170,9 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::SELECT, VT, Custom); - setOperationAction(ISD::TRUNCATE, VT, Custom); + setOperationAction({ISD::TRUNCATE, ISD::TRUNCATE_SSAT_S, + ISD::TRUNCATE_SSAT_U, ISD::TRUNCATE_USAT_U}, + VT, Custom); setOperationAction(ISD::BITCAST, VT, Custom); @@ -6395,6 +6399,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return DAG.getNode(RISCVISD::BREV8, DL, VT, BSwap); } case ISD::TRUNCATE: + case ISD::TRUNCATE_SSAT_S: + case ISD::TRUNCATE_SSAT_U: + case ISD::TRUNCATE_USAT_U: // Only custom-lower vector truncates if (!Op.getSimpleValueType().isVector()) return Op; @@ -8234,7 +8241,8 @@ SDValue RISCVTargetLowering::lowerVectorMaskTruncLike(SDValue Op, SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op, SelectionDAG &DAG) const { - bool IsVPTrunc = Op.getOpcode() == ISD::VP_TRUNCATE; + unsigned Opc = Op.getOpcode(); + bool IsVPTrunc = Opc == ISD::VP_TRUNCATE; SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); @@ -8279,10 +8287,18 @@ SDValue RISCVTargetLowering::lowerVectorTruncLike(SDValue Op, getDefaultVLOps(SrcVT, ContainerVT, DL, DAG, Subtarget); } + unsigned NewOpc; + if (Opc == ISD::TRUNCATE_SSAT_S) + NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT; + else if (Opc == ISD::TRUNCATE_SSAT_U || Opc == ISD::TRUNCATE_USAT_U) + NewOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT; + else + NewOpc = RISCVISD::TRUNCATE_VECTOR_VL; + do { SrcEltVT = MVT::getIntegerVT(SrcEltVT.getSizeInBits() / 2); MVT ResultVT = ContainerVT.changeVectorElementType(SrcEltVT); - Result = DAG.getNode(RISCVISD::TRUNCATE_VECTOR_VL, DL, ResultVT, Result, + Result = DAG.getNode(NewOpc, DL, ResultVT, Result, Mask, VL); } while (SrcEltVT != DstEltVT);