Skip to content

Commit

Permalink
Adding simplified HAL dispatch methods. (#18189)
Browse files Browse the repository at this point in the history
These combine push constants and push descriptor sets into the dispatch
calls as in practice we have a near 1:1 relationship anyway. Pipeline
layouts are still used in HAL interfaces to allow the compiler to map
the information but are otherwise not used by the new ops.

The `--iree-hal-experimental-dispatch2` flag enables emitting the new
ops. Since executables no longer require pipeline layouts in this
simplified model the `--iree-hal-experimental-executable-create2` flag
can be used to stop passing them; targets that support dispatch2 will
ignore them if provided. Future changes will start to add support on
targets for the simplified bindings and then remove the existing
pipeline layout-based binding model as a breaking ABI change.

Current target status:
* [x] Local/CPU: executable-create2 and executable-dispatch2 supported
(backward compat)
* [x] CUDA: executable-dispatch2 supported (backward compat)
* [x] HIP: executable-dispatch2 supported (backward compat)
* [x] Metal: executable-dispatch2 supported (backward compat)
* [x] Vulkan: executable-dispatch2 supported (backward compat)
* [x] WebGPU: executable-dispatch2 supported (backward compat)

Reworking the CUDA/HIP/Metal/Vulkan/WebGPU flatbuffers to support
executable-create2 will be done in a follow-up.

Progress on #18154.
  • Loading branch information
benvanik authored Aug 12, 2024
1 parent 76eb9c1 commit 8dc6820
Show file tree
Hide file tree
Showing 68 changed files with 3,335 additions and 210 deletions.
20 changes: 14 additions & 6 deletions compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,21 @@ class LLVMCPUTargetBackend final : public TargetBackend {
llvmFunc->addParamAttr(i, align16);
}

// Optionally entry points may specify that they require workgroup local
LibraryBuilder::DispatchAttrs dispatchAttrs = {0};

// Entry points may optionally specify that they require workgroup local
// memory. We fetch that value here and plumb it through so the runtime
// knows how much memory to reserve and pass in.
int64_t localMemorySize = exportOp.getWorkgroupLocalMemory()
.value_or(APInt(64, 0))
.getSExtValue();
dispatchAttrs.localMemorySize = exportOp.getWorkgroupLocalMemory()
.value_or(APInt(64, 0))
.getSExtValue();

// Specify the constant and binding information used to validate
// dispatches.
// TODO(#18189): pack per-binding information bitfields.
dispatchAttrs.constantCount = exportOp.getLayout().getPushConstants();
dispatchAttrs.bindingCount =
exportOp.getLayout().getSetLayout(0).getBindings().size();

LibraryBuilder::SourceLocation sourceLocation;
if (options.debugLevel >= 1) {
Expand All @@ -417,8 +426,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {
}
libraryBuilder.addExport(exportOp.getName(), std::move(sourceLocation),
std::move(stageLocations), /*tag=*/"",
LibraryBuilder::DispatchAttrs{localMemorySize},
llvmFunc);
dispatchAttrs, llvmFunc);
}

