Skip to content

Commit

Permalink
[RISCV][GISEL] Add support for lowerFormalArguments that contain scal…
Browse files Browse the repository at this point in the history
…able vector types (#70882)

Scalable vector types from LLVM IR can be lowered to scalable vector
types in MIR according to the RISCVAssignFn.
  • Loading branch information
michaelmaitland authored Nov 14, 2023
1 parent 506a30d commit a7bbcc4
Show file tree
Hide file tree
Showing 9 changed files with 984 additions and 9 deletions.
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
if (PartLLT.isVector() == LLTy.isVector() &&
PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() &&
(!PartLLT.isVector() ||
PartLLT.getNumElements() == LLTy.getNumElements()) &&
PartLLT.getElementCount() == LLTy.getElementCount()) &&
OrigRegs.size() == 1 && Regs.size() == 1) {
Register SrcReg = Regs[0];

Expand Down Expand Up @@ -406,6 +406,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
// If PartLLT is a mismatched vector in both number of elements and element
// size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to
// have the same elt type, i.e. v4s32.
// TODO: Extend this coersion to element multiples other than just 2.
if (PartLLT.getSizeInBits() > LLTy.getSizeInBits() &&
PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 &&
Regs.size() == 1) {
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,16 +1065,16 @@ void MachineIRBuilder::validateTruncExt(const LLT DstTy, const LLT SrcTy,
#ifndef NDEBUG
if (DstTy.isVector()) {
assert(SrcTy.isVector() && "mismatched cast between vector and non-vector");
assert(SrcTy.getNumElements() == DstTy.getNumElements() &&
assert(SrcTy.getElementCount() == DstTy.getElementCount() &&
"different number of elements in a trunc/ext");
} else
assert(DstTy.isScalar() && SrcTy.isScalar() && "invalid extend/trunc");

if (IsExtend)
assert(DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
assert(TypeSize::isKnownGT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
"invalid narrowing extend");
else
assert(DstTy.getSizeInBits() < SrcTy.getSizeInBits() &&
assert(TypeSize::isKnownLT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
"invalid widening trunc");
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/LowLevelType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace llvm;

LLT::LLT(MVT VT) {
if (VT.isVector()) {
bool asVector = VT.getVectorMinNumElements() > 1;
bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector();
init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
/*AddressSpace=*/0);
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/MachineVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ bool MachineVerifier::verifyVectorElementMatch(LLT Ty0, LLT Ty1,
return false;
}

if (Ty0.isVector() && Ty0.getNumElements() != Ty1.getNumElements()) {
if (Ty0.isVector() && Ty0.getElementCount() != Ty1.getElementCount()) {
report("operand types must preserve number of vector elements", MI);
return false;
}
Expand Down
37 changes: 35 additions & 2 deletions llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "RISCVCallLowering.h"
#include "RISCVISelLowering.h"
#include "RISCVMachineFunctionInfo.h"
#include "RISCVSubtarget.h"
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
Expand Down Expand Up @@ -185,6 +186,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
const DataLayout &DL = MF.getDataLayout();
const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();

if (LocVT.isScalableVector())
MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();

if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
*Subtarget.getTargetLowering(),
Expand Down Expand Up @@ -301,8 +305,31 @@ struct RISCVCallReturnHandler : public RISCVIncomingValueHandler {
RISCVCallLowering::RISCVCallLowering(const RISCVTargetLowering &TLI)
: CallLowering(&TLI) {}

/// Return true if scalable vector with ScalarTy is legal for lowering.
static bool isLegalElementTypeForRVV(Type *EltTy,
const RISCVSubtarget &Subtarget) {
if (EltTy->isPointerTy())
return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true;
if (EltTy->isIntegerTy(1) || EltTy->isIntegerTy(8) ||
EltTy->isIntegerTy(16) || EltTy->isIntegerTy(32))
return true;
if (EltTy->isIntegerTy(64))
return Subtarget.hasVInstructionsI64();
if (EltTy->isHalfTy())
return Subtarget.hasVInstructionsF16();
if (EltTy->isBFloatTy())
return Subtarget.hasVInstructionsBF16();
if (EltTy->isFloatTy())
return Subtarget.hasVInstructionsF32();
if (EltTy->isDoubleTy())
return Subtarget.hasVInstructionsF64();
return false;
}

// TODO: Support all argument types.
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget) {
// TODO: Remove IsLowerArgs argument by adding support for vectors in lowerCall.
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget,
bool IsLowerArgs = false) {
// TODO: Integers larger than 2*XLen are passed indirectly which is not
// supported yet.
if (T->isIntegerTy())
Expand All @@ -311,6 +338,11 @@ static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget) {
return true;
if (T->isPointerTy())
return true;
// TODO: Support fixed vector types.
if (IsLowerArgs && T->isVectorTy() && Subtarget.hasVInstructions() &&
T->isScalableTy() &&
isLegalElementTypeForRVV(T->getScalarType(), Subtarget))
return true;
return false;
}

Expand Down Expand Up @@ -398,7 +430,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
const RISCVSubtarget &Subtarget =
MIRBuilder.getMF().getSubtarget<RISCVSubtarget>();
for (auto &Arg : F.args()) {
if (!isSupportedArgumentType(Arg.getType(), Subtarget))
if (!isSupportedArgumentType(Arg.getType(), Subtarget,
/*IsLowerArgs=*/true))
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ declare <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
<vscale x 1 x i8>,
i64)

; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to lower arguments{{.*}}scalable_arg
; FALLBACK_WITH_REPORT_ERR: <unknown>:0:0: unable to translate instruction: call:
; FALLBACK-WITH-REPORT-OUT-LABEL: scalable_arg
define <vscale x 1 x i8> @scalable_arg(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i64 %2) nounwind {
entry:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; RUN: not --crash llc -mtriple=riscv32 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
; RUN: not --crash llc -mtriple=riscv64 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s

; The purpose of this test is to show that the compiler throws an error when
; there is no support for bf16 vectors. If the compiler did not throw an error,
; then it will try to scalarize the argument to an s32, which may drop elements.
define void @test_args_nxv1bf16(<vscale x 1 x bfloat> %a) {
entry:
ret void
}

; CHECK: LLVM ERROR: unable to lower arguments: ptr (in function: test_args_nxv1bf16)


Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
; RUN: not --crash llc -mtriple=riscv32 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
; RUN: not --crash llc -mtriple=riscv64 -mattr=+v -global-isel -stop-after=irtranslator \
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s

; The purpose of this test is to show that the compiler throws an error when
; there is no support for f16 vectors. If the compiler did not throw an error,
; then it will try to scalarize the argument to an s32, which may drop elements.
define void @test_args_nxv1f16(<vscale x 1 x half> %a) {
entry:
ret void
}

; CHECK: LLVM ERROR: unable to lower arguments: ptr (in function: test_args_nxv1f16)


Loading

0 comments on commit a7bbcc4

Please sign in to comment.