Skip to content

Commit

Permalink
Allow specialization constant length arrays (#2396)
Browse files Browse the repository at this point in the history
SPIR-V arrays can have their length specified via a specialization
constant. In case an `alloca` instruction uses a specialization
constant `s` as length, generate an array variable of length `s`.

Signed-off-by: Victor Perez <[email protected]>
  • Loading branch information
victor-eds authored Mar 19, 2024
1 parent 41dc967 commit 9fdddf0
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 12 deletions.
25 changes: 23 additions & 2 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,15 +2215,36 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
if (AllocaInst *Alc = dyn_cast<AllocaInst>(V)) {
SPIRVType *TranslatedTy = transScavengedType(V);
if (Alc->isArrayAllocation()) {
SPIRVValue *Length = transValue(Alc->getArraySize(), BB);
assert(Length && "Couldn't translate array size!");

if (isSpecConstantOpCode(Length->getOpCode())) {
// SPIR-V arrays length can be expressed using a specialization
// constant.
//
// Spec Constant Length Arrays need special treatment, as the allocation
// type will be 'OpTypePointer(Function, OpTypeArray(ElementType,
// Length))', we need to bitcast the obtained pointer to the expected
// type: 'OpTypePointer(Function, ElementType).
SPIRVType *AllocationType = BM->addPointerType(
StorageClassFunction,
BM->addArrayType(transType(Alc->getAllocatedType()), Length));
SPIRVValue *Arr = BM->addVariable(
AllocationType, false, spv::internal::LinkageTypeInternal, nullptr,
Alc->getName().str() + "_alloca", StorageClassFunction, BB);
// Manually set alignment. OpBitcast created below will be decorated as
// that's the SPIR-V value mapped to the original LLVM one.
transAlign(Alc, Arr);
return mapValue(V, BM->addUnaryInst(OpBitcast, TranslatedTy, Arr, BB));
}

if (!BM->checkExtension(ExtensionID::SPV_INTEL_variable_length_array,
SPIRVEC_InvalidInstruction,
toString(Alc) +
"\nTranslation of dynamic alloca requires "
"SPV_INTEL_variable_length_array extension."))
return nullptr;

SPIRVValue *Length = transValue(Alc->getArraySize(), BB);
assert(Length && "Couldn't translate array size!");
return mapValue(V,
BM->addInstTemplate(OpVariableLengthArrayINTEL,
{Length->getId()}, BB, TranslatedTy));
Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class SPIRVModuleImpl : public SPIRVModule {

// Type creation functions
template <class T> T *addType(T *Ty);
SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) override;
SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVValue *) override;
SPIRVTypeBool *addBoolType() override;
SPIRVTypeFloat *addFloatType(unsigned BitWidth) override;
SPIRVTypeFunction *addFunctionType(SPIRVType *,
Expand Down Expand Up @@ -968,7 +968,7 @@ SPIRVTypeVoid *SPIRVModuleImpl::addVoidType() {
}

SPIRVTypeArray *SPIRVModuleImpl::addArrayType(SPIRVType *ElementType,
SPIRVConstant *Length) {
SPIRVValue *Length) {
return addType(new SPIRVTypeArray(this, getId(), ElementType, Length));
}

Expand Down
2 changes: 1 addition & 1 deletion lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class SPIRVModule {
virtual void eraseInstruction(SPIRVInstruction *, SPIRVBasicBlock *) = 0;

// Type creation functions
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVConstant *) = 0;
virtual SPIRVTypeArray *addArrayType(SPIRVType *, SPIRVValue *) = 0;
virtual SPIRVTypeBool *addBoolType() = 0;
virtual SPIRVTypeFloat *addFloatType(unsigned) = 0;
virtual SPIRVTypeFunction *
Expand Down
9 changes: 4 additions & 5 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ uint64_t SPIRVType::getArrayLength() const {
const SPIRVTypeArray *AsArray = static_cast<const SPIRVTypeArray *>(this);
assert(AsArray->getLength()->getOpCode() == OpConstant &&
"getArrayLength can only be called with constant array lengths");
return AsArray->getLength()->getZExtIntValue();
return static_cast<SPIRVConstant *>(AsArray->getLength())->getZExtIntValue();
}

SPIRVWord SPIRVType::getBitWidth() const {
Expand Down Expand Up @@ -263,7 +263,7 @@ void SPIRVTypeStruct::setPacked(bool Packed) {
}

SPIRVTypeArray::SPIRVTypeArray(SPIRVModule *M, SPIRVId TheId,
SPIRVType *TheElemType, SPIRVConstant *TheLength)
SPIRVType *TheElemType, SPIRVValue *TheLength)
: SPIRVType(M, 4, OpTypeArray, TheId), ElemType(TheElemType),
Length(TheLength->getId()) {
validate();
Expand All @@ -273,11 +273,10 @@ void SPIRVTypeArray::validate() const {
SPIRVEntry::validate();
ElemType->validate();
assert(getValue(Length)->getType()->isTypeInt());
assert(isConstantOpCode(getValue(Length)->getOpCode()));
}

SPIRVConstant *SPIRVTypeArray::getLength() const {
return get<SPIRVConstant>(Length);
}
SPIRVValue *SPIRVTypeArray::getLength() const { return getValue(Length); }

_SPIRV_IMP_ENCDEC3(SPIRVTypeArray, Id, ElemType, Length)

Expand Down
4 changes: 2 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,13 @@ class SPIRVTypeArray : public SPIRVType {
public:
// Complete constructor
SPIRVTypeArray(SPIRVModule *M, SPIRVId TheId, SPIRVType *TheElemType,
SPIRVConstant *TheLength);
SPIRVValue *TheLength);
// Incomplete constructor
SPIRVTypeArray()
: SPIRVType(OpTypeArray), ElemType(nullptr), Length(SPIRVID_INVALID) {}

SPIRVType *getElementType() const { return ElemType; }
SPIRVConstant *getLength() const;
SPIRVValue *getLength() const;
SPIRVCapVec getRequiredCapability() const override {
return getElementType()->getRequiredCapability();
}
Expand Down
77 changes: 77 additions & 0 deletions test/SpecConstants/spec-constant-length-array.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv -spirv-text -o - %t.bc | FileCheck %s --check-prefix CHECK-SPV
; RUN: llvm-spirv -o %t.spv %t.bc
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s --check-prefix CHECK-LLVM

; CHECK-SPV-DAG: Decorate [[#I64_CONST:]] SpecId [[#]]
; CHECK-SPV-DAG: Decorate [[#I32_CONST:]] SpecId [[#]]
; CHECK-SPV-DAG: Decorate [[#I8_CONST:]] SpecId [[#]]
; CHECK-SPV-DAG: Decorate [[#SCLA_0:]] Alignment 4
; CHECK-SPV-DAG: Decorate [[#SCLA_1:]] Alignment 2
; CHECK-SPV-DAG: Decorate [[#SCLA_2:]] Alignment 16

; CHECK-SPV-DAG: TypeInt [[#I64_TY:]] 64
; CHECK-SPV-DAG: TypeInt [[#I32_TY:]] 32
; CHECK-SPV-DAG: TypeInt [[#I8_TY:]] 8

; CHECK-SPV-DAG: SpecConstant [[#I64_TY]] [[#LENGTH_0:]]
; CHECK-SPV-DAG: SpecConstant [[#I32_TY]] [[#LENGTH_1:]]
; CHECK-SPV-DAG: SpecConstant [[#I8_TY]] [[#LENGTH_2:]]

; CHECK-SPV-DAG: TypeFloat [[#FLOAT_TY:]] 32
; CHECK-SPV-DAG: TypePointer [[#FLOAT_PTR_TY:]] [[#FUNCTION_SC:]] [[#FLOAT_TY]]
; CHECK-SPV-DAG: TypeArray [[#ARR_TY_0:]] [[#FLOAT_TY]] [[#LENGTH_0]]
; CHECK-SPV-DAG: TypePointer [[#ARR_PTR_TY_0:]] [[#FUNCTION_SC]] [[#ARR_TY_0]]
; CHECK-SPV-DAG: TypePointer [[#I8_PTR_TY:]] [[#FUNCTION_SC]] [[#I8_TY]]
; CHECK-SPV-DAG: TypeArray [[#ARR_TY_1:]] [[#I8_TY]] [[#LENGTH_1]]
; CHECK-SPV-DAG: TypePointer [[#ARR_PTR_TY_1:]] [[#FUNCTION_SC]] [[#ARR_TY_1]]
; CHECK-SPV-DAG: TypeFloat [[#DOUBLE_TY:]] 64
; CHECK-SPV-DAG: TypeStruct [[#STR_TY:]] [[#DOUBLE_TY]] [[#DOUBLE_TY]]
; CHECK-SPV-DAG: TypePointer [[#STR_PTR_TY:]] [[#FUNCTION_SC]] [[#STR_TY]]
; CHECK-SPV-DAG: TypeArray [[#ARR_TY_2:]] [[#STR_TY]] [[#LENGTH_2]]
; CHECK-SPV-DAG: TypePointer [[#ARR_PTR_TY_2:]] [[#FUNCTION_SC]] [[#ARR_TY_2:]]

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

%struct_type = type { double, double }

define spir_kernel void @test() {
entry:
%length0 = call i64 @_Z20__spirv_SpecConstantix(i32 0, i64 1), !SYCL_SPEC_CONST_SYM_ID !0
%length1 = call i32 @_Z20__spirv_SpecConstantii(i32 1, i32 2), !SYCL_SPEC_CONST_SYM_ID !1
%length2 = call i8 @_Z20__spirv_SpecConstantic(i32 2, i8 4), !SYCL_SPEC_CONST_SYM_ID !2

; CHECK-SPV: Variable [[#ARR_PTR_TY_0]] [[#SCLA_0]] [[#FUNCTION_SC]]
; CHECK-SPV: Variable [[#ARR_PTR_TY_1]] [[#SCLA_1]] [[#FUNCTION_SC]]
; CHECK-SPV: Variable [[#ARR_PTR_TY_2]] [[#SCLA_2]] [[#FUNCTION_SC]]

; CHECK-LLVM: %[[ALLOCA0:.*]] = alloca [1 x float], align 4
; CHECK-LLVM: %[[ALLOCA1:.*]] = alloca [2 x i8], align 2
; CHECK-LLVM: %[[ALLOCA2:.*]] = alloca [4 x %struct_type], align 16

; CHECK-SPV: Bitcast [[#FLOAT_PTR_TY]] [[#]] [[#SCLA_0]]

; CHECK-LLVM: %[[VAR0:.*]] = bitcast ptr %[[ALLOCA0]] to ptr
%scla0 = alloca float, i64 %length0, align 4

; CHECK-SPV: Bitcast [[#I8_PTR_TY]] [[#]] [[#SCLA_1]]

; CHECK-LLVM: %[[VAR1:.*]] = bitcast ptr %[[ALLOCA1]] to ptr
%scla1 = alloca i8, i32 %length1, align 2

; CHECK-SPV: Bitcast [[#STR_PTR_TY]] [[#]] [[#SCLA_2]]

; CHECK-LLVM: %[[VAR2:.*]] = bitcast ptr %[[ALLOCA2]] to ptr
%scla2 = alloca %struct_type, i8 %length2, align 16
ret void
}

declare i8 @_Z20__spirv_SpecConstantic(i32, i8)
declare i32 @_Z20__spirv_SpecConstantii(i32, i32)
declare i64 @_Z20__spirv_SpecConstantix(i32, i64)

!0 = !{!"i64_spec_const", i32 0}
!1 = !{!"i32_spec_const", i32 1}
!2 = !{!"i8_spec_const", i32 2}

0 comments on commit 9fdddf0

Please sign in to comment.