// Embed source files (if present).
Expand Down
15 changes: 10 additions & 5 deletions compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,22 @@ makeDispatchFunctionType(llvm::LLVMContext &context) {

// %struct.iree_hal_executable_dispatch_attrs_v0_t = type {
// i16,
// i16
// i8,
// i8
// }
static llvm::StructType *makeDispatchAttrsType(llvm::LLVMContext &context) {
if (auto *existingType = llvm::StructType::getTypeByName(
context, "iree_hal_executable_dispatch_attrs_v0_t")) {
return existingType;
}
auto *i8Type = llvm::IntegerType::getInt8Ty(context);
auto *i16Type = llvm::IntegerType::getInt16Ty(context);
auto *type =
llvm::StructType::create(context,
{
i16Type,
i16Type,
i8Type,
i8Type,
},
"iree_hal_executable_dispatch_attrs_v0_t",
/*isPacked=*/false);
Expand Down Expand Up @@ -502,7 +505,7 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
bool hasNonDefaultAttrs = llvm::any_of(exports, [](const auto &dispatch) {
return !dispatch.attrs.isDefault();
});
if (!hasNonDefaultAttrs) {
if (hasNonDefaultAttrs) {
SmallVector<llvm::Constant *> exportAttrValues;
for (auto dispatch : exports) {
exportAttrValues.push_back(llvm::ConstantStruct::get(
Expand All @@ -513,8 +516,10 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) {
i16Type, roundUpToAlignment(dispatch.attrs.localMemorySize,
kWorkgroupLocalMemoryPageSize) /
kWorkgroupLocalMemoryPageSize),
// reserved=
llvm::ConstantInt::get(i16Type, 0),
// constant_count=
llvm::ConstantInt::get(i8Type, dispatch.attrs.constantCount),
// binding_count=
llvm::ConstantInt::get(i8Type, dispatch.attrs.bindingCount),
}));
}
exportAttrs = createArrayConstant(libraryName + "_attrs", dispatchAttrsType,
Expand Down
10 changes: 8 additions & 2 deletions compiler/plugins/target/LLVMCPU/LibraryBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,22 @@ class LibraryBuilder {
UNDEFINED = 4u,
};

// IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
// IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE
static const int64_t kWorkgroupLocalMemoryPageSize = 4096;

// iree_hal_executable_dispatch_attrs_v0_t
struct DispatchAttrs {
// Required workgroup local memory size, in bytes.
int64_t localMemorySize = 0;
// Total number of 32-bit constants used by the dispatch.
uint8_t constantCount = 0;
// Total number of bindings used by the dispatch.
uint8_t bindingCount = 0;

// True if all values are default and the attributes may be omitted.
constexpr bool isDefault() const { return localMemorySize == 0; }
constexpr bool isDefault() const {
return localMemorySize == 0 && constantCount == 0 && bindingCount == 0;
}
};

// iree_hal_executable_source_location_v0_t
Expand Down
22 changes: 21 additions & 1 deletion compiler/plugins/target/VMVX/VMVXTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ class VMVXTargetBackend final : public TargetBackend {
IREE::HAL::ExecutableVariantOp variantOp,
OpBuilder &executableBuilder) override {
// Add reflection information used at runtime specific to the HAL interface.
SymbolTable symbolTable(variantOp.getInnerModule());
auto vmModule =
*variantOp.getInnerModule().getOps<IREE::VM::ModuleOp>().begin();
SymbolTable symbolTable(vmModule);
for (auto exportOp : variantOp.getBlock().getOps<ExecutableExportOp>()) {
auto funcOp = symbolTable.lookup<IREE::VM::FuncOp>(exportOp.getName());

Expand All @@ -127,6 +129,24 @@ class VMVXTargetBackend final : public TargetBackend {
if (localMemorySizeAttr) {
funcOp.setReflectionAttr("local_memory", localMemorySizeAttr);
}

// Specify the constant and binding information used to validate
// dispatches.
// TODO(#18189): pack per-binding information bitfields.
if (auto layoutAttr = exportOp.getLayout()) {
int64_t constantCount = layoutAttr.getPushConstants();
if (constantCount > 0) {
funcOp.setReflectionAttr("constant_count",
executableBuilder.getI8IntegerAttr(
static_cast<uint8_t>(constantCount)));
}
size_t bindingCount = layoutAttr.getSetLayout(0).getBindings().size();
if (bindingCount > 0) {
funcOp.setReflectionAttr("binding_count",
executableBuilder.getI8IntegerAttr(
static_cast<uint8_t>(bindingCount)));
}
}
}

// Serialize the VM module to bytes and embed it directly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ addSet3IfNeeded(IREE::HAL::PipelineLayoutAttr originalAttr) {
SmallVector<IREE::HAL::DescriptorSetBindingAttr> bindingAttrs;
bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get(
originalAttr.getContext(), 0, IREE::HAL::DescriptorType::UniformBuffer,
std::nullopt));
IREE::HAL::DescriptorFlags::None));
setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get(
originalAttr.getContext(), 3, bindingAttrs, std::nullopt));
return IREE::HAL::PipelineLayoutAttr::get(originalAttr.getContext(),
Expand Down
99 changes: 28 additions & 71 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ assumeExportLayout(IREE::HAL::PipelineLayoutAttr layoutAttr) {
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = bindingAttr.getOrdinal();
setBinding.type = bindingAttr.getType();
setBinding.flags =
bindingAttr.getFlags().value_or(IREE::HAL::DescriptorFlags::None);
setBinding.flags = bindingAttr.getFlags();
setLayout.bindings[setBinding.ordinal] = setBinding;
pipelineLayout.resourceMap.emplace_back(setLayout.ordinal,
setBinding.ordinal);
Expand Down Expand Up @@ -123,7 +122,6 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,

// Check the usage of each binding at each dispatch site.
struct DescriptorInfo {
bool isIndirect = false;
DescriptorFlags flags = DescriptorFlags::None;
};
SmallVector<DescriptorInfo> descriptorInfos(bindingCount);
Expand All @@ -142,12 +140,18 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
// Opt into indirect descriptors when dynamic values are used from
// execution regions that may be executed more than once.
if (!isRegionExecutedOnce) {
auto resource = dispatchOp.getResources()[i];
Value resource = dispatchOp.getResources()[i];
if (auto blockArg = dyn_cast<BlockArgument>(resource)) {
if (blockArg.getOwner()->getParentOp() == parentOp) {
resource = parentOp.getResourceOperands()[blockArg.getArgNumber()];
}
}
switch (categorizeValue(resource)) {
default:
case ValueOrigin::Unknown:
case ValueOrigin::MutableGlobal:
descriptorInfo.isIndirect |= true;
descriptorInfo.flags =
descriptorInfo.flags | IREE::HAL::DescriptorFlags::Indirect;
break;
case ValueOrigin::LocalConstant:
case ValueOrigin::ImmutableGlobal:
Expand All @@ -173,74 +177,27 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp,
pipelineLayout.pushConstantCount = operandCount;
pipelineLayout.resourceMap.resize(bindingCount);

// Today we use one or two sets based on the composition of bindings we have:
// we try to put everything in a directly referenced set 0 and spill over any
// indirectly referenced values into the second set.
//
// HACK: the Vulkan HAL implementation currently cannot handle multiple
// descriptor sets. Ouch. To preserve existing behavior we only use a single
// set and mark the whole thing as indirect if any bindings are indirect.
const bool forceSingleSet = true;
if (forceSingleSet) {
DescriptorSetLayout setLayout;
setLayout.ordinal = 0;
setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
setLayout.bindings.reserve(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
if (descriptorInfo.isIndirect) {
setLayout.flags =
setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
pipelineLayout.setLayouts.push_back(setLayout);
} else {
DescriptorSetLayout directSetLayout;
directSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
directSetLayout.bindings.reserve(bindingCount);
DescriptorSetLayout indirectSetLayout;
indirectSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::Indirect;
indirectSetLayout.bindings.reserve(bindingCount);

// Ordinals relative to the owning set.
SmallVector<unsigned> bindingSetOrdinals(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
auto &setLayout =
descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
bindingSetOrdinals[i] = setBinding.ordinal;
}
unsigned nextSetOrdinal = 0;
if (!directSetLayout.bindings.empty()) {
directSetLayout.ordinal = nextSetOrdinal++;
pipelineLayout.setLayouts.push_back(directSetLayout);
}
if (!indirectSetLayout.bindings.empty()) {
indirectSetLayout.ordinal = nextSetOrdinal++;
pipelineLayout.setLayouts.push_back(indirectSetLayout);
}

// Map each resource to its set/binding ordinals.
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
auto &setLayout =
descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout;
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, bindingSetOrdinals[i]);
// TODO(#18154): simplify binding setup.
DescriptorSetLayout setLayout;
setLayout.ordinal = 0;
setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None;
setLayout.bindings.reserve(bindingCount);
for (unsigned i = 0; i < bindingCount; ++i) {
const auto &descriptorInfo = descriptorInfos[i];
if (allEnumBitsSet(descriptorInfo.flags,
IREE::HAL::DescriptorFlags::Indirect)) {
setLayout.flags =
setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect;
}
DescriptorSetLayoutBinding setBinding;
setBinding.ordinal = setLayout.bindings.size();
setBinding.type = IREE::HAL::DescriptorType::StorageBuffer;
setBinding.flags = descriptorInfo.flags;
setLayout.bindings.push_back(setBinding);
pipelineLayout.resourceMap[i] =
std::make_pair(setLayout.ordinal, setBinding.ordinal);
}
pipelineLayout.setLayouts.push_back(setLayout);

LLVM_DEBUG({
auto executableOp = exportOp->getParentOfType<IREE::Stream::ExecutableOp>();
Expand Down
10 changes: 0 additions & 10 deletions compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,6 @@
namespace mlir::iree_compiler::IREE::HAL {

ValueOrigin categorizeValue(Value value) {
// If this is a captured argument of an execution region then look up to the
// SSA value that was captured.
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (auto closureOp = dyn_cast<IREE::Util::ClosureOpInterface>(
blockArg.getOwner()->getParentOp())) {
return categorizeValue(
closureOp.getClosureOperands()[blockArg.getArgNumber()]);
}
}

// If we wanted to pull in entire IR slices this would have to use a
// worklist (selects of globals based on globals, etc). For now this analysis
// only looks at the value provided.
Expand Down
Loading

0 comments on commit 8dc6820

Please sign in to comment.