From 9fdddf081365a1bfb85d902d4b32309a1512d0ee Mon Sep 17 00:00:00 2001 From: Victor Perez Date: Tue, 19 Mar 2024 18:12:08 +0100 Subject: [PATCH] Allow specialization constant length arrays (#2396) SPIR-V arrays can have their length specified via a specialization constant. In case an `alloca` instruction uses a specialization constant `s` as length, generate an array variable of length `s`. Signed-off-by: Victor Perez --- lib/SPIRV/SPIRVWriter.cpp | 25 +++++- lib/SPIRV/libSPIRV/SPIRVModule.cpp | 4 +- lib/SPIRV/libSPIRV/SPIRVModule.h | 2 +- lib/SPIRV/libSPIRV/SPIRVType.cpp | 9 +-- lib/SPIRV/libSPIRV/SPIRVType.h | 4 +- .../spec-constant-length-array.ll | 77 +++++++++++++++++++ 6 files changed, 109 insertions(+), 12 deletions(-) create mode 100644 test/SpecConstants/spec-constant-length-array.ll diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 6377172fad..0e2e94620c 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -2215,6 +2215,29 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB, if (AllocaInst *Alc = dyn_cast(V)) { SPIRVType *TranslatedTy = transScavengedType(V); if (Alc->isArrayAllocation()) { + SPIRVValue *Length = transValue(Alc->getArraySize(), BB); + assert(Length && "Couldn't translate array size!"); + + if (isSpecConstantOpCode(Length->getOpCode())) { + // SPIR-V arrays length can be expressed using a specialization + // constant. + // + // Spec Constant Length Arrays need special treatment, as the allocation + // type will be 'OpTypePointer(Function, OpTypeArray(ElementType, + // Length))', we need to bitcast the obtained pointer to the expected + // type: 'OpTypePointer(Function, ElementType). + SPIRVType *AllocationType = BM->addPointerType( + StorageClassFunction, + BM->addArrayType(transType(Alc->getAllocatedType()), Length)); + SPIRVValue *Arr = BM->addVariable( + AllocationType, false, spv::internal::LinkageTypeInternal, nullptr, + Alc->getName().str() + "_alloca", StorageClassFunction, BB); + // Manually set alignment. OpBitcast created below will be decorated as + // that's the SPIR-V value mapped to the original LLVM one. + transAlign(Alc, Arr); + return mapValue(V, BM->addUnaryInst(OpBitcast, TranslatedTy, Arr, BB)); + } + if (!BM->checkExtension(ExtensionID::SPV_INTEL_variable_length_array, SPIRVEC_InvalidInstruction, toString(Alc) + @@ -2222,8 +2245,6 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB, "SPV_INTEL_variable_length_array extension.")) return nullptr; - SPIRVValue *Length = transValue(Alc->getArraySize(), BB); - assert(Length && "Couldn't translate array size!"); return mapValue(V, BM->addInstTemplate(OpVariableLengthArrayINTEL, {Length->getId()}, BB, TranslatedTy)); diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp index 40e9d3a067..46d1b3a7bc 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -245,7 +245,7 @@ class SPIRVModuleImpl : public SPIRVModule { // Type creation functions template T *addType(T *Ty); - SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) override; + SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVValue *) override; SPIRVTypeBool *addBoolType() override; SPIRVTypeFloat *addFloatType(unsigned BitWidth) override; SPIRVTypeFunction *addFunctionType(SPIRVType *, @@ -968,7 +968,7 @@ SPIRVTypeVoid *SPIRVModuleImpl::addVoidType() { } SPIRVTypeArray *SPIRVModuleImpl::addArrayType(SPIRVType *ElementType, - SPIRVConstant *Length) { + SPIRVValue *Length) { return addType(new SPIRVTypeArray(this, getId(), ElementType, Length)); } diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.h b/lib/SPIRV/libSPIRV/SPIRVModule.h index 4ec78631fc..5e3743fc8e 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.h +++ b/lib/SPIRV/libSPIRV/SPIRVModule.h @@ -240,7 +240,7 @@ class SPIRVModule { virtual void eraseInstruction(SPIRVInstruction *, SPIRVBasicBlock *) = 0; // Type creation functions - virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) = 0; + virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVValue *) = 0; virtual SPIRVTypeBool *addBoolType() = 0; virtual SPIRVTypeFloat *addFloatType(unsigned) = 0; virtual SPIRVTypeFunction * diff --git a/lib/SPIRV/libSPIRV/SPIRVType.cpp b/lib/SPIRV/libSPIRV/SPIRVType.cpp index 674003b0d7..6f634d284c 100644 --- a/lib/SPIRV/libSPIRV/SPIRVType.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVType.cpp @@ -56,7 +56,7 @@ uint64_t SPIRVType::getArrayLength() const { const SPIRVTypeArray *AsArray = static_cast(this); assert(AsArray->getLength()->getOpCode() == OpConstant && "getArrayLength can only be called with constant array lengths"); - return AsArray->getLength()->getZExtIntValue(); + return static_cast(AsArray->getLength())->getZExtIntValue(); } SPIRVWord SPIRVType::getBitWidth() const { @@ -263,7 +263,7 @@ void SPIRVTypeStruct::setPacked(bool Packed) { } SPIRVTypeArray::SPIRVTypeArray(SPIRVModule *M, SPIRVId TheId, - SPIRVType *TheElemType, SPIRVConstant *TheLength) + SPIRVType *TheElemType, SPIRVValue *TheLength) : SPIRVType(M, 4, OpTypeArray, TheId), ElemType(TheElemType), Length(TheLength->getId()) { validate(); @@ -273,11 +273,10 @@ void SPIRVTypeArray::validate() const { SPIRVEntry::validate(); ElemType->validate(); assert(getValue(Length)->getType()->isTypeInt()); + assert(isConstantOpCode(getValue(Length)->getOpCode())); } -SPIRVConstant *SPIRVTypeArray::getLength() const { - return get(Length); -} +SPIRVValue *SPIRVTypeArray::getLength() const { return getValue(Length); } _SPIRV_IMP_ENCDEC3(SPIRVTypeArray, Id, ElemType, Length) diff --git a/lib/SPIRV/libSPIRV/SPIRVType.h b/lib/SPIRV/libSPIRV/SPIRVType.h index 894d7b339f..086fe1a675 100644 --- a/lib/SPIRV/libSPIRV/SPIRVType.h +++ b/lib/SPIRV/libSPIRV/SPIRVType.h @@ -390,13 +390,13 @@ class SPIRVTypeArray : public SPIRVType { public: // Complete constructor SPIRVTypeArray(SPIRVModule *M, SPIRVId TheId, SPIRVType *TheElemType, - SPIRVConstant *TheLength); + SPIRVValue *TheLength); // Incomplete constructor SPIRVTypeArray() : SPIRVType(OpTypeArray), ElemType(nullptr), Length(SPIRVID_INVALID) {} SPIRVType *getElementType() const { return ElemType; } - SPIRVConstant *getLength() const; + SPIRVValue *getLength() const; SPIRVCapVec getRequiredCapability() const override { return getElementType()->getRequiredCapability(); } diff --git a/test/SpecConstants/spec-constant-length-array.ll b/test/SpecConstants/spec-constant-length-array.ll new file mode 100644 index 0000000000..a88b2b8cb0 --- /dev/null +++ b/test/SpecConstants/spec-constant-length-array.ll @@ -0,0 +1,77 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv -spirv-text -o - %t.bc | FileCheck %s --check-prefix CHECK-SPV +; RUN: llvm-spirv -o %t.spv %t.bc +; RUN: spirv-val %t.spv +; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s --check-prefix CHECK-LLVM + +; CHECK-SPV-DAG: Decorate [[#I64_CONST:]] SpecId [[#]] +; CHECK-SPV-DAG: Decorate [[#I32_CONST:]] SpecId [[#]] +; CHECK-SPV-DAG: Decorate [[#I8_CONST:]] SpecId [[#]] +; CHECK-SPV-DAG: Decorate [[#SCLA_0:]] Alignment 4 +; CHECK-SPV-DAG: Decorate [[#SCLA_1:]] Alignment 2 +; CHECK-SPV-DAG: Decorate [[#SCLA_2:]] Alignment 16 + +; CHECK-SPV-DAG: TypeInt [[#I64_TY:]] 64 +; CHECK-SPV-DAG: TypeInt [[#I32_TY:]] 32 +; CHECK-SPV-DAG: TypeInt [[#I8_TY:]] 8 + +; CHECK-SPV-DAG: SpecConstant [[#I64_TY]] [[#LENGTH_0:]] +; CHECK-SPV-DAG: SpecConstant [[#I32_TY]] [[#LENGTH_1:]] +; CHECK-SPV-DAG: SpecConstant [[#I8_TY]] [[#LENGTH_2:]] + +; CHECK-SPV-DAG: TypeFloat [[#FLOAT_TY:]] 32 +; CHECK-SPV-DAG: TypePointer [[#FLOAT_PTR_TY:]] [[#FUNCTION_SC:]] [[#FLOAT_TY]] +; CHECK-SPV-DAG: TypeArray [[#ARR_TY_0:]] [[#FLOAT_TY]] [[#LENGTH_0]] +; CHECK-SPV-DAG: TypePointer [[#ARR_PTR_TY_0:]] [[#FUNCTION_SC]] [[#ARR_TY_0]] +; CHECK-SPV-DAG: TypePointer [[#I8_PTR_TY:]] [[#FUNCTION_SC]] [[#I8_TY]] +; CHECK-SPV-DAG: TypeArray [[#ARR_TY_1:]] [[#I8_TY]] [[#LENGTH_1]] +; CHECK-SPV-DAG: TypePointer [[#ARR_PTR_TY_1:]] [[#FUNCTION_SC]] [[#ARR_TY_1]] +; CHECK-SPV-DAG: TypeFloat [[#DOUBLE_TY:]] 64 +; CHECK-SPV-DAG: TypeStruct [[#STR_TY:]] [[#DOUBLE_TY]] [[#DOUBLE_TY]] +; CHECK-SPV-DAG: TypePointer [[#STR_PTR_TY:]] [[#FUNCTION_SC]] [[#STR_TY]] +; CHECK-SPV-DAG: TypeArray [[#ARR_TY_2:]] [[#STR_TY]] [[#LENGTH_2]] +; CHECK-SPV-DAG: TypePointer [[#ARR_PTR_TY_2:]] [[#FUNCTION_SC]] [[#ARR_TY_2:]] + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +%struct_type = type { double, double } + +define spir_kernel void @test() { + entry: + %length0 = call i64 @_Z20__spirv_SpecConstantix(i32 0, i64 1), !SYCL_SPEC_CONST_SYM_ID !0 + %length1 = call i32 @_Z20__spirv_SpecConstantii(i32 1, i32 2), !SYCL_SPEC_CONST_SYM_ID !1 + %length2 = call i8 @_Z20__spirv_SpecConstantic(i32 2, i8 4), !SYCL_SPEC_CONST_SYM_ID !2 + + ; CHECK-SPV: Variable [[#ARR_PTR_TY_0]] [[#SCLA_0]] [[#FUNCTION_SC]] + ; CHECK-SPV: Variable [[#ARR_PTR_TY_1]] [[#SCLA_1]] [[#FUNCTION_SC]] + ; CHECK-SPV: Variable [[#ARR_PTR_TY_2]] [[#SCLA_2]] [[#FUNCTION_SC]] + + ; CHECK-LLVM: %[[ALLOCA0:.*]] = alloca [1 x float], align 4 + ; CHECK-LLVM: %[[ALLOCA1:.*]] = alloca [2 x i8], align 2 + ; CHECK-LLVM: %[[ALLOCA2:.*]] = alloca [4 x %struct_type], align 16 + + ; CHECK-SPV: Bitcast [[#FLOAT_PTR_TY]] [[#]] [[#SCLA_0]] + + ; CHECK-LLVM: %[[VAR0:.*]] = bitcast ptr %[[ALLOCA0]] to ptr + %scla0 = alloca float, i64 %length0, align 4 + + ; CHECK-SPV: Bitcast [[#I8_PTR_TY]] [[#]] [[#SCLA_1]] + + ; CHECK-LLVM: %[[VAR1:.*]] = bitcast ptr %[[ALLOCA1]] to ptr + %scla1 = alloca i8, i32 %length1, align 2 + + ; CHECK-SPV: Bitcast [[#STR_PTR_TY]] [[#]] [[#SCLA_2]] + + ; CHECK-LLVM: %[[VAR2:.*]] = bitcast ptr %[[ALLOCA2]] to ptr + %scla2 = alloca %struct_type, i8 %length2, align 16 + ret void +} + +declare i8 @_Z20__spirv_SpecConstantic(i32, i8) +declare i32 @_Z20__spirv_SpecConstantii(i32, i32) +declare i64 @_Z20__spirv_SpecConstantix(i32, i64) + +!0 = !{!"i64_spec_const", i32 0} +!1 = !{!"i32_spec_const", i32 1} +!2 = !{!"i8_spec_const", i32 2}