Skip to content

Commit

Permalink
Add infrastructure for translating ExecutionModeId (#2297)
Browse files Browse the repository at this point in the history
This functionality was added in SPIR-V 1.2 and allows using an <id> to
set the execution modes SubgroupsPerWorkgroupId, LocalSizeId, and
LocalSizeHintI, and others.
  • Loading branch information
vmaksimo authored Jan 16, 2024
1 parent 4dfbc85 commit 10b0aab
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 38 deletions.
31 changes: 17 additions & 14 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5657,19 +5657,20 @@ bool LLVMToSPIRVBase::transExecutionMode() {
auto AddSingleArgExecutionMode = [&](ExecutionMode EMode) {
uint32_t Arg = ~0u;
N.get(Arg);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(BF, EMode, Arg)));
BF->addExecutionMode(
BM->add(new SPIRVExecutionMode(OpExecutionMode, BF, EMode, Arg)));
};

switch (EMode) {
case spv::ExecutionModeContractionOff:
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
break;
case spv::ExecutionModeInitializer:
case spv::ExecutionModeFinalizer:
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_1)) {
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
} else {
getErrorLog().checkError(false, SPIRVEC_Requires1_1,
"Initializer/Finalizer Execution Mode");
Expand All @@ -5681,15 +5682,16 @@ bool LLVMToSPIRVBase::transExecutionMode() {
unsigned X = 0, Y = 0, Z = 0;
N.get(X).get(Y).get(Z);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
} break;
case spv::ExecutionModeMaxWorkgroupSizeINTEL: {
if (BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_kernel_attributes)) {
unsigned X = 0, Y = 0, Z = 0;
N.get(X).get(Y).get(Z);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y,
Z)));
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
BM->addCapability(CapabilityKernelAttributesINTEL);
}
Expand All @@ -5698,8 +5700,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
if (!BM->isAllowedToUseExtension(
ExtensionID::SPV_INTEL_kernel_attributes))
break;
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
BM->addCapability(CapabilityKernelAttributesINTEL);
} break;
Expand Down Expand Up @@ -5743,7 +5745,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
unsigned NBarrierCnt = 0;
N.get(NBarrierCnt);
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt)));
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
NBarrierCnt)));
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
BM->addCapability(CapabilityVectorComputeINTEL);
} break;
Expand Down Expand Up @@ -5773,8 +5776,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
} break;
case spv::internal::ExecutionModeFastCompositeKernelINTEL: {
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
BF->addExecutionMode(BM->add(
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
} break;
default:
llvm_unreachable("invalid execution mode");
Expand Down Expand Up @@ -5819,8 +5822,8 @@ void LLVMToSPIRVBase::transFPContract() {
}

if (DisableContraction) {
BF->addExecutionMode(BF->getModule()->add(
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
SPIRVExecutionModelKind TheExecModel,
SPIRVId TheId, const std::string &TheName,
std::vector<SPIRVId> Variables)
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
getSizeInWords(TheName) + Variables.size() + 3),
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}

Expand Down Expand Up @@ -681,7 +681,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
}

SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
Str(TheStr) {}

void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }

Expand Down
69 changes: 47 additions & 22 deletions lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,23 +521,24 @@ class SPIRVAnnotationGeneric : public SPIRVEntryNoIdGeneric {
SPIRVId Target;
};

template <Op OC> class SPIRVAnnotation : public SPIRVAnnotationGeneric {
class SPIRVAnnotation : public SPIRVAnnotationGeneric {
public:
// Complete constructor
SPIRVAnnotation(const SPIRVEntry *TheTarget, unsigned TheWordCount)
SPIRVAnnotation(Op OC, const SPIRVEntry *TheTarget, unsigned TheWordCount)
: SPIRVAnnotationGeneric(TheTarget->getModule(), TheWordCount, OC,
TheTarget->getId()) {}
// Incomplete constructor
SPIRVAnnotation() : SPIRVAnnotationGeneric(OC) {}
// Incomplete constructors
SPIRVAnnotation(Op OC) : SPIRVAnnotationGeneric(OC) {}
SPIRVAnnotation() : SPIRVAnnotationGeneric(OpNop) {}
};

class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
class SPIRVEntryPoint : public SPIRVAnnotation {
public:
static const SPIRVWord FixedWC = 4;
SPIRVEntryPoint(SPIRVModule *TheModule, SPIRVExecutionModelKind,
SPIRVId TheId, const std::string &TheName,
std::vector<SPIRVId> Variables);
SPIRVEntryPoint() {}
SPIRVEntryPoint() : SPIRVAnnotation(OpEntryPoint) {}

_SPIRV_DCL_ENCDEC
protected:
Expand All @@ -548,12 +549,12 @@ class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
std::vector<SPIRVId> Variables;
};

class SPIRVName : public SPIRVAnnotation<OpName> {
class SPIRVName : public SPIRVAnnotation {
public:
// Complete constructor
SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr);
// Incomplete constructor
SPIRVName() {}
SPIRVName() : SPIRVAnnotation(OpName) {}

protected:
_SPIRV_DCL_ENCDEC
Expand All @@ -562,18 +563,18 @@ class SPIRVName : public SPIRVAnnotation<OpName> {
std::string Str;
};

class SPIRVMemberName : public SPIRVAnnotation<OpName> {
class SPIRVMemberName : public SPIRVAnnotation {
public:
static const SPIRVWord FixedWC = 3;
// Complete constructor
SPIRVMemberName(const SPIRVEntry *TheTarget, SPIRVWord TheMemberNumber,
const std::string &TheStr)
: SPIRVAnnotation(TheTarget, FixedWC + getSizeInWords(TheStr)),
: SPIRVAnnotation(OpName, TheTarget, FixedWC + getSizeInWords(TheStr)),
MemberNumber(TheMemberNumber), Str(TheStr) {
validate();
}
// Incomplete constructor
SPIRVMemberName() : MemberNumber(SPIRVWORD_MAX) {}
SPIRVMemberName() : SPIRVAnnotation(OpName), MemberNumber(SPIRVWORD_MAX) {}

protected:
_SPIRV_DCL_ENCDEC
Expand Down Expand Up @@ -649,31 +650,33 @@ class SPIRVLine : public SPIRVEntry {
SPIRVWord Column;
};

class SPIRVExecutionMode : public SPIRVAnnotation<OpExecutionMode> {
class SPIRVExecutionMode : public SPIRVAnnotation {
public:
// Complete constructor for LocalSize, LocalSizeHint
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode,
SPIRVWord X, SPIRVWord Y, SPIRVWord Z)
: SPIRVAnnotation(TheTarget, 6), ExecMode(TheExecMode) {
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
SPIRVExecutionModeKind TheExecMode, SPIRVWord X,
SPIRVWord Y, SPIRVWord Z)
: SPIRVAnnotation(OC, TheTarget, 6), ExecMode(TheExecMode) {
WordLiterals.push_back(X);
WordLiterals.push_back(Y);
WordLiterals.push_back(Z);
updateModuleVersion();
}
// Complete constructor for VecTypeHint, SubgroupSize, SubgroupsPerWorkgroup
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode,
SPIRVWord Code)
: SPIRVAnnotation(TheTarget, 4), ExecMode(TheExecMode) {
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
SPIRVExecutionModeKind TheExecMode, SPIRVWord Code)
: SPIRVAnnotation(OC, TheTarget, 4), ExecMode(TheExecMode) {
WordLiterals.push_back(Code);
updateModuleVersion();
}
// Complete constructor for ContractionOff
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode)
: SPIRVAnnotation(TheTarget, 3), ExecMode(TheExecMode) {
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
SPIRVExecutionModeKind TheExecMode)
: SPIRVAnnotation(OC, TheTarget, 3), ExecMode(TheExecMode) {
updateModuleVersion();
}
// Incomplete constructor
SPIRVExecutionMode() : ExecMode(ExecutionModeInvocations) {}
SPIRVExecutionMode()
: SPIRVAnnotation(OpExecutionMode), ExecMode(ExecutionModeInvocations) {}
SPIRVExecutionModeKind getExecutionMode() const { return ExecMode; }
const std::vector<SPIRVWord> &getLiterals() const { return WordLiterals; }
SPIRVCapVec getRequiredCapability() const override {
Expand All @@ -699,6 +702,28 @@ class SPIRVExecutionMode : public SPIRVAnnotation<OpExecutionMode> {
std::vector<SPIRVWord> WordLiterals;
};

class SPIRVExecutionModeId : public SPIRVExecutionMode {
public:
// Complete constructor for LocalSizeId, LocalSizeHintId
SPIRVExecutionModeId(SPIRVEntry *TheTarget,
SPIRVExecutionModeKind TheExecMode, SPIRVWord X,
SPIRVWord Y, SPIRVWord Z)
: SPIRVExecutionMode(OpExecutionModeId, TheTarget, TheExecMode, X, Y, Z) {
updateModuleVersion();
}
// Complete constructor for SubgroupsPerWorkgroupId
SPIRVExecutionModeId(SPIRVEntry *TheTarget,
SPIRVExecutionModeKind TheExecMode, SPIRVWord Code)
: SPIRVExecutionMode(OpExecutionModeId, TheTarget, TheExecMode, Code) {
updateModuleVersion();
}
// Incomplete constructor
SPIRVExecutionModeId() : SPIRVExecutionMode() {}
SPIRVWord getRequiredSPIRVVersion() const override {
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_2);
}
};

class SPIRVComponentExecutionModes {
typedef std::multimap<SPIRVExecutionModeKind, SPIRVExecutionMode *>
SPIRVExecutionModeMap;
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ _SPIRV_OP(TypePipeStorage, 322)
_SPIRV_OP(ConstantPipeStorage, 323)
_SPIRV_OP(CreatePipeFromPipeStorage, 324)
_SPIRV_OP(ModuleProcessed, 330)
_SPIRV_OP(ExecutionModeId, 331)
_SPIRV_OP(DecorateId, 332)
_SPIRV_OP(GroupNonUniformElect, 333)
_SPIRV_OP(GroupNonUniformAll, 334)
Expand Down

0 comments on commit 10b0aab

Please sign in to comment.