diff --git a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp index 912b7a844984..3a6337e766a5 100644 --- a/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp +++ b/compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp @@ -387,12 +387,21 @@ class LLVMCPUTargetBackend final : public TargetBackend { llvmFunc->addParamAttr(i, align16); } - // Optionally entry points may specify that they require workgroup local + LibraryBuilder::DispatchAttrs dispatchAttrs = {0}; + + // Entry points may optionally specify that they require workgroup local // memory. We fetch that value here and plumb it through so the runtime // knows how much memory to reserve and pass in. - int64_t localMemorySize = exportOp.getWorkgroupLocalMemory() - .value_or(APInt(64, 0)) - .getSExtValue(); + dispatchAttrs.localMemorySize = exportOp.getWorkgroupLocalMemory() + .value_or(APInt(64, 0)) + .getSExtValue(); + + // Specify the constant and binding information used to validate + // dispatches. + // TODO(#18189): pack per-binding information bitfields. + dispatchAttrs.constantCount = exportOp.getLayout().getPushConstants(); + dispatchAttrs.bindingCount = + exportOp.getLayout().getSetLayout(0).getBindings().size(); LibraryBuilder::SourceLocation sourceLocation; if (options.debugLevel >= 1) { @@ -417,8 +426,7 @@ class LLVMCPUTargetBackend final : public TargetBackend { } libraryBuilder.addExport(exportOp.getName(), std::move(sourceLocation), std::move(stageLocations), /*tag=*/"", - LibraryBuilder::DispatchAttrs{localMemorySize}, - llvmFunc); + dispatchAttrs, llvmFunc); } // Embed source files (if present). diff --git a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp index 3c39849d8c54..21621b9b39ec 100644 --- a/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp +++ b/compiler/plugins/target/LLVMCPU/LibraryBuilder.cpp @@ -111,19 +111,22 @@ makeDispatchFunctionType(llvm::LLVMContext &context) { // %struct.iree_hal_executable_dispatch_attrs_v0_t = type { // i16, -// i16 +// i8, +// i8 // } static llvm::StructType *makeDispatchAttrsType(llvm::LLVMContext &context) { if (auto *existingType = llvm::StructType::getTypeByName( context, "iree_hal_executable_dispatch_attrs_v0_t")) { return existingType; } + auto *i8Type = llvm::IntegerType::getInt8Ty(context); auto *i16Type = llvm::IntegerType::getInt16Ty(context); auto *type = llvm::StructType::create(context, { i16Type, - i16Type, + i8Type, + i8Type, }, "iree_hal_executable_dispatch_attrs_v0_t", /*isPacked=*/false); @@ -502,7 +505,7 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { bool hasNonDefaultAttrs = llvm::any_of(exports, [](const auto &dispatch) { return !dispatch.attrs.isDefault(); }); - if (!hasNonDefaultAttrs) { + if (hasNonDefaultAttrs) { SmallVector exportAttrValues; for (auto dispatch : exports) { exportAttrValues.push_back(llvm::ConstantStruct::get( @@ -513,8 +516,10 @@ LibraryBuilder::buildLibraryV0ExportTable(std::string libraryName) { i16Type, roundUpToAlignment(dispatch.attrs.localMemorySize, kWorkgroupLocalMemoryPageSize) / kWorkgroupLocalMemoryPageSize), - // reserved= - llvm::ConstantInt::get(i16Type, 0), + // constant_count= + llvm::ConstantInt::get(i8Type, dispatch.attrs.constantCount), + // binding_count= + llvm::ConstantInt::get(i8Type, dispatch.attrs.bindingCount), })); } exportAttrs = createArrayConstant(libraryName + "_attrs", dispatchAttrsType, diff --git a/compiler/plugins/target/LLVMCPU/LibraryBuilder.h b/compiler/plugins/target/LLVMCPU/LibraryBuilder.h index fd3416b7e73b..6b1ee87d313a 100644 --- a/compiler/plugins/target/LLVMCPU/LibraryBuilder.h +++ b/compiler/plugins/target/LLVMCPU/LibraryBuilder.h @@ -74,16 +74,22 @@ class LibraryBuilder { UNDEFINED = 4u, }; - // IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE + // IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE static const int64_t kWorkgroupLocalMemoryPageSize = 4096; // iree_hal_executable_dispatch_attrs_v0_t struct DispatchAttrs { // Required workgroup local memory size, in bytes. int64_t localMemorySize = 0; + // Total number of 32-bit constants used by the dispatch. + uint8_t constantCount = 0; + // Total number of bindings used by the dispatch. + uint8_t bindingCount = 0; // True if all values are default and the attributes may be omitted. - constexpr bool isDefault() const { return localMemorySize == 0; } + constexpr bool isDefault() const { + return localMemorySize == 0 && constantCount == 0 && bindingCount == 0; + } }; // iree_hal_executable_source_location_v0_t diff --git a/compiler/plugins/target/VMVX/VMVXTarget.cpp b/compiler/plugins/target/VMVX/VMVXTarget.cpp index b87844df2363..831eb8cd66a4 100644 --- a/compiler/plugins/target/VMVX/VMVXTarget.cpp +++ b/compiler/plugins/target/VMVX/VMVXTarget.cpp @@ -116,7 +116,9 @@ class VMVXTargetBackend final : public TargetBackend { IREE::HAL::ExecutableVariantOp variantOp, OpBuilder &executableBuilder) override { // Add reflection information used at runtime specific to the HAL interface. - SymbolTable symbolTable(variantOp.getInnerModule()); + auto vmModule = + *variantOp.getInnerModule().getOps().begin(); + SymbolTable symbolTable(vmModule); for (auto exportOp : variantOp.getBlock().getOps()) { auto funcOp = symbolTable.lookup(exportOp.getName()); @@ -127,6 +129,24 @@ class VMVXTargetBackend final : public TargetBackend { if (localMemorySizeAttr) { funcOp.setReflectionAttr("local_memory", localMemorySizeAttr); } + + // Specify the constant and binding information used to validate + // dispatches. + // TODO(#18189): pack per-binding information bitfields. + if (auto layoutAttr = exportOp.getLayout()) { + int64_t constantCount = layoutAttr.getPushConstants(); + if (constantCount > 0) { + funcOp.setReflectionAttr("constant_count", + executableBuilder.getI8IntegerAttr( + static_cast(constantCount))); + } + size_t bindingCount = layoutAttr.getSetLayout(0).getBindings().size(); + if (bindingCount > 0) { + funcOp.setReflectionAttr("binding_count", + executableBuilder.getI8IntegerAttr( + static_cast(bindingCount))); + } + } } // Serialize the VM module to bytes and embed it directly. diff --git a/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp b/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp index 495a1a592976..9b418c321746 100644 --- a/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp +++ b/compiler/src/iree/compiler/Codegen/WGSL/WGSLReplacePushConstants.cpp @@ -98,7 +98,7 @@ addSet3IfNeeded(IREE::HAL::PipelineLayoutAttr originalAttr) { SmallVector bindingAttrs; bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get( originalAttr.getContext(), 0, IREE::HAL::DescriptorType::UniformBuffer, - std::nullopt)); + IREE::HAL::DescriptorFlags::None)); setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get( originalAttr.getContext(), 3, bindingAttrs, std::nullopt)); return IREE::HAL::PipelineLayoutAttr::get(originalAttr.getContext(), diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp index caf47bd1d891..39b5ec8abb64 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/BindingLayout.cpp @@ -62,8 +62,7 @@ assumeExportLayout(IREE::HAL::PipelineLayoutAttr layoutAttr) { DescriptorSetLayoutBinding setBinding; setBinding.ordinal = bindingAttr.getOrdinal(); setBinding.type = bindingAttr.getType(); - setBinding.flags = - bindingAttr.getFlags().value_or(IREE::HAL::DescriptorFlags::None); + setBinding.flags = bindingAttr.getFlags(); setLayout.bindings[setBinding.ordinal] = setBinding; pipelineLayout.resourceMap.emplace_back(setLayout.ordinal, setBinding.ordinal); @@ -123,7 +122,6 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp, // Check the usage of each binding at each dispatch site. struct DescriptorInfo { - bool isIndirect = false; DescriptorFlags flags = DescriptorFlags::None; }; SmallVector descriptorInfos(bindingCount); @@ -142,12 +140,18 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp, // Opt into indirect descriptors when dynamic values are used from // execution regions that may be executed more than once. if (!isRegionExecutedOnce) { - auto resource = dispatchOp.getResources()[i]; + Value resource = dispatchOp.getResources()[i]; + if (auto blockArg = dyn_cast(resource)) { + if (blockArg.getOwner()->getParentOp() == parentOp) { + resource = parentOp.getResourceOperands()[blockArg.getArgNumber()]; + } + } switch (categorizeValue(resource)) { default: case ValueOrigin::Unknown: case ValueOrigin::MutableGlobal: - descriptorInfo.isIndirect |= true; + descriptorInfo.flags = + descriptorInfo.flags | IREE::HAL::DescriptorFlags::Indirect; break; case ValueOrigin::LocalConstant: case ValueOrigin::ImmutableGlobal: @@ -173,74 +177,27 @@ deriveStreamExportLayout(IREE::Stream::ExecutableExportOp exportOp, pipelineLayout.pushConstantCount = operandCount; pipelineLayout.resourceMap.resize(bindingCount); - // Today we use one or two sets based on the composition of bindings we have: - // we try to put everything in a directly referenced set 0 and spill over any - // indirectly referenced values into the second set. - // - // HACK: the Vulkan HAL implementation currently cannot handle multiple - // descriptor sets. Ouch. To preserve existing behavior we only use a single - // set and mark the whole thing as indirect if any bindings are indirect. - const bool forceSingleSet = true; - if (forceSingleSet) { - DescriptorSetLayout setLayout; - setLayout.ordinal = 0; - setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None; - setLayout.bindings.reserve(bindingCount); - for (unsigned i = 0; i < bindingCount; ++i) { - const auto &descriptorInfo = descriptorInfos[i]; - if (descriptorInfo.isIndirect) { - setLayout.flags = - setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect; - } - DescriptorSetLayoutBinding setBinding; - setBinding.ordinal = setLayout.bindings.size(); - setBinding.type = IREE::HAL::DescriptorType::StorageBuffer; - setBinding.flags = descriptorInfo.flags; - setLayout.bindings.push_back(setBinding); - pipelineLayout.resourceMap[i] = - std::make_pair(setLayout.ordinal, setBinding.ordinal); - } - pipelineLayout.setLayouts.push_back(setLayout); - } else { - DescriptorSetLayout directSetLayout; - directSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None; - directSetLayout.bindings.reserve(bindingCount); - DescriptorSetLayout indirectSetLayout; - indirectSetLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::Indirect; - indirectSetLayout.bindings.reserve(bindingCount); - - // Ordinals relative to the owning set. - SmallVector bindingSetOrdinals(bindingCount); - for (unsigned i = 0; i < bindingCount; ++i) { - const auto &descriptorInfo = descriptorInfos[i]; - auto &setLayout = - descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout; - DescriptorSetLayoutBinding setBinding; - setBinding.ordinal = setLayout.bindings.size(); - setBinding.type = IREE::HAL::DescriptorType::StorageBuffer; - setBinding.flags = descriptorInfo.flags; - setLayout.bindings.push_back(setBinding); - bindingSetOrdinals[i] = setBinding.ordinal; - } - unsigned nextSetOrdinal = 0; - if (!directSetLayout.bindings.empty()) { - directSetLayout.ordinal = nextSetOrdinal++; - pipelineLayout.setLayouts.push_back(directSetLayout); - } - if (!indirectSetLayout.bindings.empty()) { - indirectSetLayout.ordinal = nextSetOrdinal++; - pipelineLayout.setLayouts.push_back(indirectSetLayout); - } - - // Map each resource to its set/binding ordinals. - for (unsigned i = 0; i < bindingCount; ++i) { - const auto &descriptorInfo = descriptorInfos[i]; - auto &setLayout = - descriptorInfo.isIndirect ? indirectSetLayout : directSetLayout; - pipelineLayout.resourceMap[i] = - std::make_pair(setLayout.ordinal, bindingSetOrdinals[i]); + // TODO(#18154): simplify binding setup. + DescriptorSetLayout setLayout; + setLayout.ordinal = 0; + setLayout.flags = IREE::HAL::DescriptorSetLayoutFlags::None; + setLayout.bindings.reserve(bindingCount); + for (unsigned i = 0; i < bindingCount; ++i) { + const auto &descriptorInfo = descriptorInfos[i]; + if (allEnumBitsSet(descriptorInfo.flags, + IREE::HAL::DescriptorFlags::Indirect)) { + setLayout.flags = + setLayout.flags | IREE::HAL::DescriptorSetLayoutFlags::Indirect; } + DescriptorSetLayoutBinding setBinding; + setBinding.ordinal = setLayout.bindings.size(); + setBinding.type = IREE::HAL::DescriptorType::StorageBuffer; + setBinding.flags = descriptorInfo.flags; + setLayout.bindings.push_back(setBinding); + pipelineLayout.resourceMap[i] = + std::make_pair(setLayout.ordinal, setBinding.ordinal); } + pipelineLayout.setLayouts.push_back(setLayout); LLVM_DEBUG({ auto executableOp = exportOp->getParentOfType(); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp index 07bb527e264a..16f1aeede6bd 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Analysis/Captures.cpp @@ -12,16 +12,6 @@ namespace mlir::iree_compiler::IREE::HAL { ValueOrigin categorizeValue(Value value) { - // If this is a captured argument of an execution region then look up to the - // SSA value that was captured. - if (auto blockArg = dyn_cast(value)) { - if (auto closureOp = dyn_cast( - blockArg.getOwner()->getParentOp())) { - return categorizeValue( - closureOp.getClosureOperands()[blockArg.getArgNumber()]); - } - } - // If we wanted to pull in entire IR slices this would have to use a // worklist (selects of globals based on globals, etc). For now this analysis // only looks at the value provided. diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index cb4179f60ed2..72716cb0f175 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -425,6 +425,162 @@ class CommandBufferDispatchIndirectOpConversion mutable IREE::VM::ImportOp importOp; }; +class CommandBufferDispatch2OpConversion + : public OpConversionPattern { +public: + CommandBufferDispatch2OpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(typeConverter, context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::CommandBufferDispatch2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto importType = importOp.getFunctionType(); + + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + Value zeroI32 = rewriter.create(op.getLoc()); + + auto flags = adaptor.getFlagsAttr() + ? rewriter + .create( + op.getLoc(), adaptor.getFlagsAttr().getInt()) + .getResult() + : rewriter.create(op.getLoc()) + .getResult(); + SmallVector callOperands = { + adaptor.getCommandBuffer(), + adaptor.getExecutable(), + castToImportType(adaptor.getEntryPoint(), i32Type, rewriter), + castToImportType(adaptor.getWorkgroupX(), i32Type, rewriter), + castToImportType(adaptor.getWorkgroupY(), i32Type, rewriter), + castToImportType(adaptor.getWorkgroupZ(), i32Type, rewriter), + flags, + }; + SmallVector segmentSizes = { + /*command_buffer=*/-1, + /*executable=*/-1, + /*entry_point=*/-1, + /*workgroup_x=*/-1, + /*workgroup_y=*/-1, + /*workgroup_z=*/-1, + /*flags=*/-1, + /*constants=*/static_cast(adaptor.getConstants().size()), + /*bindings=*/ + static_cast(adaptor.getBindingBuffers().size()), + }; + llvm::append_range(callOperands, adaptor.getConstants()); + for (auto [bindingBufferOrSlot, bindingOffset, bindingLength] : + llvm::zip_equal(adaptor.getBindingBuffers(), + adaptor.getBindingOffsets(), + adaptor.getBindingLengths())) { + callOperands.push_back(zeroI32); + auto [bindingBufferSlot, bindingBuffer] = + splitBufferSlot(op.getLoc(), bindingBufferOrSlot, rewriter); + callOperands.push_back(bindingBufferSlot); + callOperands.push_back(bindingBuffer); + callOperands.push_back( + castToImportType(bindingOffset, i64Type, rewriter)); + callOperands.push_back( + castToImportType(bindingLength, i64Type, rewriter)); + } + + auto callOp = rewriter.replaceOpWithNewOp( + op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes, + importType.getInputs(), callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + +class CommandBufferDispatch2IndirectOpConversion + : public OpConversionPattern { +public: + CommandBufferDispatch2IndirectOpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(typeConverter, context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::CommandBufferDispatch2IndirectOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto importType = importOp.getFunctionType(); + + auto i32Type = rewriter.getI32Type(); + auto i64Type = rewriter.getI64Type(); + Value zeroI32 = rewriter.create(op.getLoc()); + + auto [workgroupsBufferSlot, workgroupsBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter); + auto flags = adaptor.getFlagsAttr() + ? rewriter + .create( + op.getLoc(), adaptor.getFlagsAttr().getInt()) + .getResult() + : rewriter.create(op.getLoc()) + .getResult(); + SmallVector callOperands = { + adaptor.getCommandBuffer(), + adaptor.getExecutable(), + castToImportType(adaptor.getEntryPoint(), i32Type, rewriter), + workgroupsBufferSlot, + workgroupsBuffer, + castToImportType(adaptor.getWorkgroupsOffset(), i64Type, rewriter), + flags, + }; + SmallVector segmentSizes = { + /*command_buffer=*/-1, + /*executable=*/-1, + /*entry_point=*/-1, + /*workgroups_buffer_slot=*/-1, + /*workgroups_buffer=*/-1, + /*workgroups_offset=*/-1, + /*flags=*/-1, + /*constants=*/static_cast(adaptor.getConstants().size()), + /*bindings=*/ + static_cast(adaptor.getBindingBuffers().size()), + }; + llvm::append_range(callOperands, adaptor.getConstants()); + for (auto [bindingBufferOrSlot, bindingOffset, bindingLength] : + llvm::zip_equal(adaptor.getBindingBuffers(), + adaptor.getBindingOffsets(), + adaptor.getBindingLengths())) { + callOperands.push_back(zeroI32); + auto [bindingBufferSlot, bindingBuffer] = + splitBufferSlot(op.getLoc(), bindingBufferOrSlot, rewriter); + callOperands.push_back(bindingBufferSlot); + callOperands.push_back(bindingBuffer); + callOperands.push_back( + castToImportType(bindingOffset, i64Type, rewriter)); + callOperands.push_back( + castToImportType(bindingLength, i64Type, rewriter)); + } + + auto callOp = rewriter.replaceOpWithNewOp( + op, SymbolRefAttr::get(importOp), importType.getResults(), segmentSizes, + importType.getInputs(), callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + } // namespace void populateHALCommandBufferToVMPatterns(MLIRContext *context, @@ -468,6 +624,11 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context, patterns.insert( context, importSymbols, typeConverter, "hal.command_buffer.dispatch.indirect"); + patterns.insert( + context, importSymbols, typeConverter, "hal.command_buffer.dispatch2"); + patterns.insert( + context, importSymbols, typeConverter, + "hal.command_buffer.dispatch2.indirect"); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp index a911de8b2830..7b0372d42266 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertExecutableOps.cpp @@ -150,6 +150,62 @@ class ExecutableCreateOpConversion mutable IREE::VM::ImportOp importOp; }; +class ExecutableCreate2OpConversion + : public OpConversionPattern { +public: + ExecutableCreate2OpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::ExecutableCreate2Op createOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Materialize vm.rodata for the binary. + auto executableBinaryOp = + SymbolTable::lookupNearestSymbolFrom( + createOp, createOp.getExecutableTarget()); + auto executableOp = executableBinaryOp.getOperation() + ->getParentOfType(); + std::string rodataName = sanitizeSymbolName( + (executableOp.getName() + "_" + executableBinaryOp.getName()).str()); + auto rodataOp = rewriter.create( + executableBinaryOp.getLoc(), + IREE::VM::RefType::get(rewriter.getType()), + rewriter.getStringAttr(rodataName), executableBinaryOp.getData(), + rewriter.getI64IntegerAttr(16), executableBinaryOp.getMimeTypeAttr()); + + // Get format string as a rodata blob. + auto executableFormatStr = rewriter.create( + createOp.getLoc(), executableBinaryOp.getFormatAttr()); + + // Pack constants, if any. + auto constantBuffer = createPackedConstantBuffer( + createOp.getLoc(), adaptor.getConstants(), rewriter); + + SmallVector callOperands = { + adaptor.getDevice(), + executableFormatStr, + rodataOp, + constantBuffer, + }; + auto importType = importOp.getFunctionType(); + auto callOp = rewriter.replaceOpWithNewOp( + createOp, SymbolRefAttr::get(importOp), importType.getResults(), + callOperands); + copyImportAttrs(importOp, callOp); + + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + } // namespace void populateHALExecutableToVMPatterns(MLIRContext *context, @@ -162,6 +218,8 @@ void populateHALExecutableToVMPatterns(MLIRContext *context, patterns.insert( context, importSymbols, typeConverter, "hal.executable.create"); + patterns.insert( + context, importSymbols, typeConverter, "hal.executable.create2"); patterns.insert>( context, importSymbols, typeConverter, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 2df69590c1c8..f005985a4490 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -395,3 +395,103 @@ util.func public @command_buffer_dispatch_indirect_indirect( flags(None) util.return } + +// ----- + +// CHECK-LABEL: @command_buffer_dispatch2 +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref, +// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref, +// CHECK-SAME: %[[SLOT:.+]]: i32) +util.func public @command_buffer_dispatch2( + %cmd: !hal.command_buffer, + %executable: !hal.executable, + %buffer: !hal.buffer, + %slot: index +) { + // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123 + // CHECK-DAG: %[[C0:.+]] = vm.const.i32.zero + %ordinal = arith.constant 123 : index + // CHECK-DAG: %[[X:.+]] = vm.const.i32 100 + %x = arith.constant 100 : index + // CHECK-DAG: %[[Y:.+]] = vm.const.i32 200 + %y = arith.constant 200 : index + // CHECK-DAG: %[[Z:.+]] = vm.const.i32 300 + %z = arith.constant 300 : index + // CHECK-DAG: %[[CONSTANT0:.+]] = vm.const.i32 31 + %constant0 = arith.constant 31 : i32 + // CHECK-DAG: %[[CONSTANT1:.+]] = vm.const.i32 32 + %constant1 = arith.constant 32 : i32 + %c4 = arith.constant 4 : index + %c4096 = arith.constant 4096 : index + %c8000 = arith.constant 8000 : index + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call.variadic @hal.command_buffer.dispatch2 + // CHECK-SAME: %[[CMD]], + // CHECK-SAME: %[[EXECUTABLE]], %[[ORDINAL]], + // CHECK-SAME: %[[X]], %[[Y]], %[[Z]], + // CHECK-SAME: %[[FLAGS]], + // CHECK-SAME: [%[[CONSTANT0]], %[[CONSTANT1]]], + // CHECK-SAME: [(%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000), + // CHECK-SAME: (%[[C0]], %[[SLOT]], %[[NULL_BUFFER]], %c4, %c4096)] + hal.command_buffer.dispatch2<%cmd : !hal.command_buffer> + target(%executable : !hal.executable)[%ordinal] + workgroups([%x, %y, %z]) + constants([%constant0, %constant1]) + bindings([ + (%buffer : !hal.buffer)[%c4096, %c8000], + (%slot : index)[%c4, %c4096] + ]) + flags(None) + util.return +} + +// ----- + +// CHECK-LABEL: vm.func private @command_buffer_dispatch2 +// CHECK-SAME: (%[[CMD:[a-z0-9]+]]: !vm.ref, +// CHECK-SAME: %[[EXECUTABLE:[a-z0-9]+]]: !vm.ref, +// CHECK-SAME: %[[WORKGROUPS_SLOT:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[BUFFER:[a-z0-9]+]]: !vm.ref, +// CHECK-SAME: %[[SLOT:[a-z0-9]+]]: i32) +util.func public @command_buffer_dispatch2( + %cmd: !hal.command_buffer, + %executable: !hal.executable, + %workgroups_slot: index, + %buffer: !hal.buffer, + %slot: index +) { + // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123 + // CHECK-DAG: %[[C0:.+]] = vm.const.i32.zero + %ordinal = arith.constant 123 : index + // CHECK-DAG: %[[WORKGROUPS_OFFSET:.+]] = vm.const.i64 100 + %workgroups_offset = arith.constant 100 : index + // CHECK-DAG: %[[CONSTANT0:.+]] = vm.const.i32 31 + %constant0 = arith.constant 31 : i32 + // CHECK-DAG: %[[CONSTANT1:.+]] = vm.const.i32 32 + %constant1 = arith.constant 32 : i32 + %c4 = arith.constant 4 : index + %c4096 = arith.constant 4096 : index + %c8000 = arith.constant 8000 : index + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call.variadic @hal.command_buffer.dispatch2.indirect + // CHECK-SAME: %[[CMD]], + // CHECK-SAME: %[[EXECUTABLE]], %[[ORDINAL]], + // CHECK-SAME: %[[WORKGROUPS_SLOT]], %[[NULL_BUFFER]], %[[WORKGROUPS_OFFSET]], + // CHECK-SAME: %[[FLAGS]], + // CHECK-SAME: [%[[CONSTANT0]], %[[CONSTANT1]]], + // CHECK-SAME: [(%[[C0]], %[[C0]], %[[BUFFER]], %c4096, %c8000), + // CHECK-SAME: (%[[C0]], %[[SLOT]], %[[NULL_BUFFER]], %c4, %c4096)] + hal.command_buffer.dispatch2.indirect<%cmd : !hal.command_buffer> + target(%executable : !hal.executable)[%ordinal] + workgroups(%workgroups_slot : index)[%workgroups_offset] + constants([%constant0, %constant1]) + bindings([ + (%buffer : !hal.buffer)[%c4096, %c8000], + (%slot : index)[%c4, %c4096] + ]) + flags(None) + util.return +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir index 5dd534142627..292cb47c6278 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/executable_ops.mlir @@ -43,6 +43,45 @@ util.func public @executableCreate( // ----- +hal.executable @exe { + hal.executable.binary @binary1 attributes { + data = dense<[0, 1, 2, 3]> : vector<4xi8>, + format = "format1" + } + hal.executable.binary @binary2 attributes { + data = dense<[4, 5, 6, 7]> : vector<4xi8>, + format = "format2" + } +} + +// CHECK-LABEL: @executableCreate2 +util.func public @executableCreate2( + // CHECK-SAME: %[[DEV:.+]]: !vm.ref + %device: !hal.device +) -> (!hal.executable, !hal.executable) { + + // CHECK-DAG: %[[FORMAT1:.+]] = vm.rodata.inline "_utf8_format1_ + // CHECK-DAG: %[[BINARY1:.+]] = vm.rodata.inline "exe_binary1" {alignment = 16 : i64} : !vm.buffer = dense<[0, 1, 2, 3]> : vector<4xi8> + // CHECK-DAG: %[[NULL1:.+]] = vm.const.ref.zero : !vm.buffer + // CHECK: %[[EXE1:.+]] = vm.call @hal.executable.create2( + // CHECK-SAME: %[[DEV]], %[[FORMAT1]], %[[BINARY1]], %[[NULL1]] + // CHECK-SAME: ) {nosideeffects} : (!vm.ref, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref + %0 = hal.executable.create2 device(%device : !hal.device) target(@exe::@binary1) : !hal.executable + + // CHECK-DAG: %[[FORMAT2:.+]] = vm.rodata.inline "_utf8_format2_ + // CHECK-DAG: %[[BINARY2:.+]] = vm.rodata.inline "exe_binary2" {alignment = 16 : i64} : !vm.buffer = dense<[4, 5, 6, 7]> : vector<4xi8> + // CHECK-DAG: %[[NULL2:.+]] = vm.const.ref.zero : !vm.buffer + // CHECK: %[[EXE2:.+]] = vm.call @hal.executable.create2( + // CHECK-SAME: %[[DEV]], %[[FORMAT2]], %[[BINARY2]], %[[NULL2]] + // CHECK-SAME: ) {nosideeffects} : (!vm.ref, !vm.buffer, !vm.buffer, !vm.buffer) -> !vm.ref + %1 = hal.executable.create2 device(%device : !hal.device) target(@exe::@binary2) : !hal.executable + + // CHECK: vm.return %[[EXE1]], %[[EXE2]] + util.return %0, %1 : !hal.executable, !hal.executable +} + +// ----- + hal.executable @exe1 { hal.executable.binary @binary1 attributes { data = dense<[0, 1, 2, 3]> : vector<4xi8>, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index e21a626e1ac5..109701a1d465 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -30,6 +30,13 @@ static llvm::cl::opt clIndirectCommandBuffers{ llvm::cl::init(false), }; +// TODO(#18154): switch default to true and then remove. +static llvm::cl::opt clExperimentalDispatch2{ + "iree-hal-experimental-dispatch2", + llvm::cl::desc("Whether to emit iree_hal_command_buffer_dispatch2 ops."), + llvm::cl::init(false), +}; + struct ContextResolveOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -623,8 +630,8 @@ struct CmdCollectiveOpPattern ConversionPatternRewriter &rewriter) const override { auto commandBufferMapping = mapping->lookupCommandBufferFor(op); - IREE::HAL::BindingTableValue sendBinding; - IREE::HAL::BindingTableValue recvBinding; + IREE::HAL::BindingValue sendBinding; + IREE::HAL::BindingValue recvBinding; switch (adaptor.getOp().getKind()) { default: assert(adaptor.getResources().size() == 2 && "should have verified"); @@ -663,6 +670,7 @@ struct CmdCollectiveOpPattern } }; +// TODO(#18154): switch to dispatch2. struct CmdDispatchOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -845,6 +853,145 @@ struct CmdDispatchOpPattern } }; +struct CmdDispatch2OpPattern + : public StreamConversionPattern { + using StreamConversionPattern::StreamConversionPattern; + LogicalResult + matchAndRewrite(IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = dispatchOp.getLoc(); + auto commandBufferMapping = mapping->lookupCommandBufferFor(dispatchOp); + + // TODO(multi-device): reusable command buffers done at the stream level may + // make this difficult. For now we assume each stream region being lowered + // has a singular affinity that may itself reference multiple devices in the + // future but currently uniquely identifies a device. + auto affinityAttr = IREE::Stream::AffinityAttr::lookupOrDefault(dispatchOp); + + // Get the device handle we're executing against in this execution region. + // Note that this is a dynamic value: we have to treat the device as unknown + // here. + Value device = rewriter.create( + loc, rewriter.getType(), + commandBufferMapping.getHandle()); + + // Prepare for variant switch table by gathering the conditions selecting + // each variant. + SmallVector caseIndices; + SmallVector> + caseExportOps; + dispatchOp.forEachEntryPointAttr([&](SymbolRefAttr entryPointAttr) { + // NOTE: slow lookup! + auto exportOp = + SymbolTable::lookupNearestSymbolFrom( + dispatchOp, entryPointAttr); + assert(exportOp && "dispatch target export not found"); + caseIndices.push_back(caseIndices.size()); + caseExportOps.push_back(std::make_pair(entryPointAttr, exportOp)); + }); + + // If there is only one variant we can emit that directly without a + // conditional check. The same result should occur later on but it saves + // a lot of IR during generation if we know we can avoid it. + if (caseExportOps.size() == 1) { + auto [entryPointAttr, exportOp] = caseExportOps.front(); + rewriter.replaceOp(dispatchOp, + emitDispatchOp(loc, affinityAttr, device, + commandBufferMapping, exportOp, + entryPointAttr, dispatchOp, adaptor, + rewriter)); + } else { + // Select the variant index. + Value selectedIndex = buildIfElseTree( + loc, caseExportOps.size(), + [&](Location loc, size_t i, OpBuilder &builder) { + auto exportOp = caseExportOps[i].second; + auto variantOp = + exportOp->getParentOfType(); + return variantOp.buildCondition(device, rewriter); + }, + rewriter); + + // Allow each variant to define how it is dispatched. + auto switchOp = rewriter.create( + loc, TypeRange{}, selectedIndex, caseIndices, caseIndices.size()); + for (size_t i = 0; i < caseExportOps.size(); ++i) { + auto [entryPointAttr, exportOp] = caseExportOps[i]; + auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); + auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); + emitDispatchOp(loc, affinityAttr, device, commandBufferMapping, + exportOp, entryPointAttr, dispatchOp, adaptor, + caseBuilder); + caseBuilder.create(loc); + } + + // Fallback for no available variant. Today we just no-op as executable + // loading should have already failed. + auto &defaultBlock = switchOp.getDefaultRegion().emplaceBlock(); + auto defaultBuilder = OpBuilder::atBlockBegin(&defaultBlock); + defaultBuilder.create(loc); + + rewriter.replaceOp(dispatchOp, switchOp); + } + + return success(); + } + + Operation *emitDispatchOp( + Location loc, IREE::Stream::AffinityAttr affinityAttr, Value device, + CommandBufferConversionMapping &commandBufferMapping, + IREE::HAL::ExecutableExportOp exportOp, SymbolRefAttr entryPointAttr, + IREE::Stream::CmdDispatchOp dispatchOp, OpAdaptor adaptor, + OpBuilder &builder) const { + auto workgroupCount = exportOp.calculateWorkgroupCount( + loc, device, adaptor.getWorkload(), builder); + + Value executable = builder.create( + loc, builder.getType(), device, + entryPointAttr.getRootReference().getValue()); + Value ordinal = builder.create( + loc, builder.getIndexType(), entryPointAttr); + + // TODO(#18154): simplify bindings by removing descriptor sets. + auto layoutAttr = exportOp.getLayout(); + auto bindingAttrs = IREE::HAL::getInterfaceBindingAttrs( + exportOp, dispatchOp.getResources().size()); + SmallVector bindings; + for (auto [i, bindingAttr] : llvm::enumerate(bindingAttrs)) { + auto descriptorFlags = layoutAttr.getSetLayout(bindingAttr.getSet()) + .getBinding(i) + .getFlags(); + IREE::HAL::BindingValue binding; + if (bitEnumContainsAll(descriptorFlags, + IREE::HAL::DescriptorFlags::Indirect)) { + // Indirect binding resolved through the cached command buffer binding + // table. The buffer recorded in the descriptor is a slot ordinal into + // the binding table. Note that the range may be adjusted based on the + // range bound to the slot in the table. + auto resolvedBinding = commandBufferMapping.resolveBinding( + loc, dispatchOp.getResources()[i], adaptor.getResources()[i], + adaptor.getResourceOffsets()[i], adaptor.getResourceLengths()[i], + builder); + binding.buffer = resolvedBinding.buffer; + binding.byteOffset = resolvedBinding.byteOffset; + binding.byteLength = resolvedBinding.byteLength; + } else { + // Direct binding referencing the buffer and range provided on the op. + binding.buffer = adaptor.getResources()[i]; + binding.byteOffset = adaptor.getResourceOffsets()[i]; + binding.byteLength = adaptor.getResourceLengths()[i]; + } + bindings.push_back(binding); + } + + auto flags = IREE::HAL::DispatchFlags::None; + + return builder.create( + loc, commandBufferMapping.getHandle(), executable, ordinal, + workgroupCount, adaptor.getUniformOperands(), bindings, flags); + } +}; + struct CmdFuncOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -1408,9 +1555,15 @@ void populateStreamToHALPatterns(MLIRContext *context, patterns .insert( + CmdFuncOpPattern, CmdCallOpPattern, CmdExecuteOpPattern, + CmdSerialOpPattern, CmdConcurrentOpPattern>( mapping, typeConverter, context); + // TODO(#18154): drop existing pattern. + if (clExperimentalDispatch2) { + patterns.insert(mapping, typeConverter, context); + } else { + patterns.insert(mapping, typeConverter, context); + } patterns.insert BindingTable::lookupResourceSlot(Value resourceValue) { return std::nullopt; } -IREE::HAL::BindingTableValue CommandBufferConversionMapping::resolveBinding( +IREE::HAL::BindingValue CommandBufferConversionMapping::resolveBinding( Location loc, Value resourceValue, Value bufferValue, Value useOffset, Value useLength, OpBuilder &builder) { - IREE::HAL::BindingTableValue bindingTableValue; + IREE::HAL::BindingValue bindingValue; // Try to resolve the resource to a slot. If not found then it's a direct // reference and we use the buffer provided. auto slot = bindingTable.lookupResourceSlot(resourceValue); if (slot.has_value()) { - bindingTableValue.buffer = slot.value(); + bindingValue.buffer = slot.value(); } else { - bindingTableValue.buffer = bufferValue; + bindingValue.buffer = bufferValue; } // TODO(benvanik): adjust range by the binding table base index. Today all // binding table entries are the full buffers starting at zero. - bindingTableValue.byteOffset = useOffset; - bindingTableValue.byteLength = useLength; + bindingValue.byteOffset = useOffset; + bindingValue.byteLength = useLength; - return bindingTableValue; + return bindingValue; } void StreamConversionMapping::mapCommandBuffer( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h index 925220d3c19f..50f0b0d2da1e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Utils.h @@ -84,7 +84,7 @@ class BindingTable { size_t size() const { return indirectBuffers.size(); } // Builds a binding table (buffer, offset, length) based on the analysis. - ArrayRef getValues() { return indirectBuffers; } + ArrayRef getValues() { return indirectBuffers; } // Returns the binding table slot for the given resource, if it's used // indirectly. @@ -94,7 +94,7 @@ class BindingTable { // True if any ops are nested that may prevent binding table usage. bool hasUnsupportedOps = false; // Buffer binding table with . - SmallVector indirectBuffers; + SmallVector indirectBuffers; // A mapping of resources to binding table slot ordinals. DenseMap indirectSlots; }; @@ -111,10 +111,9 @@ class CommandBufferConversionMapping { // The returned range may differ from the provided used range in cases where // an indirect binding table reference may have already factored in the // offset. - IREE::HAL::BindingTableValue resolveBinding(Location loc, Value resourceValue, - Value bufferValue, - Value useOffset, Value useLength, - OpBuilder &builder); + IREE::HAL::BindingValue resolveBinding(Location loc, Value resourceValue, + Value bufferValue, Value useOffset, + Value useLength, OpBuilder &builder); private: Value handle; diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel index 2d6f7779493a..12dc5203254f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/BUILD.bazel @@ -17,6 +17,7 @@ iree_lit_test_suite( srcs = enforce_glob( [ "channel_ops.mlir", + "cmd_dispatch2_ops.mlir", "cmd_ops.mlir", "context_ops.mlir", "debug_ops.mlir", diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt index b273190ba86a..0aeea90e8242 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "channel_ops.mlir" + "cmd_dispatch2_ops.mlir" "cmd_ops.mlir" "context_ops.mlir" "debug_ops.mlir" diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_dispatch2_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_dispatch2_ops.mlir new file mode 100644 index 000000000000..ce9a4ade5dca --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_dispatch2_ops.mlir @@ -0,0 +1,114 @@ +// RUN: iree-opt --split-input-file --iree-hal-conversion --cse --iree-hal-indirect-command-buffers=true --iree-hal-experimental-dispatch2=true %s | FileCheck %s + +#executable_target_aarch64 = #hal.executable.target<"llvm-cpu", "embedded-elf-aarch64"> +#executable_target_x86_64 = #hal.executable.target<"llvm-cpu", "embedded-elf-x86_64"> +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer, Indirect> + ]> +]> +hal.executable private @ex { + hal.executable.variant public @aarch64 target(#executable_target_aarch64) { + hal.executable.condition(%device: !hal.device) -> i1 { + %ok, %selected = hal.device.query<%device : !hal.device> key("some" :: "feature") : i1, i1 + hal.return %selected : i1 + } + hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + // Opaque at this point (in some target-specific dialects). + } + } + hal.executable.variant public @x86_64 target(#executable_target_x86_64) { + hal.executable.export public @dispatch ordinal(0) layout(#pipeline_layout) attributes { + translation_info = #iree_codegen.translation_info + } { + ^bb0(%device: !hal.device, %arg0: index, %arg1: index, %arg2: index): // no predecessors + %c1 = arith.constant 1 : index + %0 = affine.apply affine_map<()[s0] -> (s0 ceildiv 4)>()[%arg0] + hal.return %0, %c1, %c1 : index, index, index + } + builtin.module { + // Opaque at this point (in some target-specific dialects). + } + } +} + +util.global private @device : !hal.device +util.global private @constant_resource : !stream.resource +util.global private @constant_size : index + +// CHECK-LABEL: @cmdDispatch +// CHECK-SAME: (%[[ARG_RESOURCE:.+]]: !hal.buffer, %[[ARG_SIZE:.+]]: index) +util.func public @cmdDispatch(%arg_resource: !stream.resource, %arg_size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4_i32 = arith.constant 4 : i32 + %c5_i32 = arith.constant 5 : i32 + %c128 = arith.constant 128 : index + // CHECK-DAG: %[[CONSTANT_RESOURCE:.+]] = util.global.load immutable @constant_resource + %constant_resource = util.global.load immutable @constant_resource : !stream.resource + %constant_size = util.global.load immutable @constant_size : index + // CHECK-DAG: %[[DEVICE:.+]] = util.global.load immutable @device + // CHECK: %[[MEMOIZED_CMD:.+]] = hal.device.memoize + // CHECK: %[[CMD:.+]] = hal.command_buffer.create + %0 = stream.cmd.execute on(#hal.device.affinity<@device>) with(%constant_resource as %constant_capture: !stream.resource{%constant_size}, %arg_resource as %arg_capture: !stream.resource{%arg_size}) { + // Switch for each executable variant by checking conditions and ranking: + // CHECK: %[[CMD_DEVICE:.+]] = hal.command_buffer.device<%[[CMD]] : !hal.command_buffer> + // CHECK-DAG: %{{.+}}, %[[AARCH64_FORMAT:.+]] = hal.device.query<%[[CMD_DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-aarch64") + // CHECK-DAG: %[[AARCH64_FEATURE:.+]] = scf.execute_region -> i1 { + // CHECK-NEXT: %{{.+}}, %[[FEATURE:.+]] = hal.device.query<%[[CMD_DEVICE]] : !hal.device> key("some" :: "feature") + // CHECK-NEXT: scf.yield %[[FEATURE]] + // CHECK-NEXT: } + // CHECK-DAG: %[[AARCH64_SELECTED:.+]] = arith.andi %[[AARCH64_FORMAT]], %[[AARCH64_FEATURE]] + // CHECK-DAG: %{{.+}}, %[[X86_64_SELECTED:.+]] = hal.device.query<%[[CMD_DEVICE]] : !hal.device> key("hal.executable.format" :: "embedded-elf-x86_64") + // CHECK: %[[VARIANT1:.+]] = arith.select %[[X86_64_SELECTED]], %c1 + // CHECK: %[[VARIANT0:.+]] = arith.select %[[AARCH64_SELECTED]], %c0, %[[VARIANT1]] + // CHECK: scf.index_switch %[[VARIANT0]] + // CHECK-NEXT: case 0 { + + // Inlined workgroup count calculation: + // CHECK: %[[X:.+]] = affine.apply #map()[%c1] + + // Target executable/export: + // CHECK-DAG: %[[EXECUTABLE_0:.+]] = hal.executable.lookup + // CHECK-SAME: device(%[[CMD_DEVICE]] : !hal.device) + // CHECK-SAME: executable(@ex) : !hal.executable + // CHECK-DAG: %[[ORDINAL_0:.+]] = hal.executable.export.ordinal + // CHECK-SAME: target(@ex::@aarch64::@dispatch) : index + + // Dispatch: + // CHECK: hal.command_buffer.dispatch2<%[[CMD]] + // CHECK-SAME: target(%[[EXECUTABLE_0]] : !hal.executable)[%[[ORDINAL_0]]] + // CHECK-SAME: workgroups([%[[X]], %c1, %c1]) + // CHECK-SAME: constants([%c4_i32, %c5_i32]) + // CHECK-SAME: bindings([ + // CHECK-NEXT: (%[[CONSTANT_RESOURCE]] : !hal.buffer)[%c0, %c128], + // CHECK-NEXT: (%c0 : index)[%c0, %c128] + + // Other variant, when selected: + // CHECK: case 1 { + // CHECK-DAG: %[[ORDINAL_1:.+]] = hal.executable.export.ordinal target(@ex::@x86_64::@dispatch) + // CHECK: hal.command_buffer.dispatch2<%[[CMD]] + // CHECK-SAME: target({{.+}})[%[[ORDINAL_1]]] + stream.cmd.dispatch {@ex::@aarch64::@dispatch, @ex::@x86_64::@dispatch}[%c1, %c2, %c3](%c4_i32, %c5_i32 : i32, i32) { + ro %constant_capture[%c0 for %c128] : !stream.resource{%constant_size}, + wo %arg_capture[%c0 for %c128] : !stream.resource{%arg_size} + } + // CHECK: hal.command_buffer.execution_barrier<%[[CMD]] + } => !stream.timepoint + // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] + // CHECK: hal.device.queue.execute.indirect<%[[DEVICE]] : !hal.device> {{.+}} commands(%[[MEMOIZED_CMD]]) bindings([ + // CHECK-NEXT: (%[[ARG_RESOURCE]] : !hal.buffer)[%c0, %[[ARG_SIZE]]] + // CHECK-NEXT: ]) + util.return %0 : !stream.timepoint +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index 2933f27c801b..78fdad307cc6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -168,10 +168,12 @@ def HAL_DescriptorTypeAttr : def HAL_DescriptorFlags_None : I32BitEnumAttrCase<"None", 0x0000>; def HAL_DescriptorFlags_ReadOnly : I32BitEnumAttrCase<"ReadOnly", 0x0001>; +def HAL_DescriptorFlags_Indirect : I32BitEnumAttrCase<"Indirect", 0x0002>; def HAL_DescriptorFlagsAttr : I32BitEnumAttr<"DescriptorFlags", "valid Descriptor flags", [ HAL_DescriptorFlags_None, HAL_DescriptorFlags_ReadOnly, + HAL_DescriptorFlags_Indirect, ]> { let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; } @@ -387,7 +389,7 @@ def HAL_DescriptorSetBindingAttr : let parameters = (ins AttrParameter<"int64_t", "">:$ordinal, AttrParameter<"DescriptorType", "">:$type, - OptionalParameter<"std::optional">:$flags + OptionalParameter<"DescriptorFlags", "DescriptorFlags::None">:$flags ); let assemblyFormat = [{ `<` $ordinal `,` $type (`,` $flags^)? `>` diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp index fcf5ae403f85..811789ccc782 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALDialect.cpp @@ -104,9 +104,7 @@ class HALToVMConversionInterface : public VMConversionDialectInterface { fn(IntegerAttr::get(IndexType::get(context), APInt(64, bindingAttr.getOrdinal()))); fn(IREE::HAL::DescriptorTypeAttr::get(context, bindingAttr.getType())); - fn(IREE::HAL::DescriptorFlagsAttr::get( - context, - bindingAttr.getFlags().value_or(IREE::HAL::DescriptorFlags::None))); + fn(IREE::HAL::DescriptorFlagsAttr::get(context, bindingAttr.getFlags())); return success(); } if (auto dtAttr = llvm::dyn_cast(attr)) { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index e90d0e400c81..6b787a76a410 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -111,6 +111,61 @@ static void printDescriptorSetBindings(OpAsmPrinter &p, Operation *op, p.printNewline(); } +//===----------------------------------------------------------------------===// +// custom($binding_buffers, +// type($binding_buffers), +// $binding_offsets, +// $binding_lengths) +//===----------------------------------------------------------------------===// + +static ParseResult +parseBindings(OpAsmParser &parser, + SmallVectorImpl &buffers, + SmallVectorImpl &bufferTypes, + SmallVectorImpl &bufferOffsets, + SmallVectorImpl &bufferLengths) { + do { + OpAsmParser::UnresolvedOperand buffer; + Type bufferType; + OpAsmParser::UnresolvedOperand bufferOffset; + OpAsmParser::UnresolvedOperand bufferLength; + if (failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) || + failed(parser.parseColonType(bufferType)) || + failed(parser.parseRParen()) || failed(parser.parseLSquare()) || + failed(parser.parseOperand(bufferOffset)) || + failed(parser.parseComma()) || + failed(parser.parseOperand(bufferLength)) || + failed(parser.parseRSquare())) { + return failure(); + } + buffers.push_back(buffer); + bufferTypes.push_back(bufferType); + bufferOffsets.push_back(bufferOffset); + bufferLengths.push_back(bufferLength); + } while (succeeded(parser.parseOptionalComma())); + return success(); +} + +static void printBindings(OpAsmPrinter &p, Operation *op, ValueRange buffers, + TypeRange bufferTypes, ValueRange bufferOffsets, + ValueRange bufferLengths) { + llvm::interleaveComma( + llvm::zip_equal(buffers, bufferTypes, bufferOffsets, bufferLengths), p, + [&](std::tuple it) { + p.printNewline(); + p << " ("; + p.printOperand(std::get<0>(it)); + p << " : "; + p.printType(std::get<1>(it)); + p << ")["; + p.printOperand(std::get<2>(it)); + p << ", "; + p.printOperand(std::get<3>(it)); + p << "]"; + }); + p.printNewline(); +} + //===----------------------------------------------------------------------===// // custom($binding_buffers, // type($binding_buffers), @@ -1054,6 +1109,108 @@ void CommandBufferPushDescriptorSetOp::build( state.addOperands(bindingLengths); } +//===----------------------------------------------------------------------===// +// hal.command_buffer.dispatch2 + .indirect +//===----------------------------------------------------------------------===// + +void CommandBufferDispatch2Op::build(OpBuilder &builder, OperationState &state, + Value commandBuffer, Value executable, + Value entryPoint, ValueRange workgroups, + ValueRange constants, + ArrayRef bindings, + IREE::HAL::DispatchFlags flags) { + state.addOperands({commandBuffer, executable, entryPoint}); + state.addOperands(workgroups); + state.addOperands(constants); + SmallVector bindingBuffers; + SmallVector bindingOffsets; + SmallVector bindingLengths; + for (auto binding : bindings) { + bindingBuffers.push_back(binding.buffer); + bindingOffsets.push_back(binding.byteOffset); + bindingLengths.push_back(binding.byteLength); + } + state.addOperands(bindingBuffers); + state.addOperands(bindingOffsets); + state.addOperands(bindingLengths); + state.addAttribute("flags", + builder.getAttr(flags)); + state.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + 1, + 1, + 1, + 1, + 1, + 1, + static_cast(constants.size()), + static_cast(bindingBuffers.size()), + static_cast(bindingOffsets.size()), + static_cast(bindingLengths.size()), + })); +} + +void CommandBufferDispatch2IndirectOp::build( + OpBuilder &builder, OperationState &state, Value commandBuffer, + Value executable, Value entryPoint, Value workgroupsBuffer, + Value workgroupsOffset, ValueRange constants, + ArrayRef bindings, IREE::HAL::DispatchFlags flags) { + state.addOperands({commandBuffer, executable, entryPoint, workgroupsBuffer, + workgroupsOffset}); + state.addOperands(constants); + SmallVector bindingBuffers; + SmallVector bindingOffsets; + SmallVector bindingLengths; + for (auto binding : bindings) { + bindingBuffers.push_back(binding.buffer); + bindingOffsets.push_back(binding.byteOffset); + bindingLengths.push_back(binding.byteLength); + } + state.addOperands(bindingBuffers); + state.addOperands(bindingOffsets); + state.addOperands(bindingLengths); + state.addAttribute("flags", + builder.getAttr(flags)); + state.addAttribute(getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + 1, + 1, + 1, + 1, + 1, + static_cast(constants.size()), + static_cast(bindingBuffers.size()), + static_cast(bindingOffsets.size()), + static_cast(bindingLengths.size()), + })); +} + +static LogicalResult verifyDispatch2Bindings(Operation *op, + ValueRange bindingBuffers, + ValueRange bindingOffsets, + ValueRange bindingLengths) { + if (bindingBuffers.size() != bindingOffsets.size() || + bindingBuffers.size() != bindingLengths.size()) { + return op->emitOpError() << "requires that binding fields all have the " + "same number of elements"; + } + return success(); +} + +LogicalResult CommandBufferDispatch2Op::verify() { + CommandBufferDispatch2Op op = *this; + return verifyDispatch2Bindings(op, op.getBindingBuffers(), + op.getBindingOffsets(), + op.getBindingLengths()); +} + +LogicalResult CommandBufferDispatch2IndirectOp::verify() { + CommandBufferDispatch2IndirectOp op = *this; + return verifyDispatch2Bindings(op, op.getBindingBuffers(), + op.getBindingOffsets(), + op.getBindingLengths()); +} + //===----------------------------------------------------------------------===// // hal.descriptor_set_layout.create //===----------------------------------------------------------------------===// @@ -1165,7 +1322,7 @@ void DeviceQueueExecuteIndirectOp::build(OpBuilder &builder, OperationState &state, Value device, Value queueAffinity, Value waitFence, Value signalFence, Value commandBuffer, - ArrayRef bindings) { + ArrayRef bindings) { state.addOperands( {device, queueAffinity, waitFence, signalFence, commandBuffer}); SmallVector bindingBuffers; @@ -1748,6 +1905,16 @@ void ExecutableCreateOp::getAsmResultNames( setNameFn(getResult(), StringRef("exe")); } +//===----------------------------------------------------------------------===// +// hal.executable.create2 +//===----------------------------------------------------------------------===// + +void ExecutableCreate2Op::getAsmResultNames( + function_ref setNameFn) { + // TODO(benvanik): name after sanitized symbol. + setNameFn(getResult(), StringRef("executable")); +} + //===----------------------------------------------------------------------===// // hal.executable.lookup //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index a268dbcd244b..a6fe1f26116e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -1469,6 +1469,7 @@ def HAL_CommandBufferCollectiveOp : HAL_Op<"command_buffer.collective", [ }]; } +// TODO(#18154): deprecated and will be replaced with simplified bindings. def HAL_CommandBufferPushConstantsOp : HAL_Op<"command_buffer.push_constants"> { let summary = [{command buffer push constants operation}]; let description = [{ @@ -1496,6 +1497,7 @@ def HAL_CommandBufferPushConstantsOp : HAL_Op<"command_buffer.push_constants"> { }]; } +// TODO(#18154): deprecated and will be replaced with simplified bindings. def HAL_CommandBufferPushDescriptorSetOp : HAL_Op<"command_buffer.push_descriptor_set", [ SameVariadicOperandSize, ]> { @@ -1541,6 +1543,7 @@ def HAL_CommandBufferPushDescriptorSetOp : HAL_Op<"command_buffer.push_descripto let hasCanonicalizer = 1; } +// TODO(#18154): deprecated and will be replaced with simplified bindings. def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> { let summary = [{command buffer dispatch recording operation}]; let description = [{ @@ -1571,6 +1574,7 @@ def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> { }]; } +// TODO(#18154): deprecated and will be replaced with simplified bindings. def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indirect"> { let summary = [{command buffer indirect dispatch recording operation}]; let description = [{ @@ -1598,6 +1602,139 @@ def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indire }]; } +def HAL_CommandBufferDispatch2Op : HAL_Op<"command_buffer.dispatch2", [ + AttrSizedOperandSegments, +]> { + let summary = [{command buffer dispatch recording operation}]; + let description = [{ + Dispatches an execution request. + The request may execute overlapped with any other transfer operation or + dispatch made within the same barrier-defined sequence. + + The provided constant data and binding list will be recorded into the + command buffer and need not remain live beyond the call. Push constants are + always 4-byte values and treated as opaque, meaning that they may be + bit-casted floats, bit-packed booleans, etc. The provided buffers may either + be HAL buffers or indirect references into the command buffer binding table. + }]; + + let arguments = (ins + HAL_CommandBuffer:$command_buffer, + HAL_Executable:$executable, + HAL_Ordinal:$entry_point, + HAL_Dim:$workgroup_x, + HAL_Dim:$workgroup_y, + HAL_Dim:$workgroup_z, + Variadic:$constants, + Variadic>:$binding_buffers, + Variadic:$binding_offsets, + Variadic:$binding_lengths, + HAL_DispatchFlagsAttr:$flags + ); + + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $executable `:` type($executable) `)` + `` `[` $entry_point `]` + `workgroups` `(` `[` + $workgroup_x `,` + $workgroup_y `,` + $workgroup_z + `]` `)` + (`constants` `(` `[` $constants^ `]` `)`)? + `bindings` `(` `[` + custom($binding_buffers, + type($binding_buffers), + $binding_offsets, + $binding_lengths) + `]` `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + "Value":$commandBuffer, + "Value":$executable, + "Value":$entryPoint, + "ValueRange":$workgroups, + "ValueRange":$constants, + "ArrayRef":$bindings, + "IREE::HAL::DispatchFlags":$flags + )>, + ]; + + let hasVerifier = 1; +} + +def HAL_CommandBufferDispatch2IndirectOp : HAL_Op<"command_buffer.dispatch2.indirect", [ + AttrSizedOperandSegments, +]> { + let summary = [{command buffer indirect dispatch recording operation}]; + let description = [{ + Dispatches an execution request with a deferred workgroup count. + This is the same as iree_hal_command_buffer_dispatch but the workgroup count + is read from the given |workgroups_ref| buffer at the specified offset as + 3 uint32_t XYZ values immediately before performing the dispatch. This + allows prior dispatches within the command sequence to populate the + workgroup count or the workgroup count to change across submissions of the + same reusable command buffer. + + The provided constant data and binding list will be recorded into the + command buffer and need not remain live beyond the call. Push constants are + always 4-byte values and treated as opaque, meaning that they may be + bit-casted floats, bit-packed booleans, etc. The provided buffers may either + be HAL buffers or indirect references into the command buffer binding table. + }]; + + let arguments = (ins + HAL_CommandBuffer:$command_buffer, + HAL_Executable:$executable, + HAL_Ordinal:$entry_point, + AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer, + HAL_DeviceSize:$workgroups_offset, + Variadic:$constants, + Variadic>:$binding_buffers, + Variadic:$binding_offsets, + Variadic:$binding_lengths, + HAL_DispatchFlagsAttr:$flags + ); + + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `target` `(` $executable `:` type($executable) `)` + `` `[` $entry_point `]` + `workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)` + `` `[` $workgroups_offset `]` + (`constants` `(` `[` $constants^ `]` `)`)? + `bindings` `(` `[` + custom($binding_buffers, + type($binding_buffers), + $binding_offsets, + $binding_lengths) + `]` `)` + `flags` `(` $flags `)` + attr-dict-with-keyword + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins + "Value":$commandBuffer, + "Value":$executable, + "Value":$entryPoint, + "Value":$workgroupsBuffer, + "Value":$workgroupsOffset, + "ValueRange":$constants, + "ArrayRef":$bindings, + "IREE::HAL::DispatchFlags":$flags + )>, + ]; + + let hasVerifier = 1; +} + } // OpGroupCommandBufferOps //===----------------------------------------------------------------------===// @@ -2060,7 +2197,7 @@ def HAL_DeviceQueueExecuteIndirectOp : HAL_Op<"device.queue.execute.indirect", [ "Value":$waitFence, "Value":$signalFence, "Value":$commandBuffer, - "ArrayRef":$bindings + "ArrayRef":$bindings )>, ]; @@ -2663,6 +2800,7 @@ def HAL_ExecutableBinaryOp : HAL_Op<"executable.binary", [ ]; } +// TODO(#18154): deprecated and will be replaced with simplified bindings. def HAL_ExecutableCreateOp : HAL_PureOp<"executable.create", [ DeclareOpInterfaceMethods, AttrSizedOperandSegments, @@ -2702,6 +2840,42 @@ def HAL_ExecutableCreateOp : HAL_PureOp<"executable.create", [ }]; } +def HAL_ExecutableCreate2Op : HAL_PureOp<"executable.create2", [ + DeclareOpInterfaceMethods, +]> { + let summary = [{creates an executable}]; + let description = [{ + Creates a target-dependent executable cached on the provided device. Entry + points contained within the executable can be dispatched using the resulting + executable handle. + + Depending on the driver creation may take a non-trivial amount of time + (such as when JITing/etc). As the cache is internally synchronized callers + can issue preparation requests from multiple threads - even for the same + executables - and calls will block until preparation completes. + + Optional constants provide for specialization of the executable based on + runtime-derived parameters. + }]; + + let arguments = (ins + HAL_Device:$device, + SymbolRefAttr:$executable_target, + Variadic:$constants + ); + let results = (outs + HAL_Executable:$result + ); + + let assemblyFormat = [{ + `device` `(` $device `:` type($device) `)` + `target` `(` $executable_target `)` + (`constants` `(` `[` $constants^ `]` `)`)? + `:` type($result) + attr-dict-with-keyword + }]; +} + def HAL_ExecutableLookupOp : HAL_PureOp<"executable.lookup", [ DeclareOpInterfaceMethods, ]> { diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h index ef6417dac598..cdb29e032a01 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.h @@ -185,7 +185,10 @@ struct DescriptorSetBindingValue { Value byteLength; }; -struct BindingTableValue { +// A tuple containing runtime values for a binding. +// The buffer specified may be either a !hal.buffer or an index of a binding +// table slot to source the buffer from. +struct BindingValue { Value buffer; Value byteOffset; Value byteLength; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir index ec0cbdd1d60b..d04c7d681407 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/attributes.mlir @@ -5,8 +5,8 @@ "descriptor_set_layout_binding.basic"() { // CHECK: dslb0 = #hal.descriptor_set.binding<0, uniform_buffer> dslb0 = #hal.descriptor_set.binding<0, uniform_buffer>, - // CHECK: dslb1 = #hal.descriptor_set.binding<1, storage_buffer> - dslb1 = #hal.descriptor_set.binding<1, storage_buffer> + // CHECK: dslb1 = #hal.descriptor_set.binding<1, storage_buffer, "ReadOnly|Indirect"> + dslb1 = #hal.descriptor_set.binding<1, storage_buffer, "ReadOnly|Indirect"> } : () -> () // ----- diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir index 7408ad9edbef..c3ee5543fbe5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/executable_ops.mlir @@ -160,6 +160,19 @@ util.func public @executable_create( // ----- +// CHECK-LABEL: @executable_create2 +// CHECK-SAME: %[[DEVICE:.+]]: !hal.device +util.func public @executable_create2(%device: !hal.device) { + // CHECK: = hal.executable.create + // CHECK-SAME: device(%[[DEVICE]] : !hal.device) + // CHECK-SAME: target(@exe::@binary1) : !hal.executable + %0 = hal.executable.create2 device(%device : !hal.device) + target(@exe::@binary1) : !hal.executable + util.return +} + +// ----- + // CHECK-LABEL: @pipeline_layout_create // CHECK-SAME: %[[DEVICE:.+]]: !hal.device, // CHECK-SAME: %[[LAYOUT0:.+]]: !hal.descriptor_set_layout, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp index 7cc0471bc7ef..37d960307d21 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeInterfaces.cpp @@ -268,10 +268,7 @@ makePipelineLayoutAttr(const PipelineLayout &pipelineLayout, SmallVector bindingAttrs; for (const auto &binding : setLayout.bindings) { bindingAttrs.push_back(IREE::HAL::DescriptorSetBindingAttr::get( - builder.getContext(), binding.ordinal, binding.type, - binding.flags != IREE::HAL::DescriptorFlags::None - ? binding.flags - : std::optional{})); + builder.getContext(), binding.ordinal, binding.type, binding.flags)); } setLayoutAttrs.push_back(IREE::HAL::DescriptorSetLayoutAttr::get( builder.getContext(), setLayout.ordinal, bindingAttrs, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp index de22093e4e29..16c57e23a671 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/MaterializeResourceCaches.cpp @@ -32,6 +32,13 @@ namespace mlir::iree_compiler::IREE::HAL { namespace { +// TODO(#18154): switch default to true and then remove. +static llvm::cl::opt clExperimentalExecutableCreate2{ + "iree-hal-experimental-executable-create2", + llvm::cl::desc("Whether to emit iree_hal_executable_create2 ops."), + llvm::cl::init(false), +}; + //===----------------------------------------------------------------------===// // --iree-hal-materialize-resource-caches //===----------------------------------------------------------------------===// @@ -248,15 +255,6 @@ static Value initializeExecutable(DeviceResources &deviceResources, auto &caseBlock = switchOp.getCaseRegions()[i].emplaceBlock(); auto caseBuilder = OpBuilder::atBlockBegin(&caseBlock); - // Gather each of the pipeline layouts needed for each entry point in - // the executable. - SmallVector pipelineLayoutValues; - for (auto exportOp : variantOp.getExportOps()) { - auto &pipelineLayout = - deviceResources.pipelineLayouts[exportOp.getLayoutAttr()]; - pipelineLayoutValues.push_back(pipelineLayout.initializerValue); - } - // Inline constant initializer from the variant. // We want these to all happen inside of this device switch case; they'll // get deduplicated/hoisted if possible in future canonicalization passes. @@ -270,13 +268,31 @@ static Value initializeExecutable(DeviceResources &deviceResources, blockName, blockOp, moduleBuilder, caseBuilder, initializerDevice)); } - Value executableValue = - caseBuilder.createOrFold( - loc, executableType, initializerDevice, - SymbolRefAttr::get( - executable.executableOp.getSymNameAttr(), - {SymbolRefAttr::get(variantOp.getSymNameAttr())}), - pipelineLayoutValues, constantValues); + Value executableValue; + if (clExperimentalExecutableCreate2) { + executableValue = + caseBuilder.createOrFold( + loc, executableType, initializerDevice, + SymbolRefAttr::get( + executable.executableOp.getSymNameAttr(), + {SymbolRefAttr::get(variantOp.getSymNameAttr())}), + constantValues); + } else { + // Gather each of the pipeline layouts needed for each entry point in + // the executable. + SmallVector pipelineLayoutValues; + for (auto exportOp : variantOp.getExportOps()) { + auto &pipelineLayout = + deviceResources.pipelineLayouts[exportOp.getLayoutAttr()]; + pipelineLayoutValues.push_back(pipelineLayout.initializerValue); + } + + executableValue = caseBuilder.createOrFold( + loc, executableType, initializerDevice, + SymbolRefAttr::get(executable.executableOp.getSymNameAttr(), + {SymbolRefAttr::get(variantOp.getSymNameAttr())}), + pipelineLayoutValues, constantValues); + } caseBuilder.create(loc, executableValue); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir index d350e0e038c0..5623e7fccf0e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/materialize_interfaces.mlir @@ -13,9 +13,9 @@ util.global private @default_device = #hal.device.target<"cpu", [ // CHECK-SAME: push_constants = 1 // CHECK-SAME: sets = [ // CHECK-SAME: <0, bindings = [ -// CHECK-SAME: <0, storage_buffer, ReadOnly> -// CHECK-SAME: <1, storage_buffer, ReadOnly> -// CHECK-SAME: <2, storage_buffer> +// CHECK-SAME: <0, storage_buffer, "ReadOnly|Indirect"> +// CHECK-SAME: <1, storage_buffer, "ReadOnly|Indirect"> +// CHECK-SAME: <2, storage_buffer, Indirect> // CHECK: hal.executable private @ex // CHECK: hal.executable.variant public @arm_64 target(#executable_target_arm_64 diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 66f8dd7af602..9cd824d6dadf 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -286,6 +286,8 @@ vm.import private @command_buffer.collective( %element_count : i64 ) +// TODO(#18154): remove this in favor of inlined constants. +// // Pushes constants for consumption by dispatches. vm.import private @command_buffer.push_constants( %command_buffer : !vm.ref, @@ -294,6 +296,8 @@ vm.import private @command_buffer.push_constants( %values : i32 ... ) +// TODO(#18154): remove this in favor of inlined bindings. +// // Pushes a descriptor set to the given set number. vm.import private @command_buffer.push_descriptor_set( %command_buffer : !vm.ref, @@ -326,6 +330,45 @@ vm.import private @command_buffer.dispatch.indirect( %flags : i64 ) +// TODO(#18154): replace @command_buffer.dispatch. +// +// Dispatches an execution request. +vm.import private @command_buffer.dispatch2( + %command_buffer : !vm.ref, + %executable : !vm.ref, + %entry_point : i32, + %workgroup_x : i32, + %workgroup_y : i32, + %workgroup_z : i32, + %flags : i64, + %constants : i32 ..., + // + %bindings : tuple, i64, i64>... +) +attributes { + minimum_version = 4 : i32 +} + +// TODO(#18154): replace @command_buffer.dispatch.indirect. +// +// Dispatches an execution request with the dispatch parameters loaded from the +// given buffer. +vm.import private @command_buffer.dispatch2.indirect( + %command_buffer : !vm.ref, + %executable : !vm.ref, + %entry_point : i32, + %workgroups_buffer_slot : i32, + %workgroups_buffer : !vm.ref, + %workgroups_offset : i64, + %flags : i64, + %constants : i32 ..., + // + %bindings : tuple, i64, i64>... +) +attributes { + minimum_version = 4 : i32 +} + //===----------------------------------------------------------------------===// // iree_hal_descriptor_set_layout_t //===----------------------------------------------------------------------===// @@ -468,6 +511,19 @@ vm.import private @executable.create( ) -> !vm.ref attributes {nosideeffects} +// TODO(#18154): replace @executable.create. +// Creates an executable for use with the specified device. +vm.import private @executable.create2( + %device : !vm.ref, + %executable_format : !vm.buffer, + %executable_data : !vm.buffer, + %constants : !vm.buffer +) -> !vm.ref +attributes { + minimum_version = 4 : i32, + nosideeffects +} + //===----------------------------------------------------------------------===// // iree_hal_fence_t //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp index 3f9d68001747..266b1e1835ac 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -147,9 +147,11 @@ void FuncOp::setReflectionAttr(StringRef name, Attribute value) { // TODO(benvanik): remove reflection attrs as a concept and use something more // MLIRish like an attribute interface/dialect interface. // DictionaryAttr is not very friendly for modification :/ - auto existingAttr = - getOperation()->getAttrOfType("iree.reflection"); - SmallVector attrs(existingAttr.begin(), existingAttr.end()); + SmallVector attrs; + if (auto existingAttr = + getOperation()->getAttrOfType("iree.reflection")) { + llvm::append_range(attrs, existingAttr); + } bool didFind = false; for (size_t i = 0; i < attrs.size(); ++i) { if (attrs[i].getName() == name) { diff --git a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index ea5cb95a4b64..fbf8876c2df2 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -202,12 +202,23 @@ static iree_vm_AttrDef_vec_ref_t makeAttrDefs(DictionaryAttr attrs, SmallVector attrRefs; for (auto attr : attrs) { auto key = attr.getName().strref(); - auto value = llvm::dyn_cast(attr.getValue()); - if (!value || key.empty()) + if (key.empty()) { continue; + } + std::string value; + if (auto stringAttr = dyn_cast(attr.getValue())) { + value = stringAttr.getValue().str(); + } else if (auto integerAttr = dyn_cast(attr.getValue())) { + SmallVector str; + integerAttr.getValue().toStringSigned(str); + value.append(str.data(), str.size()); + } else { + assert(false && "expected string or integer reflection attr"); + continue; + } // NOTE: if we actually want to keep these we should dedupe them (as the // keys and likely several of the values are shared across all functions). - auto valueRef = fbb.createString(value.getValue()); + auto valueRef = fbb.createString(value); auto keyRef = fbb.createString(key); attrRefs.push_back(iree_vm_AttrDef_create(fbb, keyRef, valueRef)); } diff --git a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp index 6e8ecc0bcbb0..f3b04384394e 100644 --- a/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp +++ b/compiler/src/iree/compiler/InputConversion/Common/IREEImportPublic.cpp @@ -78,17 +78,16 @@ convertDescriptorType(IREE::Input::DescriptorType src) { } } -static std::optional +static IREE::HAL::DescriptorFlags convertDescriptorFlags(std::optional src) { if (!src.has_value()) - return std::nullopt; + return IREE::HAL::DescriptorFlags::None; switch (*src) { + default: case IREE::Input::DescriptorFlags::None: return IREE::HAL::DescriptorFlags::None; case IREE::Input::DescriptorFlags::ReadOnly: return IREE::HAL::DescriptorFlags::ReadOnly; - default: - return std::nullopt; } } diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c index de89e4feabbe..2a7047bfa1dd 100644 --- a/experimental/webgpu/command_buffer.c +++ b/experimental/webgpu/command_buffer.c @@ -270,6 +270,16 @@ static void iree_hal_webgpu_command_buffer_reset( iree_hal_webgpu_command_segment_list_reset(&command_buffer->segments); iree_arena_reset(&command_buffer->arena); + // Pad up to IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX with empty bind groups. + WGPUBindGroup empty_handle = command_buffer->staging_buffer->empty_bind_group; + for (iree_host_size_t i = 0; i < IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX; + ++i) { + wgpuComputePassEncoderSetBindGroup(compute_pass, (uint32_t)i, empty_handle, + 0, NULL); + command_buffer->state.bind_groups[i].handle = empty_handle; + command_buffer->state.bind_groups_empty |= 1ull << i; + } + IREE_TRACE_ZONE_END(z0); } @@ -802,7 +812,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_push_descriptor_set( static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch( iree_hal_webgpu_command_buffer_t* command_buffer, iree_hal_executable_t* executable, uint32_t ordinal, - WGPUComputePassEncoder* out_compute_pass) { + iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings, + iree_hal_dispatch_flags_t flags, WGPUComputePassEncoder* out_compute_pass) { const iree_hal_webgpu_entry_point_t* entry_point = iree_hal_webgpu_executable_lookup_entry_point(executable, ordinal); @@ -915,6 +926,111 @@ static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect( return iree_ok_status(); } +static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch2( + iree_hal_webgpu_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, uint32_t ordinal, + iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings, + iree_hal_dispatch_flags_t flags, WGPUComputePassEncoder* out_compute_pass) { + const iree_hal_webgpu_entry_point_t* entry_point = + iree_hal_webgpu_executable_lookup_entry_point(executable, ordinal); + + // Upload push constant data - this may incur a segment flush if the staging + // buffer is exhausted. + uint32_t params_offset = 0; + if (!iree_const_byte_span_is_empty(constants)) { + IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_append_parameters( + command_buffer, constants, ¶ms_offset)); + } + + // Acquire the compute pass we'll encode the dispatch into - this may be + // fresh or reused from prior commands. + WGPUComputePassEncoder compute_pass = NULL; + IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_acquire_compute_pass( + command_buffer, &compute_pass)); + wgpuComputePassEncoderSetPipeline(compute_pass, entry_point->pipeline); + + if (!iree_const_byte_span_is_empty(constants)) { + // Bind the push constant emulation bind group at the staging buffer + // relative offset for this dispatch. + wgpuComputePassEncoderSetBindGroup( + compute_pass, IREE_HAL_WEBGPU_PARAMS_BIND_GROUP_INDEX, + command_buffer->staging_buffer->bind_group, 1, ¶ms_offset); + } + + // Set all bindings. + const iree_hal_webgpu_set_binding_info_t* binding_info = + iree_hal_webgpu_pipeline_layout_set_binding_info(entry_point->layout); + + // TODO: change the bind group cache to take the bindings list directly and + // avoid this copy. + iree_hal_webgpu_bind_group_binding_t* group_bindings = + (iree_hal_webgpu_bind_group_binding_t*)iree_alloca( + bindings.count * sizeof(iree_hal_webgpu_bind_group_binding_t)); + iree_hal_webgpu_binding_mask_t binding_mask = 0; + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + binding_mask |= 1u << i; + group_bindings[i].type = WGPUBufferBindingType_Storage; + group_bindings[i].buffer = + bindings[i].buffer ? iree_hal_webgpu_buffer_handle(bindings[i].buffer) + : NULL; + group_bindings[i] offset = bindings[i].offset; + group_bindings[i] length = bindings[i].length; + } + + // Acquire the bind group to use for the current descriptor set. + WGPUBindGroup handle = iree_hal_webgpu_bind_group_cache_acquire( + command_buffer->bind_group_cache, binding_info->set_layout, + group_bindings, binding_mask); + + // NOTE: today we don't support dynamic offsets for push descriptor sets. + // This will be a larger change we'll need to handle in the compiler. If we + // wanted to improve caching we could make all the bindings dynamic and then + // always cache the base offsets, however + // maxDynamicStorageBuffersPerPipelineLayout is minimally 4 and that's not + // a lot of bindings. + wgpuComputePassEncoderSetBindGroup(compute_pass, 0, handle, 0, NULL); + + *out_compute_pass = compute_pass; + return iree_ok_status(); +} + +static iree_status_t iree_hal_webgpu_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_webgpu_command_buffer_t* command_buffer = + iree_hal_webgpu_command_buffer_cast(base_command_buffer); + + WGPUComputePassEncoder compute_pass = NULL; + IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_prepare_dispatch2( + command_buffer, executable, entry_point, constants, bindings, flags, + &compute_pass)); + wgpuComputePassEncoderDispatchWorkgroups( + compute_pass, workgroup_count[0], workgroup_count[1], workgroup_count[2]); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_webgpu_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_webgpu_command_buffer_t* command_buffer = + iree_hal_webgpu_command_buffer_cast(base_command_buffer); + + WGPUComputePassEncoder compute_pass = NULL; + IREE_RETURN_IF_ERROR(iree_hal_webgpu_command_buffer_prepare_dispatch2( + command_buffer, executable, entry_point, constants, bindings, flags, + &compute_pass)); + wgpuComputePassEncoderDispatchWorkgroupsIndirect( + compute_pass, iree_hal_webgpu_buffer_handle(workgroups_ref.buffer), + workgroups_ref.offset); + + return iree_ok_status(); +} + const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = { .destroy = iree_hal_webgpu_command_buffer_destroy, .begin = iree_hal_webgpu_command_buffer_begin, @@ -933,4 +1049,6 @@ const iree_hal_command_buffer_vtable_t iree_hal_webgpu_command_buffer_vtable = { .push_descriptor_set = iree_hal_webgpu_command_buffer_push_descriptor_set, .dispatch = iree_hal_webgpu_command_buffer_dispatch, .dispatch_indirect = iree_hal_webgpu_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_webgpu_command_buffer_dispatch2, + .dispatch2_indirect = iree_hal_webgpu_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/base/internal/threading_darwin.c b/runtime/src/iree/base/internal/threading_darwin.c index 8f611b8bccdc..537f705c95ba 100644 --- a/runtime/src/iree/base/internal/threading_darwin.c +++ b/runtime/src/iree/base/internal/threading_darwin.c @@ -26,7 +26,7 @@ struct iree_thread_t { iree_atomic_ref_count_t ref_count; iree_allocator_t allocator; - char name[16]; + char name[32]; pthread_t handle; mach_port_t mach_port; diff --git a/runtime/src/iree/base/internal/threading_pthreads.c b/runtime/src/iree/base/internal/threading_pthreads.c index ec0f1076b634..0d5c0167419b 100644 --- a/runtime/src/iree/base/internal/threading_pthreads.c +++ b/runtime/src/iree/base/internal/threading_pthreads.c @@ -33,7 +33,7 @@ struct iree_thread_t { iree_atomic_ref_count_t ref_count; iree_allocator_t allocator; - char name[16]; + char name[32]; pthread_t handle; iree_thread_entry_t entry; diff --git a/runtime/src/iree/base/internal/threading_win32.c b/runtime/src/iree/base/internal/threading_win32.c index 944c24a604a1..0091af146e64 100644 --- a/runtime/src/iree/base/internal/threading_win32.c +++ b/runtime/src/iree/base/internal/threading_win32.c @@ -25,7 +25,7 @@ struct iree_thread_t { iree_atomic_ref_count_t ref_count; iree_allocator_t allocator; - char name[16]; + char name[32]; HANDLE handle; DWORD id; diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c index 7f3785ddbd8c..802330fd5ff9 100644 --- a/runtime/src/iree/hal/command_buffer.c +++ b/runtime/src/iree/hal/command_buffer.c @@ -619,6 +619,77 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( return status; } +IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable); + + if ((workgroup_count[0] | workgroup_count[1] | workgroup_count[2]) == 0) { + // No-op dispatch. All implementations are expected to do this but we ensure + // it happens here to avoid the overhead of going all the way down into the + // device layer for something we know should have no (intentional) + // side-effects. Note that this does mean that validation is skipped and + // the executable/etc could be bogus but that's fine. + return iree_ok_status(); + } + + IREE_TRACE_ZONE_BEGIN(z0); +#if IREE_HAL_VERBOSE_TRACING_ENABLE + // TODO(benvanik): add a tracing.h helper that does the snprintf directly + // into a tracy_malloc buffer so that we can avoid the memcpy. Today this can + // take 4-5us which adds too much overhead when trying to get accurate timings + // with tracing enabled. Because benchmarks shouldn't be run with asserts + // enabled we only enable these when assertions are enabled. Ideally we'd + // slice off a much larger allocation and then suballocate from that ourselves + // so that we could avoid the tracy_malloc overheads per-dispatch. + IREE_TRACE({ + char xyz_string[32]; + int xyz_string_length = + snprintf(xyz_string, IREE_ARRAYSIZE(xyz_string), "%ux%ux%u", + workgroup_count[0], workgroup_count[1], workgroup_count[2]); + IREE_TRACE_ZONE_APPEND_TEXT_STRING_VIEW(z0, xyz_string, xyz_string_length); + }); +#endif // IREE_HAL_VERBOSE_TRACING_ENABLE + + IF_VALIDATING(command_buffer, { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_command_buffer_dispatch2_validation( + command_buffer, VALIDATION_STATE(command_buffer), executable, + entry_point, workgroup_count, constants, bindings, flags)); + }); + + iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch2)( + command_buffer, executable, entry_point, workgroup_count, constants, + bindings, flags); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + IREE_ASSERT_ARGUMENT(command_buffer); + IREE_ASSERT_ARGUMENT(executable); + IREE_TRACE_ZONE_BEGIN(z0); + IF_VALIDATING(command_buffer, { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_command_buffer_dispatch2_indirect_validation( + command_buffer, VALIDATION_STATE(command_buffer), executable, + entry_point, workgroups_ref, constants, bindings, flags)); + }); + iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch2_indirect)( + command_buffer, executable, entry_point, workgroups_ref, constants, + bindings, flags); + IREE_TRACE_ZONE_END(z0); + return status; +} + //===----------------------------------------------------------------------===// // Validation support //===----------------------------------------------------------------------===// diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h index 5cd30c64ced7..43a876feef52 100644 --- a/runtime/src/iree/hal/command_buffer.h +++ b/runtime/src/iree/hal/command_buffer.h @@ -91,6 +91,7 @@ typedef uint32_t iree_hal_command_category_t; // // Roughly maps to VkDescriptorSetBinding. typedef struct iree_hal_buffer_ref_t { + // TODO(#18154): change ordinal to `reserved` after binding simplification. // The binding number of this entry and corresponds to a resource of the // same binding number in the executable interface. Only used by certain // calls. @@ -125,6 +126,12 @@ static inline iree_hal_buffer_ref_t iree_hal_make_indirect_buffer_ref( return (iree_hal_buffer_ref_t){0, buffer_slot, NULL, offset, length}; } +// A list of buffer references. +typedef struct iree_hal_buffer_ref_list_t { + iree_host_size_t count; + const iree_hal_buffer_ref_t* values; +} iree_hal_buffer_ref_list_t; + // Bitfield specifying which execution stage a barrier should start/end at. // // Maps to VkPipelineStageFlagBits. @@ -714,6 +721,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_collective( iree_hal_collective_op_t op, uint32_t param, iree_hal_buffer_ref_t send_ref, iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count); +// TODO(#18154): deprecated and will be replaced with simplified bindings. +// // Pushes an inline set of constants that can be accessed by subsequent // dispatches using a compatible pipeline layout. // @@ -725,6 +734,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_constants( iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, const void* values, iree_host_size_t values_length); +// TODO(#18154): deprecated and will be replaced with simplified bindings. +// // Pushes descriptor set bindings and associates them with |set|. // This uses an internal ringbuffer inside of the command buffer to avoid the // need for creating and binding descriptor sets and managing their lifetime. @@ -745,6 +756,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_descriptor_set( iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings); +// TODO(#18154): deprecated and will be replaced with simplified bindings. +// // Dispatches an execution request. // The request may execute overlapped with any other transfer operation or // dispatch made within the same barrier-defined sequence. @@ -761,6 +774,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, iree_hal_dispatch_flags_t flags); +// TODO(#18154): deprecated and will be replaced with simplified bindings. +// // Dispatches an execution request with deferred workgroup counts. // This is the same as iree_hal_command_buffer_dispatch but the workgroup counts // are read from the given |workgroups_buffer| at offset |workgroups_offset| as @@ -775,6 +790,40 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( iree_hal_executable_t* executable, int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); +// Dispatches an execution request. +// The request may execute overlapped with any other transfer operation or +// dispatch made within the same barrier-defined sequence. The executable +// specified must be registered for use with the device driver owning this +// queue. +// +// The provided constant data and binding list will be recorded into the command +// buffer and need not remain live beyond the call. +// +// Fails if the queue does not support dispatch operations (as indicated by +// can_dispatch). +IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); + +// Dispatches an execution request with a deferred workgroup count. +// This is the same as iree_hal_command_buffer_dispatch but the workgroup count +// is read from the given |workgroups_ref| buffer at the specified offset as +// 3 uint32_t XYZ values immediately before performing the dispatch. This allows +// prior dispatches within the command sequence to populate the workgroup +// count or the workgroup count to change across submissions of the same +// reusable command buffer. +// +// The buffer must have been allocated with +// IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS and be of +// IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE. +IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); + //===----------------------------------------------------------------------===// // Validation support //===----------------------------------------------------------------------===// @@ -937,6 +986,18 @@ typedef struct iree_hal_command_buffer_vtable_t { iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); + + iree_status_t(IREE_API_PTR* dispatch2)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); + + iree_status_t(IREE_API_PTR* dispatch2_indirect)( + iree_hal_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); } iree_hal_command_buffer_vtable_t; IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t); diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c index b27433c3f35d..0c5b0dc39b80 100644 --- a/runtime/src/iree/hal/command_buffer_validation.c +++ b/runtime/src/iree/hal/command_buffer_validation.c @@ -651,6 +651,88 @@ iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( return iree_ok_status(); } +static iree_status_t iree_hal_command_buffer_dispatch2_validation_base( + iree_hal_command_buffer_t* command_buffer, + iree_hal_command_buffer_validation_state_t* validation_state, + iree_hal_executable_t* executable, int32_t entry_point, + iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings, + iree_hal_dispatch_flags_t flags) { + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( + command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); + + if (IREE_UNLIKELY((constants.data_length % 4) != 0)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "invalid alignment %" PRIhsz + ", must be 4-byte aligned", + constants.data_length); + } + + // For now we conservatively say _any_ access may be performed (read/write). + iree_hal_buffer_binding_requirements_t requirements = { + .required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH, + .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE, + .access = IREE_HAL_MEMORY_ACCESS_ANY, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + }; + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + requirements.max_byte_offset = + bindings.values[i].offset + bindings.values[i].length; + IREE_RETURN_IF_ERROR( + iree_hal_command_buffer_validate_buffer_requirements( + command_buffer, validation_state, bindings.values[i], requirements), + "binding[%u] (arg[%" PRIhsz "])", bindings.values[i].ordinal, i); + } + + return iree_ok_status(); +} + +iree_status_t iree_hal_command_buffer_dispatch2_validation( + iree_hal_command_buffer_t* command_buffer, + iree_hal_command_buffer_validation_state_t* validation_state, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + return iree_hal_command_buffer_dispatch2_validation_base( + command_buffer, validation_state, executable, entry_point, constants, + bindings, flags); +} + +iree_status_t iree_hal_command_buffer_dispatch2_indirect_validation( + iree_hal_command_buffer_t* command_buffer, + iree_hal_command_buffer_validation_state_t* validation_state, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + if ((workgroups_ref.offset % sizeof(uint32_t)) != 0) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "workgroup count offset does not match the required natural alignment " + "of uint32_t (offset=%" PRIdsz ", min_byte_alignment=%" PRIhsz ")", + workgroups_ref.offset, sizeof(uint32_t)); + } else if (workgroups_ref.length < 3 * sizeof(uint32_t)) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "workgroup count buffer does not have the capacity " + "to store the required 3 uint32_t values " + "(length=%" PRIdsz ", min_length=%" PRIhsz ")", + workgroups_ref.length, 3 * sizeof(uint32_t)); + } + + const iree_hal_buffer_binding_requirements_t workgroups_reqs = { + .required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH, + .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_INDIRECT_PARAMS, + .access = IREE_HAL_MEMORY_ACCESS_READ, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE, + .max_byte_offset = workgroups_ref.offset + workgroups_ref.length, + .min_byte_alignment = sizeof(uint32_t), + }; + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_buffer_requirements( + command_buffer, validation_state, workgroups_ref, workgroups_reqs)); + + return iree_hal_command_buffer_dispatch2_validation_base( + command_buffer, validation_state, executable, entry_point, constants, + bindings, flags); +} + iree_status_t iree_hal_command_buffer_binding_table_validation( iree_hal_command_buffer_t* command_buffer, const iree_hal_command_buffer_validation_state_t* validation_state, diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h index 82ab1c5c7ad6..505982f0d0e8 100644 --- a/runtime/src/iree/hal/command_buffer_validation.h +++ b/runtime/src/iree/hal/command_buffer_validation.h @@ -126,18 +126,21 @@ iree_status_t iree_hal_command_buffer_collective_validation( iree_hal_buffer_ref_t send_ref, iree_hal_buffer_ref_t recv_ref, iree_device_size_t element_count); +// TODO(#18154): deprecated and will be replaced with simplified bindings. iree_status_t iree_hal_command_buffer_push_constants_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_pipeline_layout_t* pipeline_layout, iree_host_size_t offset, const void* values, iree_host_size_t values_length); +// TODO(#18154): deprecated and will be replaced with simplified bindings. iree_status_t iree_hal_command_buffer_push_descriptor_set_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_pipeline_layout_t* pipeline_layout, uint32_t set, iree_host_size_t binding_count, const iree_hal_buffer_ref_t* bindings); +// TODO(#18154): deprecated and will be replaced with simplified bindings. iree_status_t iree_hal_command_buffer_dispatch_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, @@ -145,12 +148,27 @@ iree_status_t iree_hal_command_buffer_dispatch_validation( uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, iree_hal_dispatch_flags_t flags); +// TODO(#18154): deprecated and will be replaced with simplified bindings. iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); +iree_status_t iree_hal_command_buffer_dispatch2_validation( + iree_hal_command_buffer_t* command_buffer, + iree_hal_command_buffer_validation_state_t* validation_state, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); + +iree_status_t iree_hal_command_buffer_dispatch2_indirect_validation( + iree_hal_command_buffer_t* command_buffer, + iree_hal_command_buffer_validation_state_t* validation_state, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags); + iree_status_t iree_hal_command_buffer_binding_table_validation( iree_hal_command_buffer_t* command_buffer, const iree_hal_command_buffer_validation_state_t* validation_state, diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c index c53428af1ce6..68d4d34668bb 100644 --- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c @@ -59,9 +59,8 @@ typedef struct iree_hal_cuda_graph_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; + // TODO(#18189): drop state used by legacy bindings mechanism. int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; - - // The current bound descriptor sets. struct { CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT]; } descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT]; @@ -879,6 +878,132 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch_indirect( "indirect dispatch not yet implemented"); } +static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_cuda_graph_command_buffer_t* command_buffer = + iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda_graph_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_cuda_kernel_info_t kernel_info; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_info)); + + IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer, kernel_info.source_filename.data, + kernel_info.source_filename.size, kernel_info.source_line, + kernel_info.function_name.data, kernel_info.function_name.size, + /*name=*/NULL, 0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &executable)); + // We append push constants to the end of descriptors to form a linear chain + // of kernel arguments. + iree_host_size_t kernel_params_count = + kernel_info.binding_count + kernel_info.constant_count; + iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); + + // TODO: use packed parameters instead of the indirection mechanism - this + // would avoid additional driver overhead to reflect and repack them all. + // + // Per CUDA API requirements, we need two levels of indirection for passing + // kernel arguments in. + // "If the kernel has N parameters, then kernelParams needs to be an array + // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1], + // points to the region of memory from which the actual parameter will be + // copied." + // + // (From the cuGraphAddKernelNode API doc in + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b) + // + // It means each kernel_params[i] is itself a pointer to the corresponding + // element at the *second* inline allocation at the end of the current + // segment. + iree_host_size_t total_size = kernel_params_length * 2; + uint8_t* storage_base = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, total_size, + (void**)&storage_base)); + void** params_ptr = (void**)storage_base; + CUdeviceptr* payload_ptr = + (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length); + for (size_t i = 0; i < kernel_params_count; i++) { + params_ptr[i] = &payload_ptr[i]; + } + for (iree_host_size_t i = 0; i < bindings.count; i++) { + const iree_hal_buffer_ref_t* binding = &bindings.values[i]; + CUdeviceptr device_ptr = 0; + if (binding->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &binding->buffer)); + CUdeviceptr device_buffer = iree_hal_cuda_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding->buffer)); + iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); + device_ptr = device_buffer + offset + binding->offset; + } + payload_ptr[i] = device_ptr; + } + + // As commented in the above, what each kernel parameter points to is a + // CUdeviceptr, which as the size of a pointer on the target machine. we are + // just storing a 32-bit value for the push constant here instead. So we must + // process one element each type, for 64-bit machines. + for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) { + *((uint32_t*)params_ptr[kernel_info.binding_count + i]) = + ((const uint32_t*)constants.data)[i]; + } + + CUDA_KERNEL_NODE_PARAMS params = { + .func = kernel_info.function, + .blockDimX = kernel_info.block_size[0], + .blockDimY = kernel_info.block_size[1], + .blockDimZ = kernel_info.block_size[2], + .gridDimX = workgroup_count[0], + .gridDimY = workgroup_count[1], + .gridDimZ = workgroup_count[2], + .kernelParams = params_ptr, + .sharedMemBytes = kernel_info.shared_memory_size, + }; + + if (command_buffer->graph_node_count >= + IREE_HAL_CUDA_MAX_CONCURRENT_GRAPH_NODE_COUNT) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "exceeded max concurrent node limit"); + } + + size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0; + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->symbols, + cuGraphAddKernelNode( + &command_buffer->cu_graph_nodes[command_buffer->graph_node_count++], + command_buffer->cu_graph, &command_buffer->cu_barrier_node, + dependency_count, ¶ms), + "cuGraphAddKernelNode"); + + IREE_CUDA_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect dispatch not yet implemented"); +} + static const iree_hal_command_buffer_vtable_t iree_hal_cuda_graph_command_buffer_vtable = { .destroy = iree_hal_cuda_graph_command_buffer_destroy, @@ -903,4 +1028,7 @@ static const iree_hal_command_buffer_vtable_t .dispatch = iree_hal_cuda_graph_command_buffer_dispatch, .dispatch_indirect = iree_hal_cuda_graph_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_cuda_graph_command_buffer_dispatch2, + .dispatch2_indirect = + iree_hal_cuda_graph_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.c b/runtime/src/iree/hal/drivers/cuda/native_executable.c index 00b7216e3152..06a7ffc1bbbe 100644 --- a/runtime/src/iree/hal/drivers/cuda/native_executable.c +++ b/runtime/src/iree/hal/drivers/cuda/native_executable.c @@ -224,16 +224,33 @@ iree_status_t iree_hal_cuda_native_executable_create( } if (!iree_status_is_ok(status)) break; + // TODO(#18189): embed all of this on a single flatbuffer table + // per-export. + // // Package required parameters for kernel launches for each entry point. iree_hal_cuda_kernel_info_t* info = &executable->entry_points[i]; info->layout = executable_params->pipeline_layouts[i]; iree_hal_pipeline_layout_retain(info->layout); info->function = function; + info->constant_count = + iree_hal_cuda_pipeline_layout_push_constant_count(info->layout); + info->binding_count = + iree_hal_cuda_pipeline_layout_total_binding_count(info->layout); info->block_size[0] = block_sizes_vec[i].x; info->block_size[1] = block_sizes_vec[i].y; info->block_size[2] = block_sizes_vec[i].z; info->shared_memory_size = shared_memory_sizes[i]; + if (info->binding_count > + IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT) { + status = iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "exceeded available binding slots; requested %u of maximum %d", + info->binding_count, + IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT); + } + if (!iree_status_is_ok(status)) break; + // Stash the entry point name in the string table for use when tracing. IREE_TRACE({ iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name); diff --git a/runtime/src/iree/hal/drivers/cuda/native_executable.h b/runtime/src/iree/hal/drivers/cuda/native_executable.h index 1dee84dc4e3a..226cedaa6b44 100644 --- a/runtime/src/iree/hal/drivers/cuda/native_executable.h +++ b/runtime/src/iree/hal/drivers/cuda/native_executable.h @@ -20,8 +20,12 @@ extern "C" { #endif // __cplusplus typedef struct iree_hal_cuda_kernel_info_t { + // TODO(#18189): remove when using simplified bindings. iree_hal_pipeline_layout_t* layout; CUfunction function; + uint32_t constant_count; + uint32_t binding_count; + // TODO(#18189): add bitfield indicating indirect bindings. uint32_t block_size[3]; uint32_t shared_memory_size; diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c index 5b64dfaf704e..17af9684c179 100644 --- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c +++ b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c @@ -538,13 +538,13 @@ iree_status_t iree_hal_cuda_pending_queue_actions_create( // Create the ready-list processing worker itself. iree_thread_create_params_t params; memset(¶ms, 0, sizeof(params)); - params.name = IREE_SV("deferque_worker"); + params.name = IREE_SV("iree-cuda-queue-worker"); params.create_suspended = false; iree_status_t status = iree_thread_create( (iree_thread_entry_t)iree_hal_cuda_worker_execute, actions, params, actions->host_allocator, &actions->worker_thread); - params.name = IREE_SV("done_worker"); + params.name = IREE_SV("iree-cuda-queue-completion"); params.create_suspended = false; if (iree_status_is_ok(status)) { status = iree_thread_create( diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c index 3369f3b405cb..a9b50fc19f4a 100644 --- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c @@ -39,10 +39,8 @@ typedef struct iree_hal_cuda_stream_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; - // The current set push constants. + // TODO(#18189): drop state used by legacy bindings mechanism. int32_t push_constants[IREE_HAL_CUDA_MAX_PUSH_CONSTANT_COUNT]; - - // The current bound descriptor sets. struct { CUdeviceptr bindings[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_BINDING_COUNT]; } descriptor_sets[IREE_HAL_CUDA_MAX_DESCRIPTOR_SET_COUNT]; @@ -652,6 +650,120 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect( "need cuda implementation of dispatch indirect"); } +static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_cuda_stream_command_buffer_t* command_buffer = + iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + iree_hal_cuda_stream_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_cuda_kernel_info_t kernel_info; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_cuda_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_info)); + + IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, &command_buffer->tracing_event_list, + command_buffer->cu_stream, kernel_info.source_filename.data, + kernel_info.source_filename.size, kernel_info.source_line, + kernel_info.function_name.data, kernel_info.function_name.size, + /*name=*/NULL, 0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &executable)); + + // We append push constants to the end of descriptors to form a linear chain + // of kernel arguments. + iree_host_size_t kernel_params_count = + kernel_info.binding_count + kernel_info.constant_count; + iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); + + // TODO: use packed parameters instead of the indirection mechanism - this + // would avoid additional driver overhead to reflect and repack them all. + // + // Per CUDA API requirements, we need two levels of indirection for passing + // kernel arguments in. + // "If the kernel has N parameters, then kernelParams needs to be an array + // of N pointers. Each pointer, from kernelParams[0] to kernelParams[N-1], + // points to the region of memory from which the actual parameter will be + // copied." + // + // (From the cuGraphAddKernelNode API doc in + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b) + // + // It means each kernel_params[i] is itself a pointer to the corresponding + // element at the *second* inline allocation at the end of the current + // segment. + iree_host_size_t total_size = kernel_params_length * 2; + uint8_t* storage_base = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, total_size, + (void**)&storage_base)); + void** params_ptr = (void**)storage_base; + CUdeviceptr* payload_ptr = + (CUdeviceptr*)((uint8_t*)params_ptr + kernel_params_length); + for (size_t i = 0; i < kernel_params_count; i++) { + params_ptr[i] = &payload_ptr[i]; + } + for (iree_host_size_t i = 0; i < bindings.count; i++) { + const iree_hal_buffer_ref_t* binding = &bindings.values[i]; + CUdeviceptr device_ptr = 0; + if (binding->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &binding->buffer)); + CUdeviceptr device_buffer = iree_hal_cuda_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding->buffer)); + iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); + device_ptr = device_buffer + offset + binding->offset; + } + payload_ptr[i] = device_ptr; + } + + // As commented in the above, what each kernel parameter points to is a + // CUdeviceptr, which as the size of a pointer on the target machine. we are + // just storing a 32-bit value for the push constant here instead. So we must + // process one element each type, for 64-bit machines. + for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) { + *((uint32_t*)params_ptr[kernel_info.binding_count + i]) = + ((const uint32_t*)constants.data)[i]; + } + + IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->cuda_symbols, + cuLaunchKernel(kernel_info.function, workgroup_count[0], + workgroup_count[1], workgroup_count[2], + kernel_info.block_size[0], kernel_info.block_size[1], + kernel_info.block_size[2], kernel_info.shared_memory_size, + command_buffer->cu_stream, params_ptr, NULL), + "cuLaunchKernel"); + + IREE_CUDA_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, + &command_buffer->tracing_event_list, + command_buffer->cu_stream); + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect dispatch not yet implemented"); +} + static const iree_hal_command_buffer_vtable_t iree_hal_cuda_stream_command_buffer_vtable = { .destroy = iree_hal_cuda_stream_command_buffer_destroy, @@ -676,4 +788,7 @@ static const iree_hal_command_buffer_vtable_t .dispatch = iree_hal_cuda_stream_command_buffer_dispatch, .dispatch_indirect = iree_hal_cuda_stream_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_cuda_stream_command_buffer_dispatch2, + .dispatch2_indirect = + iree_hal_cuda_stream_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c index ae66cfd2110a..99b3538caf77 100644 --- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c @@ -60,9 +60,8 @@ typedef struct iree_hal_hip_graph_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; + // TODO(#18189): drop state used by legacy bindings mechanism. int32_t push_constants[IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT]; - - // The current bound descriptor sets. struct { hipDeviceptr_t bindings[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT]; } descriptor_sets[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_COUNT]; @@ -888,6 +887,123 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect( "indirect dispatch not yet implemented"); } +static iree_status_t iree_hal_hip_graph_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_hip_graph_command_buffer_t* command_buffer = + iree_hal_hip_graph_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_graph_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_hip_kernel_info_t kernel_info; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_info)); + + IREE_HIP_GRAPH_COMMAND_BUFFER_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer, kernel_info.source_filename.data, + kernel_info.source_filename.size, kernel_info.source_line, + kernel_info.function_name.data, kernel_info.function_name.size, + /*name=*/NULL, 0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &executable)); + + // We append push constants to the end of descriptors to form a linear chain + // of kernel arguments. + iree_host_size_t kernel_params_count = + kernel_info.binding_count + kernel_info.constant_count; + iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); + + // TODO: use packed parameters instead of the indirection mechanism - this + // would avoid additional driver overhead to reflect and repack them all. + // + // Each kernel_params[i] is itself a pointer to the corresponding + // element at the *second* inline allocation at the end of the current + // segment. + iree_host_size_t total_size = kernel_params_length * 2; + uint8_t* storage_base = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, total_size, + (void**)&storage_base)); + void** params_ptr = (void**)storage_base; + hipDeviceptr_t* payload_ptr = + (hipDeviceptr_t*)((uint8_t*)params_ptr + kernel_params_length); + for (size_t i = 0; i < kernel_params_count; i++) { + params_ptr[i] = &payload_ptr[i]; + } + for (iree_host_size_t i = 0; i < bindings.count; i++) { + const iree_hal_buffer_ref_t* binding = &bindings.values[i]; + hipDeviceptr_t device_ptr = NULL; + if (binding->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &binding->buffer)); + hipDeviceptr_t device_buffer = iree_hal_hip_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding->buffer)); + iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); + device_ptr = (uint8_t*)device_buffer + offset + binding->offset; + } + payload_ptr[i] = device_ptr; + } + + // Each kernel parameter points to is a hipDeviceptr_t, which as the size of a + // pointer on the target machine. we are just storing a 32-bit value for the + // push constant here instead. So we must process one element each type, for + // 64-bit machines. + for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) { + *((uint32_t*)params_ptr[kernel_info.binding_count + i]) = + ((const uint32_t*)constants.data)[i]; + } + + hipKernelNodeParams params = { + .blockDim.x = kernel_info.block_size[0], + .blockDim.y = kernel_info.block_size[1], + .blockDim.z = kernel_info.block_size[2], + .gridDim.x = workgroup_count[0], + .gridDim.y = workgroup_count[1], + .gridDim.z = workgroup_count[2], + .func = kernel_info.function, + .kernelParams = params_ptr, + .sharedMemBytes = kernel_info.shared_memory_size, + }; + + if (command_buffer->graph_node_count >= + IREE_HAL_HIP_MAX_CONCURRENT_GRAPH_NODE_COUNT) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "exceeded max concurrent node limit"); + } + + size_t dependency_count = command_buffer->hip_barrier_node ? 1 : 0; + IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR( + z0, command_buffer->symbols, + hipGraphAddKernelNode( + &command_buffer->hip_graph_nodes[command_buffer->graph_node_count++], + command_buffer->hip_graph, &command_buffer->hip_barrier_node, + dependency_count, ¶ms), + "hipGraphAddKernelNode"); + + IREE_HIP_GRAPH_COMMAND_BUFFER_TRACE_ZONE_END(command_buffer); + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_hip_graph_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect dispatch not yet implemented"); +} + static const iree_hal_command_buffer_vtable_t iree_hal_hip_graph_command_buffer_vtable = { .destroy = iree_hal_hip_graph_command_buffer_destroy, @@ -912,4 +1028,7 @@ static const iree_hal_command_buffer_vtable_t .dispatch = iree_hal_hip_graph_command_buffer_dispatch, .dispatch_indirect = iree_hal_hip_graph_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_hip_graph_command_buffer_dispatch2, + .dispatch2_indirect = + iree_hal_hip_graph_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.c b/runtime/src/iree/hal/drivers/hip/native_executable.c index 10b5aacff7c2..19caae90aa22 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.c +++ b/runtime/src/iree/hal/drivers/hip/native_executable.c @@ -10,6 +10,7 @@ #include "iree/base/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" +#include "iree/hal/drivers/hip/pipeline_layout.h" #include "iree/hal/drivers/hip/status_util.h" // flatcc schemas: @@ -242,16 +243,33 @@ iree_status_t iree_hal_hip_native_executable_create( } if (!iree_status_is_ok(status)) break; + // TODO(#18189): embed all of this on a single flatbuffer table + // per-export. + // // Package required parameters for kernel launches for each entry point. iree_hal_hip_kernel_info_t* kernel_info = &executable->entry_points[i]; kernel_info->layout = executable_params->pipeline_layouts[i]; iree_hal_pipeline_layout_retain(kernel_info->layout); kernel_info->function = function; + iree_hal_hip_dispatch_layout_t dispatch_params = + iree_hal_hip_pipeline_layout_dispatch_layout(kernel_info->layout); + kernel_info->constant_count = dispatch_params.push_constant_count; + kernel_info->binding_count = dispatch_params.total_binding_count; kernel_info->block_size[0] = block_sizes_vec[i].x; kernel_info->block_size[1] = block_sizes_vec[i].y; kernel_info->block_size[2] = block_sizes_vec[i].z; kernel_info->shared_memory_size = shared_memory_sizes_vec[i]; + if (kernel_info->binding_count > + IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT) { + status = iree_make_status( + IREE_STATUS_RESOURCE_EXHAUSTED, + "exceeded available binding slots; requested %u of maximum %d", + kernel_info->binding_count, + IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT); + } + if (!iree_status_is_ok(status)) break; + // Stash the entry point name in the string table for use when tracing. IREE_TRACE({ iree_host_size_t entry_name_length = flatbuffers_string_len(entry_name); diff --git a/runtime/src/iree/hal/drivers/hip/native_executable.h b/runtime/src/iree/hal/drivers/hip/native_executable.h index 922f343e87b2..d2b1a319de5c 100644 --- a/runtime/src/iree/hal/drivers/hip/native_executable.h +++ b/runtime/src/iree/hal/drivers/hip/native_executable.h @@ -20,8 +20,12 @@ extern "C" { #endif // __cplusplus typedef struct iree_hal_hip_kernel_info_t { + // TODO(#18189): remove when using simplified bindings. iree_hal_pipeline_layout_t* layout; hipFunction_t function; + uint32_t constant_count; + uint32_t binding_count; + // TODO(#18189): add bitfield indicating indirect bindings. uint32_t block_size[3]; uint32_t shared_memory_size; diff --git a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c index 6d7233096b05..88a2e830537d 100644 --- a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c +++ b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c @@ -537,13 +537,13 @@ iree_status_t iree_hal_hip_pending_queue_actions_create( // Create the ready-list processing worker itself. iree_thread_create_params_t params; memset(¶ms, 0, sizeof(params)); - params.name = IREE_SV("deferque_worker"); + params.name = IREE_SV("iree-hip-queue-worker"); params.create_suspended = false; iree_status_t status = iree_thread_create( (iree_thread_entry_t)iree_hal_hip_worker_execute, actions, params, actions->host_allocator, &actions->worker_thread); - params.name = IREE_SV("done_worker"); + params.name = IREE_SV("iree-hip-queue-completion"); params.create_suspended = false; if (iree_status_is_ok(status)) { status = iree_thread_create( diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index 0f087275524d..e4ffac2200a9 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -41,9 +41,8 @@ typedef struct iree_hal_hip_stream_command_buffer_t { // Iteratively constructed batch of collective operations. iree_hal_collective_batch_t collective_batch; + // TODO(#18189): drop state used by legacy bindings mechanism. int32_t push_constants[IREE_HAL_HIP_MAX_PUSH_CONSTANT_COUNT]; - - // The current bound descriptor sets. struct { hipDeviceptr_t bindings[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_BINDING_COUNT]; } descriptor_sets[IREE_HAL_HIP_MAX_DESCRIPTOR_SET_COUNT]; @@ -632,6 +631,110 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch_indirect( "need hip implementation of dispatch indirect"); } +static iree_status_t iree_hal_hip_stream_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_hip_stream_command_buffer_t* command_buffer = + iree_hal_hip_stream_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_stream_command_buffer_flush_collectives(command_buffer)); + + // Lookup kernel parameters used for side-channeling additional launch + // information from the compiler. + iree_hal_hip_kernel_info_t kernel_info; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_native_executable_entry_point_kernel_info( + executable, entry_point, &kernel_info)); + + IREE_HIP_STREAM_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, &command_buffer->tracing_event_list, + command_buffer->hip_stream, kernel_info.source_filename.data, + kernel_info.source_filename.size, kernel_info.source_line, + kernel_info.function_name.data, kernel_info.function_name.size, + /*name=*/NULL, 0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &executable)); + + // We append push constants to the end of descriptors to form a linear chain + // of kernel arguments. + iree_host_size_t kernel_params_count = + kernel_info.binding_count + kernel_info.constant_count; + iree_host_size_t kernel_params_length = kernel_params_count * sizeof(void*); + + // TODO: use packed parameters instead of the indirection mechanism - this + // would avoid additional driver overhead to reflect and repack them all. + // + // Each kernel_params[i] is itself a pointer to the corresponding + // element at the *second* inline allocation at the end of the current + // segment. + iree_host_size_t total_size = kernel_params_length * 2; + uint8_t* storage_base = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, total_size, + (void**)&storage_base)); + void** params_ptr = (void**)storage_base; + hipDeviceptr_t* payload_ptr = + (hipDeviceptr_t*)((uint8_t*)params_ptr + kernel_params_length); + for (size_t i = 0; i < kernel_params_count; i++) { + params_ptr[i] = &payload_ptr[i]; + } + for (iree_host_size_t i = 0; i < bindings.count; i++) { + const iree_hal_buffer_ref_t* binding = &bindings.values[i]; + hipDeviceptr_t device_ptr = NULL; + if (binding->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, + &binding->buffer)); + hipDeviceptr_t device_buffer = iree_hal_hip_buffer_device_pointer( + iree_hal_buffer_allocated_buffer(binding->buffer)); + iree_device_size_t offset = iree_hal_buffer_byte_offset(binding->buffer); + device_ptr = (uint8_t*)device_buffer + offset + binding->offset; + } + payload_ptr[i] = device_ptr; + } + + // As commented in the above, what each kernel parameter points to is a + // hipDeviceptr_t, which as the size of a pointer on the target machine. we + // are just storing a 32-bit value for the push constant here instead. So we + // must process one element each type, for 64-bit machines. + for (iree_host_size_t i = 0; i < kernel_info.constant_count; i++) { + *((uint32_t*)params_ptr[kernel_info.binding_count + i]) = + ((const uint32_t*)constants.data)[i]; + } + + iree_status_t status = IREE_HIP_RESULT_TO_STATUS( + command_buffer->hip_symbols, + hipModuleLaunchKernel( + kernel_info.function, workgroup_count[0], workgroup_count[1], + workgroup_count[2], kernel_info.block_size[0], + kernel_info.block_size[1], kernel_info.block_size[2], + kernel_info.shared_memory_size, command_buffer->hip_stream, + params_ptr, NULL), + "hipModuleLaunchKernel"); + + IREE_HIP_STREAM_TRACE_ZONE_END(command_buffer->tracing_context, + &command_buffer->tracing_event_list, + command_buffer->hip_stream); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_hip_stream_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + return iree_make_status(IREE_STATUS_UNIMPLEMENTED, + "indirect dispatch not yet implemented"); +} + static const iree_hal_command_buffer_vtable_t iree_hal_hip_stream_command_buffer_vtable = { .destroy = iree_hal_hip_stream_command_buffer_destroy, @@ -656,4 +759,7 @@ static const iree_hal_command_buffer_vtable_t .dispatch = iree_hal_hip_stream_command_buffer_dispatch, .dispatch_indirect = iree_hal_hip_stream_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_hip_stream_command_buffer_dispatch2, + .dispatch2_indirect = + iree_hal_hip_stream_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c index de8642ec8b90..3b0a9ba18617 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c +++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c @@ -78,6 +78,7 @@ typedef struct iree_hal_task_command_buffer_t { // All execution tasks emitted that must execute after |open_barrier|. iree_task_list_t open_tasks; + // TODO(#18189): remove legacy binding state. // A flattened list of all available descriptor set bindings. // As descriptor sets are pushed/bound the bindings will be updated to // represent the fully-translated binding data pointer. @@ -89,6 +90,7 @@ typedef struct iree_hal_task_command_buffer_t { binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; + // TODO(#18189): remove legacy push constant state. // All available push constants updated each time push_constants is called. // Reset only with the command buffer and otherwise will maintain its values // during recording to allow for partial push_constants updates. @@ -930,7 +932,7 @@ static iree_status_t iree_hal_task_command_buffer_build_dispatch( cmd->task.local_memory_size = local_executable->dispatch_attrs ? local_executable->dispatch_attrs[entry_point].local_memory_pages * - IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE + IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE : 0; // Copy only the push constant range used by the executable. @@ -1012,6 +1014,234 @@ static iree_status_t iree_hal_task_command_buffer_dispatch_indirect( return iree_ok_status(); } +//===----------------------------------------------------------------------===// +// iree_hal_command_buffer_dispatch2 +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_task_cmd_dispatch2_t { + iree_task_dispatch_t task; + iree_hal_local_executable_t* executable; + int32_t ordinal; + + // Total number of available 4 byte push constant values in |push_constants|. + uint16_t push_constant_count; + + // Total number of binding base pointers in |binding_ptrs| and + // |binding_lengths|. The set is packed densely based on which bindings are + // used (known at compile-time). + uint16_t binding_count; + + // Following this structure in memory there are 3 tables: + // - const uint32_t push_constants[push_constant_count]; + // - void* binding_ptrs[binding_count]; + // - const size_t binding_lengths[binding_count]; +} iree_hal_task_cmd_dispatch2_t; + +static iree_status_t iree_hal_task_cmd_dispatch2_tile( + void* user_context, const iree_task_tile_context_t* tile_context, + iree_task_submission_t* pending_submission) { + const iree_hal_task_cmd_dispatch2_t* cmd = + (const iree_hal_task_cmd_dispatch2_t*)user_context; + IREE_TRACE_ZONE_BEGIN(z0); + + // We could share this across all workgroups in a dispatch and reduce cache + // pressure as all cores would be hitting the same hot read-only cache line. + // It'd grow the size of iree_hal_task_cmd_dispatch_t by a few dozen bytes, + // though, and so we'd need some profiling to see if it's worth it (fixed + // command buffer cost vs potential for saving a cache miss or two). + iree_alignas(64) iree_hal_executable_dispatch_state_v0_t dispatch_state = { + .workgroup_size_x = tile_context->workgroup_size[0], + .workgroup_size_y = tile_context->workgroup_size[1], + .workgroup_size_z = tile_context->workgroup_size[2], + .push_constant_count = cmd->push_constant_count, + .workgroup_count_x = tile_context->workgroup_count[0], + .workgroup_count_y = tile_context->workgroup_count[1], + .workgroup_count_z = tile_context->workgroup_count[2], + .max_concurrency = + iree_task_affinity_set_count_ones(cmd->task.header.affinity_set), + .binding_count = cmd->binding_count, + }; + uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd); + dispatch_state.push_constants = (uint32_t*)cmd_ptr; + cmd_ptr += cmd->push_constant_count * sizeof(*dispatch_state.push_constants); + dispatch_state.binding_ptrs = (void**)cmd_ptr; + cmd_ptr += cmd->binding_count * sizeof(*dispatch_state.binding_ptrs); + dispatch_state.binding_lengths = (size_t*)cmd_ptr; + cmd_ptr += cmd->binding_count * sizeof(*dispatch_state.binding_lengths); + + const iree_alignas(64) + iree_hal_executable_workgroup_state_v0_t workgroup_state = { + .workgroup_id_x = tile_context->workgroup_xyz[0], + .workgroup_id_y = tile_context->workgroup_xyz[1], + .workgroup_id_z = tile_context->workgroup_xyz[2], + .reserved = 0, + .processor_id = tile_context->processor_id, + .local_memory = tile_context->local_memory.data, + .local_memory_size = (size_t)tile_context->local_memory.data_length, + }; + iree_status_t status = iree_hal_local_executable_issue_call( + cmd->executable, cmd->ordinal, &dispatch_state, &workgroup_state, + tile_context->worker_id); + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static iree_status_t iree_hal_task_command_buffer_build_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, + iree_hal_task_cmd_dispatch2_t** out_cmd) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + iree_hal_local_executable_t* local_executable = + iree_hal_local_executable_cast(executable); + iree_hal_executable_dispatch_attrs_v0_t dispatch_attrs = {0}; + if (local_executable->dispatch_attrs) { + dispatch_attrs = local_executable->dispatch_attrs[entry_point]; + } + + iree_hal_task_cmd_dispatch2_t* cmd = NULL; + iree_host_size_t total_cmd_size = + sizeof(*cmd) + dispatch_attrs.constant_count * sizeof(uint32_t) + + dispatch_attrs.binding_count * sizeof(void*) + + dispatch_attrs.binding_count * sizeof(size_t); + IREE_RETURN_IF_ERROR(iree_arena_allocate(&command_buffer->arena, + total_cmd_size, (void**)&cmd)); + + cmd->executable = local_executable; + cmd->ordinal = entry_point; + cmd->push_constant_count = dispatch_attrs.constant_count; + cmd->binding_count = dispatch_attrs.binding_count; + + // TODO(benvanik): expose on API or keep fixed on executable. + const uint32_t workgroup_size[3] = {1, 1, 1}; + iree_task_dispatch_initialize( + command_buffer->scope, + iree_task_make_dispatch_closure(iree_hal_task_cmd_dispatch_tile, + (void*)cmd), + workgroup_size, workgroup_count, &cmd->task); + + // Tell the task system how much workgroup local memory is required for the + // dispatch; each invocation of the entry point will have at least as much + // scratch memory available during execution. + cmd->task.local_memory_size = + dispatch_attrs.local_memory_pages * + IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE; + + // Push constants are pulled directly from the args and copied into the + // command buffer. Note that we require 4 byte alignment and if the input + // buffer is not aligned we have to fail. + if (IREE_UNLIKELY((constants.data_length % sizeof(uint32_t)) != 0)) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "constants must be 4-byte aligned"); + } else if (IREE_UNLIKELY(constants.data_length != + dispatch_attrs.constant_count * sizeof(uint32_t))) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "constant count mismatch, expected %u but was provided %" PRIhsz, + (uint32_t)dispatch_attrs.constant_count, + constants.data_length / sizeof(uint32_t)); + } + uint8_t* cmd_ptr = (uint8_t*)cmd + sizeof(*cmd); + uint32_t* push_constants = (uint32_t*)cmd_ptr; + memcpy(push_constants, constants.data, + dispatch_attrs.constant_count * sizeof(*push_constants)); + cmd_ptr += dispatch_attrs.constant_count * sizeof(*push_constants); + + // Produce the dense binding list based on the declared bindings used. + // + // Note that we are just directly setting the binding data pointers here with + // no ownership/retaining/etc - it's part of the HAL contract that buffers are + // kept valid for the duration they may be in use. + if (IREE_UNLIKELY(bindings.count != dispatch_attrs.binding_count)) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "binding count mismatch, expected %u but was provided %" PRIhsz, + (uint32_t)dispatch_attrs.binding_count, bindings.count); + } + void** binding_ptrs = (void**)cmd_ptr; + cmd_ptr += bindings.count * sizeof(*binding_ptrs); + size_t* binding_lengths = (size_t*)cmd_ptr; + cmd_ptr += bindings.count * sizeof(*binding_lengths); + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. + iree_hal_buffer_mapping_t buffer_mapping = {{0}}; + if (IREE_LIKELY(bindings.values[i].buffer)) { + // TODO(benvanik): batch insert by getting the resources in their own + // list. + const iree_hal_buffer_ref_t binding = bindings.values[i]; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + binding.buffer, IREE_HAL_MAPPING_MODE_PERSISTENT, + IREE_HAL_MEMORY_ACCESS_ANY, binding.offset, binding.length, + &buffer_mapping)); + } else { + return iree_make_status( + IREE_STATUS_FAILED_PRECONDITION, + "required binding %" PRIhsz + " is NULL; all bindings must have a valid pointer", + i); + } + binding_ptrs[i] = buffer_mapping.contents.data; + binding_lengths[i] = buffer_mapping.contents.data_length; + } + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided( + command_buffer->resource_set, bindings.count, bindings.values, + offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t))); + + *out_cmd = cmd; + return iree_hal_task_command_buffer_emit_execution_task(command_buffer, + &cmd->task.header); +} + +static iree_status_t iree_hal_task_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( + command_buffer->resource_set, 1, &executable)); + + iree_hal_task_cmd_dispatch2_t* cmd = NULL; + return iree_hal_task_command_buffer_build_dispatch2( + base_command_buffer, executable, entry_point, workgroup_count, constants, + bindings, &cmd); +} + +static iree_status_t iree_hal_task_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_task_command_buffer_t* command_buffer = + iree_hal_task_command_buffer_cast(base_command_buffer); + + const void* resources[2] = {executable, workgroups_ref.buffer}; + IREE_RETURN_IF_ERROR( + iree_hal_resource_set_insert(command_buffer->resource_set, 2, resources)); + + // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. + iree_hal_buffer_mapping_t buffer_mapping = {{0}}; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + workgroups_ref.buffer, IREE_HAL_MAPPING_MODE_PERSISTENT, + IREE_HAL_MEMORY_ACCESS_READ, workgroups_ref.offset, 3 * sizeof(uint32_t), + &buffer_mapping)); + + uint32_t workgroup_count[3] = {0}; // unused with the indirect flag + iree_hal_task_cmd_dispatch2_t* cmd = NULL; + IREE_RETURN_IF_ERROR(iree_hal_task_command_buffer_build_dispatch2( + base_command_buffer, executable, entry_point, workgroup_count, constants, + bindings, &cmd)); + cmd->task.workgroup_count.ptr = (const uint32_t*)buffer_mapping.contents.data; + cmd->task.header.flags |= IREE_TASK_FLAG_DISPATCH_INDIRECT; + return iree_ok_status(); +} + //===----------------------------------------------------------------------===// // iree_hal_command_buffer_vtable_t //===----------------------------------------------------------------------===// @@ -1036,4 +1266,6 @@ static const iree_hal_command_buffer_vtable_t .push_descriptor_set = iree_hal_task_command_buffer_push_descriptor_set, .dispatch = iree_hal_task_command_buffer_dispatch, .dispatch_indirect = iree_hal_task_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_task_command_buffer_dispatch2, + .dispatch2_indirect = iree_hal_task_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m index eaed4f539652..fbf6374e7fcc 100644 --- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m +++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m @@ -49,6 +49,7 @@ typedef enum iree_hal_metal_command_segment_action_e { IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_BARRIER, // Execution/memory barrier command IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH, // Dispatch command + IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH2, // Dispatch command IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER, // Fill buffer command IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_COPY_BUFFER, // Copy buffer command } iree_hal_metal_command_segment_action_t; @@ -94,6 +95,30 @@ // + Additional inline allocation for holding all bound descriptors. // + Additional inline allocation for holding all push constants. +// API data for dispatch command segments. +typedef struct iree_hal_metal_dispatch2_segment_t { + // Compute kernel information--kernel object, pipeline layout, threadgroup size, etc. + iree_hal_metal_kernel_params_t kernel_params; + + // Workgroup count information--if |workgroups_buffer| is not nil, then indirect dispatch; + // otherwise uses |workgroup_count| for direct dispatch. + id workgroups_buffer; + iree_device_size_t workgroups_offset; + uint32_t workgroup_count[3]; + + // The number of descriptors bound for this dispatch. + iree_host_size_t descriptor_count; + // The list of bound descriptors, pointing to the end of the segment allocation. + iree_hal_metal_descriptor_t* descriptors; + + // The number of push constant values. + iree_host_size_t push_constant_count; + // The list of push constants, pointing to the end of the segment allocation. + int32_t* push_constants; +} iree_hal_metal_dispatch2_segment_t; +// + Additional inline allocation for holding all bound descriptors. +// + Additional inline allocation for holding all push constants. + // API data for fill buffer command segments. typedef struct iree_hal_metal_fill_buffer_segment_t { id target_buffer; @@ -121,6 +146,7 @@ union { iree_hal_metal_barrier_segment_t barrier; iree_hal_metal_dispatch_segment_t dispatch; + iree_hal_metal_dispatch2_segment_t dispatch2; iree_hal_metal_fill_buffer_segment_t fill_buffer; iree_hal_metal_copy_buffer_segment_t copy_buffer; }; @@ -1105,6 +1131,183 @@ static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect( return iree_ok_status(); } +// Prepares kernels and argument buffers needed for kernel dispatches. +static iree_status_t iree_hal_metal_command_segment_create_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, + int32_t entry_point, iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings, + iree_hal_dispatch_flags_t flags, iree_hal_metal_dispatch2_segment_t** out_segment) { + iree_hal_metal_command_buffer_t* command_buffer = + iree_hal_metal_command_buffer_cast(base_command_buffer); + IREE_TRACE_ZONE_BEGIN(z0); + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &executable)); + + iree_hal_metal_kernel_params_t kernel_params; + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_kernel_library_entry_point_kernel_params( + executable, entry_point, &kernel_params)); + + // Allocate the command segment and keep track of all necessary API data. + uint8_t* storage_base = NULL; + iree_hal_metal_command_segment_t* segment = NULL; + iree_host_size_t descriptor_length = bindings.count * sizeof(iree_hal_metal_descriptor_t); + iree_host_size_t total_size = sizeof(*segment) + descriptor_length + constants.data_length; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_arena_allocate(&command_buffer->arena, total_size, (void**)&storage_base)); + + // Compose and push the dispatch segment. + segment = (iree_hal_metal_command_segment_t*)storage_base; + memset(segment, 0, sizeof(*segment)); + segment->action = IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH2; + iree_hal_metal_command_segment_list_push_back(&command_buffer->segments, segment); + + segment->dispatch.kernel_params = kernel_params; + + // Copy descriptors from all sets to the end of the current segment for later access. + const iree_hal_descriptor_set_layout_t* set_layout = + iree_hal_metal_pipeline_layout_descriptor_set_layout(kernel_params.layout, 0); + segment->dispatch.descriptor_count = bindings.count; + segment->dispatch.descriptors = (iree_hal_metal_descriptor_t*)(storage_base + sizeof(*segment)); + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + iree_hal_metal_descriptor_t* descriptor = &segment->dispatch.descriptors[i]; + + descriptor->set = 0; + descriptor->binding = i; + descriptor->buffer = bindings.values[i].buffer; + descriptor->offset = bindings.values[i].offset; + + const iree_hal_descriptor_set_layout_binding_t* binding_params = + iree_hal_metal_descriptor_set_layout_binding(set_layout, descriptor->binding); + descriptor->usage = iree_hal_metal_get_metal_resource_usage(binding_params); + + if (descriptor->buffer) { + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_resource_set_insert(command_buffer->resource_set, 1, &descriptor->buffer)); + } + } + + // Copy push constants to the end of the current segment for later access. + segment->dispatch.push_constant_count = constants.data_length / sizeof(uint32_t); + uint8_t* push_constant_ptr = storage_base + sizeof(*segment) + descriptor_length; + segment->dispatch.push_constants = (int32_t*)push_constant_ptr; + memcpy(push_constant_ptr, constants.data, constants.data_length); + + *out_segment = &segment->dispatch2; + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_metal_command_segment_record_dispatch2( + iree_hal_metal_command_buffer_t* command_buffer, iree_hal_metal_dispatch2_segment_t* segment) { + IREE_TRACE_ZONE_BEGIN(z0); + + // Set the compute kernel to dispatch. + id compute_encoder = + iree_hal_metal_get_or_begin_compute_encoder(command_buffer); + [compute_encoder setComputePipelineState:segment->kernel_params.pso]; + + // Record push constants. + if (segment->push_constant_count != 0) { + [compute_encoder setBytes:(void*)segment->push_constants + length:segment->push_constant_count * sizeof(int32_t) + atIndex:IREE_HAL_METAL_PUSH_CONSTANT_BUFFER_INDEX]; + } + + // Record argument buffers for all descriptors and record buffer usages. + iree_hal_metal_descriptor_t* descriptors = segment->descriptors; + + // Build argument encoder and argument buffer for the current descriptor set. + // TODO(antiagainst): Use a cache layer to cache and reuse argument buffers with the same + // content, to avoid duplicating overhead. + id argument_buffer = command_buffer->staging_buffer->metal_buffer; + id argument_encoder = + [segment->kernel_params.function newArgumentEncoderWithBufferIndex:0]; // +1 + IREE_ASSERT(argument_encoder != nil); + + // Reserve space for the argument buffer from shared staging buffer. + iree_byte_span_t reservation = iree_byte_span_empty(); + uint32_t argument_buffer_offset = 0; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_metal_staging_buffer_reserve( + command_buffer->staging_buffer, argument_encoder.encodedLength, + argument_encoder.alignment, &reservation, &argument_buffer_offset)); + [argument_encoder setArgumentBuffer:argument_buffer offset:argument_buffer_offset]; + + // Now record all bound buffers belonging to the current set into the argument buffer. + for (iree_host_size_t i = 0; i < segment->descriptor_count; ++i) { + uint32_t current_binding = descriptors[i].binding; + id current_buffer = + iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(descriptors[i].buffer)); + iree_host_size_t offset = + iree_hal_buffer_byte_offset(descriptors[i].buffer) + descriptors[i].offset; + [argument_encoder setBuffer:current_buffer offset:offset atIndex:current_binding]; + + // Also record buffer usages. + [compute_encoder useResource:current_buffer usage:descriptors[i].usage]; + } + // Record the argument buffer. + [compute_encoder setBuffer:argument_buffer offset:argument_buffer_offset atIndex:0]; + + [argument_encoder release]; // -1 + + // Record the dispatch, either direct or indirect. + uint32_t* workgroup_size = segment->kernel_params.threadgroup_size; + if (segment->workgroups_buffer == nil) { + // Direct dispatch of a fixed workgroup count. + uint32_t* workgroup_count = segment->workgroup_count; + [compute_encoder + dispatchThreadgroups:MTLSizeMake(workgroup_count[0], workgroup_count[1], + workgroup_count[2]) + threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], workgroup_size[2])]; + } else { + // Indirect dispatch using a workgroup count from buffers. + [compute_encoder + dispatchThreadgroupsWithIndirectBuffer:segment->workgroups_buffer + indirectBufferOffset:segment->workgroups_offset + threadsPerThreadgroup:MTLSizeMake(workgroup_size[0], workgroup_size[1], + workgroup_size[2])]; + } + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, + int32_t entry_point, const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_metal_dispatch2_segment_t* segment = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_metal_command_segment_create_dispatch2( + base_command_buffer, executable, entry_point, constants, bindings, flags, &segment)); + segment->workgroup_count[0] = workgroup_count[0]; + segment->workgroup_count[1] = workgroup_count[1]; + segment->workgroup_count[2] = workgroup_count[2]; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + +static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, + int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_metal_dispatch2_segment_t* segment = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_metal_command_segment_create_dispatch2( + base_command_buffer, executable, entry_point, constants, bindings, flags, &segment)); + segment->workgroups_buffer = + iree_hal_metal_buffer_handle(iree_hal_buffer_allocated_buffer(workgroups_ref.buffer)); + segment->workgroups_offset = workgroups_ref.offset; + + IREE_TRACE_ZONE_END(z0); + return iree_ok_status(); +} + static iree_status_t iree_hal_metal_command_segment_record( iree_hal_metal_command_buffer_t* command_buffer) { IREE_ASSERT_ARGUMENT(command_buffer); @@ -1121,6 +1324,10 @@ static iree_status_t iree_hal_metal_command_segment_record( IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_metal_command_segment_record_dispatch(command_buffer, &segment->dispatch)); } break; + case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_DISPATCH2: { + IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_dispatch2( + command_buffer, &segment->dispatch2)); + } break; case IREE_HAL_METAL_COMMAND_SEGMENT_ACTION_FILL_BUFFER: { IREE_RETURN_AND_END_ZONE_IF_ERROR(z0, iree_hal_metal_command_segment_record_fill_buffer( command_buffer, &segment->fill_buffer)); @@ -1180,4 +1387,6 @@ static iree_status_t iree_hal_metal_command_buffer_end( .push_descriptor_set = iree_hal_metal_command_buffer_push_descriptor_set, .dispatch = iree_hal_metal_command_buffer_prepare_dispatch, .dispatch_indirect = iree_hal_metal_command_buffer_prepare_dispatch_indirect, + .dispatch2 = iree_hal_metal_command_buffer_prepare_dispatch2, + .dispatch2_indirect = iree_hal_metal_command_buffer_prepare_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc index 8bb94139f3fa..03dac80b60c9 100644 --- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc +++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc @@ -749,7 +749,7 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch( VkPipeline pipeline_handle = VK_NULL_HANDLE; IREE_RETURN_IF_ERROR( iree_hal_vulkan_native_executable_pipeline_for_entry_point( - executable, entry_point, &pipeline_handle)); + executable, entry_point, &pipeline_handle, NULL)); command_buffer->syms->vkCmdBindPipeline( command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); @@ -787,7 +787,7 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect( VkPipeline pipeline_handle = VK_NULL_HANDLE; IREE_RETURN_IF_ERROR( iree_hal_vulkan_native_executable_pipeline_for_entry_point( - executable, entry_point, &pipeline_handle)); + executable, entry_point, &pipeline_handle, NULL)); command_buffer->syms->vkCmdBindPipeline( command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); @@ -805,6 +805,120 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect( return iree_ok_status(); } +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch2_bind( + iree_hal_vulkan_direct_command_buffer_t* command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_const_byte_span_t constants, iree_hal_buffer_ref_list_t bindings, + iree_hal_dispatch_flags_t flags) { + // Get the compiled and linked pipeline for the specified entry point. + VkPipeline pipeline_handle = VK_NULL_HANDLE; + iree_hal_pipeline_layout_t* pipeline_layout = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_vulkan_native_executable_pipeline_for_entry_point( + executable, entry_point, &pipeline_handle, &pipeline_layout)); + + // Update push constants. + if (!iree_const_byte_span_is_empty(constants)) { + VkPipelineLayout pipeline_layout_handle = + iree_hal_vulkan_native_pipeline_layout_handle(pipeline_layout); + command_buffer->syms->vkCmdPushConstants( + command_buffer->handle, pipeline_layout_handle, + VK_SHADER_STAGE_COMPUTE_BIT, (uint32_t)0, + (uint32_t)constants.data_length, constants.data); + } + + // Retain bound buffers until the command buffer is reset. + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided( + command_buffer->resource_set, bindings.count, bindings.values, + offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t))); + + // Either allocate, update, and bind a descriptor set or use push descriptor + // sets to use the command buffer pool when supported. + IREE_RETURN_IF_ERROR(command_buffer->descriptor_set_arena.BindDescriptorSet( + command_buffer->handle, pipeline_layout, 0, bindings.count, + bindings.values)); + + command_buffer->syms->vkCmdBindPipeline( + command_buffer->handle, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline_handle); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + IREE_TRACE({ + iree_hal_vulkan_source_location_t source_location; + iree_hal_vulkan_native_executable_entry_point_source_location( + executable, entry_point, &source_location); + IREE_VULKAN_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->handle, + source_location.file_name.data, source_location.file_name.size, + source_location.line, source_location.func_name.data, + source_location.func_name.size, /*name=*/NULL, 0); + }); + + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( + command_buffer->resource_set, 1, &executable)); + + IREE_RETURN_IF_ERROR(iree_hal_vulkan_direct_command_buffer_dispatch2_bind( + command_buffer, executable, entry_point, constants, bindings, flags)); + + command_buffer->syms->vkCmdDispatch(command_buffer->handle, + workgroup_count[0], workgroup_count[1], + workgroup_count[2]); + + IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->handle); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_vulkan_direct_command_buffer_t* command_buffer = + iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); + + IREE_TRACE({ + iree_hal_vulkan_source_location_t source_location; + iree_hal_vulkan_native_executable_entry_point_source_location( + executable, entry_point, &source_location); + IREE_VULKAN_TRACE_ZONE_BEGIN_EXTERNAL( + command_buffer->tracing_context, command_buffer->handle, + source_location.file_name.data, source_location.file_name.size, + source_location.line, source_location.func_name.data, + source_location.func_name.size, /*name=*/NULL, 0); + }); + + const void* resources[2] = {executable, workgroups_ref.buffer}; + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( + command_buffer->resource_set, IREE_ARRAYSIZE(resources), resources)); + + IREE_RETURN_IF_ERROR(iree_hal_vulkan_direct_command_buffer_dispatch2_bind( + command_buffer, executable, entry_point, constants, bindings, flags)); + + VkBuffer workgroups_device_buffer = + iree_hal_vulkan_buffer_handle(workgroups_ref.buffer); + iree_device_size_t workgroups_offset = + iree_hal_buffer_byte_offset(workgroups_ref.buffer) + + workgroups_ref.offset; + command_buffer->syms->vkCmdDispatchIndirect( + command_buffer->handle, workgroups_device_buffer, workgroups_offset); + + IREE_VULKAN_TRACE_ZONE_END(command_buffer->tracing_context, + command_buffer->handle); + + return iree_ok_status(); +} + namespace { const iree_hal_command_buffer_vtable_t iree_hal_vulkan_direct_command_buffer_vtable = { @@ -836,5 +950,8 @@ const iree_hal_command_buffer_vtable_t /*.dispatch=*/iree_hal_vulkan_direct_command_buffer_dispatch, /*.dispatch_indirect=*/ iree_hal_vulkan_direct_command_buffer_dispatch_indirect, + /*.dispatch2=*/iree_hal_vulkan_direct_command_buffer_dispatch2, + /*.dispatch2_indirect=*/ + iree_hal_vulkan_direct_command_buffer_dispatch2_indirect, }; } // namespace diff --git a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc index b6d8dc67f7b4..ebfd00626532 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_executable.cc +++ b/runtime/src/iree/hal/drivers/vulkan/native_executable.cc @@ -26,6 +26,7 @@ using namespace iree::hal::vulkan; typedef struct iree_hal_vulkan_entry_point_t { VkPipeline pipeline; + iree_hal_pipeline_layout_t* layout; iree_string_view_t name; // Optional debug information. @@ -107,6 +108,11 @@ static iree_status_t iree_hal_vulkan_create_pipelines( iree_hal_spirv_ExecutableDef_subgroup_sizes_get(executable_def); for (iree_host_size_t entry_ordinal = 0; entry_ordinal < pipeline_count; ++entry_ordinal) { + iree_hal_pipeline_layout_t* pipeline_layout = + executable_params->pipeline_layouts[entry_ordinal]; + iree_hal_pipeline_layout_retain(pipeline_layout); + out_entry_points[entry_ordinal].layout = pipeline_layout; + VkComputePipelineCreateInfo* create_info = &create_infos[entry_ordinal]; create_info->sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; create_info->pNext = NULL; @@ -121,8 +127,8 @@ static iree_status_t iree_hal_vulkan_create_pipelines( } else { create_info->flags |= VK_PIPELINE_CREATE_DERIVATIVE_BIT; } - create_info->layout = iree_hal_vulkan_native_pipeline_layout_handle( - executable_params->pipeline_layouts[entry_ordinal]); + create_info->layout = + iree_hal_vulkan_native_pipeline_layout_handle(pipeline_layout); create_info->basePipelineHandle = VK_NULL_HANDLE; create_info->basePipelineIndex = 0; @@ -472,6 +478,7 @@ static void iree_hal_vulkan_native_executable_destroy( for (iree_host_size_t i = 0; i < executable->entry_point_count; ++i) { iree_hal_vulkan_destroy_pipeline(executable->logical_device, executable->entry_points[i].pipeline); + iree_hal_pipeline_layout_release(executable->entry_points[i].layout); } iree_allocator_free(host_allocator, executable); @@ -528,7 +535,8 @@ void iree_hal_vulkan_native_executable_entry_point_source_location( iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( iree_hal_executable_t* base_executable, iree_host_size_t entry_ordinal, - VkPipeline* out_pipeline_handle) { + VkPipeline* out_pipeline_handle, + iree_hal_pipeline_layout_t** out_pipeline_layout) { iree_hal_vulkan_native_executable_t* executable = iree_hal_vulkan_native_executable_cast(base_executable); if (entry_ordinal >= executable->entry_point_count) { @@ -537,6 +545,9 @@ iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( entry_ordinal); } *out_pipeline_handle = executable->entry_points[entry_ordinal].pipeline; + if (out_pipeline_layout) { + *out_pipeline_layout = executable->entry_points[entry_ordinal].layout; + } return iree_ok_status(); } diff --git a/runtime/src/iree/hal/drivers/vulkan/native_executable.h b/runtime/src/iree/hal/drivers/vulkan/native_executable.h index da6a845bf221..248db1dde5b4 100644 --- a/runtime/src/iree/hal/drivers/vulkan/native_executable.h +++ b/runtime/src/iree/hal/drivers/vulkan/native_executable.h @@ -43,7 +43,8 @@ void iree_hal_vulkan_native_executable_entry_point_source_location( // Returns the cached VkPipeline for the given executable |entry_ordinal|. iree_status_t iree_hal_vulkan_native_executable_pipeline_for_entry_point( iree_hal_executable_t* executable, iree_host_size_t entry_ordinal, - VkPipeline* out_pipeline_handle); + VkPipeline* out_pipeline_handle, + iree_hal_pipeline_layout_t** out_pipeline_layout); #ifdef __cplusplus } // extern "C" diff --git a/runtime/src/iree/hal/executable_cache.h b/runtime/src/iree/hal/executable_cache.h index bee9bf9f2a03..435f01d5e584 100644 --- a/runtime/src/iree/hal/executable_cache.h +++ b/runtime/src/iree/hal/executable_cache.h @@ -92,6 +92,9 @@ typedef struct iree_hal_executable_params_t { // to any executable created using it still held by the caller. iree_const_byte_span_t executable_data; + // TODO(#18154): drop pipeline layouts with simplified bindings. Allowed to be + // empty for now on targets that support simplified bindings. + // // A set of pipeline layouts for each entry point in the executable. // The order matches that produced by the compiler. As multiple entry points // may share the same layout some entries in this list may reference the same diff --git a/runtime/src/iree/hal/local/executable_library.h b/runtime/src/iree/hal/local/executable_library.h index d917f0f35ce9..d45b477d5573 100644 --- a/runtime/src/iree/hal/local/executable_library.h +++ b/runtime/src/iree/hal/local/executable_library.h @@ -372,19 +372,27 @@ typedef int (*iree_hal_executable_dispatch_v0_t)( // Bytes per page of workgroup local memory. // This is chosen to match the common page size of devices. -#define IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE 4096 +#define IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE 4096 + +// Maximum number of constants that can be used by a single dispatch. +#define IREE_HAL_EXECUTABLE_MAX_CONSTANT_COUNT 64 +// Maximum number of bindings that can be used by a single dispatch. +#define IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT 64 // Attributes for exported dispatch functions defining how they are to be // executed. 0 defaults are well-specified and the entire attributes table may // be omitted if no dispatch functions require these fields. typedef struct iree_hal_executable_dispatch_attrs_v0_t { - // Number of IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE byte pages (or 0) - // indicating how much workgroup local memory is required for the dispatch. - // This is the size of the buffer referenced by the `local_memory` argument. + // Number of IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE byte pages + // (or 0) indicating how much workgroup local memory is required for the + // dispatch. This is the size of the buffer referenced by the `local_memory` + // argument. uint16_t local_memory_pages; - // Must be 0. May be used in the future for flags controlling the dispatch - // behavior/synchronization requirements. - uint16_t reserved; + // Total number of 32-bit constants used by the dispatch. + uint8_t constant_count; + // Total number of bindings used by the dispatch. + uint8_t binding_count; + // TODO(#18189): add ~8 uint64_t fields for binding bits (readonly/indirect). } iree_hal_executable_dispatch_attrs_v0_t; static_assert(sizeof(iree_hal_executable_dispatch_attrs_v0_t) == 4, "uint32_t"); diff --git a/runtime/src/iree/hal/local/executable_library_benchmark.c b/runtime/src/iree/hal/local/executable_library_benchmark.c index 95403ded3b56..d87149dfbeca 100644 --- a/runtime/src/iree/hal/local/executable_library_benchmark.c +++ b/runtime/src/iree/hal/local/executable_library_benchmark.c @@ -186,7 +186,7 @@ static iree_status_t iree_hal_executable_library_run( local_executable->dispatch_attrs ? local_executable->dispatch_attrs[FLAG_entry_point] .local_memory_pages * - IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE + IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE : 0; if (local_memory_size > 0) { IREE_RETURN_IF_ERROR(iree_allocator_malloc( diff --git a/runtime/src/iree/hal/local/executable_library_demo.c b/runtime/src/iree/hal/local/executable_library_demo.c index af18875d7fae..300d645120ae 100644 --- a/runtime/src/iree/hal/local/executable_library_demo.c +++ b/runtime/src/iree/hal/local/executable_library_demo.c @@ -66,9 +66,13 @@ static const iree_hal_executable_dispatch_v0_t entry_points[2] = { static const iree_hal_executable_dispatch_attrs_v0_t entry_attrs[2] = { { .local_memory_pages = 0, + .constant_count = 1, + .binding_count = 2, }, { .local_memory_pages = 0, + .constant_count = 0, + .binding_count = 0, }, }; // Names for each entry point. diff --git a/runtime/src/iree/hal/local/executable_library_util.c b/runtime/src/iree/hal/local/executable_library_util.c index 5dbe51c9552b..b2d1165fe1f7 100644 --- a/runtime/src/iree/hal/local/executable_library_util.c +++ b/runtime/src/iree/hal/local/executable_library_util.c @@ -39,6 +39,30 @@ iree_status_t iree_hal_executable_library_verify( executable_params->constant_count); } + // If dispatch attributes are present validate they are in range. + if (library->exports.attrs) { + for (uint32_t i = 0; i < library->exports.count; ++i) { + const iree_hal_executable_dispatch_attrs_v0_t dispatch_attrs = + library->exports.attrs[i]; + if (dispatch_attrs.constant_count > + IREE_HAL_EXECUTABLE_MAX_CONSTANT_COUNT) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "dispatch requiring %u constants exceeds limit of %d", + dispatch_attrs.constant_count, + IREE_HAL_EXECUTABLE_MAX_CONSTANT_COUNT); + } + if (dispatch_attrs.binding_count > + IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT) { + return iree_make_status( + IREE_STATUS_OUT_OF_RANGE, + "dispatch requiring %u bindings exceeds limit of %d", + dispatch_attrs.binding_count, + IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT); + } + } + } + return iree_ok_status(); } diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c index 3de7c601e312..2e0465c0bcf7 100644 --- a/runtime/src/iree/hal/local/inline_command_buffer.c +++ b/runtime/src/iree/hal/local/inline_command_buffer.c @@ -28,24 +28,18 @@ typedef struct iree_hal_inline_command_buffer_t { iree_allocator_t host_allocator; struct { + // TODO(#18189): remove legacy bindings state. + // // A flattened list of all available descriptor set bindings. // As descriptor sets are pushed/bound the bindings will be updated to // represent the fully-translated binding data pointer. - // - // TODO(benvanik): support proper mapping semantics and track the - // iree_hal_buffer_mapping_t and map/unmap where appropriate. void* full_bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; size_t full_binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; - // Packed bindings scratch space used during dispatch. Executable bindings - // are packed into a dense list with unused bindings removed. - void* packed_bindings[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * - IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; - size_t packed_binding_lengths[IREE_HAL_LOCAL_MAX_DESCRIPTOR_SET_COUNT * - IREE_HAL_LOCAL_MAX_DESCRIPTOR_BINDING_COUNT]; - + // TODO(#18189): remove legacy push constant state. + // // All available push constants updated each time push_constants is called. // Reset only with the command buffer and otherwise will maintain its values // during recording to allow for partial push_constants updates. @@ -55,6 +49,10 @@ typedef struct iree_hal_inline_command_buffer_t { // Individual dispatches must populate the dynamically changing fields like // push_constant_count and binding_count. iree_alignas(64) iree_hal_executable_dispatch_state_v0_t dispatch_state; + // Persistent storage for binding pointers used by dispatch_state. + void* binding_ptr_storage[IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT]; + // Persistent storage for binding lengths used by dispatch_state. + size_t binding_length_storage[IREE_HAL_EXECUTABLE_MAX_BINDING_COUNT]; // An opaque tag used to reduce the cost of processor ID queries. iree_cpu_processor_tag_t processor_tag; @@ -80,9 +78,9 @@ static void iree_hal_inline_command_buffer_reset( iree_hal_executable_dispatch_state_v0_t* dispatch_state = &command_buffer->state.dispatch_state; dispatch_state->push_constants = command_buffer->state.push_constants; - dispatch_state->binding_ptrs = command_buffer->state.packed_bindings; + dispatch_state->binding_ptrs = command_buffer->state.binding_ptr_storage; dispatch_state->binding_lengths = - command_buffer->state.packed_binding_lengths; + command_buffer->state.binding_length_storage; } iree_host_size_t iree_hal_inline_command_buffer_size( @@ -461,7 +459,7 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch( iree_host_size_t local_memory_size = local_executable->dispatch_attrs ? local_executable->dispatch_attrs[entry_point].local_memory_pages * - IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE + IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE : 0; // Update the ID of the processor we are running on. @@ -489,6 +487,7 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch( // only allow the dispatch to read what we know is initialized based on the // layout. dispatch_state->push_constant_count = local_layout->push_constants; + dispatch_state->push_constants = command_buffer->state.push_constants; // Produce the dense binding list based on the declared bindings used. // This allows us to change the descriptor sets and bindings counts supported @@ -548,6 +547,123 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch( return status; } +static iree_status_t iree_hal_inline_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_inline_command_buffer_t* command_buffer = + iree_hal_inline_command_buffer_cast(base_command_buffer); + + iree_hal_local_executable_t* local_executable = + iree_hal_local_executable_cast(executable); + + iree_hal_executable_dispatch_attrs_v0_t dispatch_attrs = {0}; + if (local_executable->dispatch_attrs) { + dispatch_attrs = local_executable->dispatch_attrs[entry_point]; + } + const iree_host_size_t local_memory_size = + dispatch_attrs.local_memory_pages * + IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE; + + // Update the ID of the processor we are running on. + // We don't know how much time has passed since we last updated as we are + // running inline with the user program; if we knew we were going to be + // handling a batch of dispatches we could reduce the amount of times we call + // this - but that's what the task system is for. + iree_hal_inline_command_buffer_update_processor_id(command_buffer); + + iree_hal_executable_dispatch_state_v0_t* dispatch_state = + &command_buffer->state.dispatch_state; + + // TODO(benvanik): expose on API or keep fixed on executable. + dispatch_state->workgroup_size_x = 1; + dispatch_state->workgroup_size_y = 1; + dispatch_state->workgroup_size_z = 1; + dispatch_state->workgroup_count_x = workgroup_count[0]; + dispatch_state->workgroup_count_y = workgroup_count[1]; + dispatch_state->workgroup_count_z = workgroup_count[2]; + + // Single-threaded. + dispatch_state->max_concurrency = 1; + + // Push constants are pulled directly from the args. Note that we require 4 + // byte alignment and if the input buffer is not aligned we have to fail. + if (IREE_UNLIKELY((constants.data_length % sizeof(uint32_t)) != 0)) { + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "constants must be 4-byte aligned"); + } else if (IREE_UNLIKELY(constants.data_length != + dispatch_attrs.constant_count * sizeof(uint32_t))) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "constant count mismatch, expected %u but was provided %" PRIhsz, + (uint32_t)dispatch_attrs.constant_count, + constants.data_length / sizeof(uint32_t)); + } + dispatch_state->push_constant_count = dispatch_attrs.constant_count; + dispatch_state->push_constants = (const uint32_t*)constants.data; + + // Produce the dense binding list based on the declared bindings used. + // + // Note that we are just directly setting the binding data pointers here with + // no ownership/retaining/etc - it's part of the HAL contract that buffers are + // kept valid for the duration they may be in use. + if (IREE_UNLIKELY(bindings.count != dispatch_attrs.binding_count)) { + return iree_make_status( + IREE_STATUS_INVALID_ARGUMENT, + "binding count mismatch, expected %u but was provided %" PRIhsz, + (uint32_t)dispatch_attrs.binding_count, bindings.count); + } + dispatch_state->binding_count = bindings.count; + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. + iree_hal_buffer_mapping_t buffer_mapping = {{0}}; + if (IREE_LIKELY(bindings.values[i].buffer)) { + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + bindings.values[i].buffer, IREE_HAL_MAPPING_MODE_PERSISTENT, + IREE_HAL_MEMORY_ACCESS_ANY, bindings.values[i].offset, + bindings.values[i].length, &buffer_mapping)); + } else { + return iree_make_status( + IREE_STATUS_FAILED_PRECONDITION, + "required binding %" PRIhsz + " is NULL; all bindings must have a valid pointer", + i); + } + command_buffer->state.binding_ptr_storage[i] = buffer_mapping.contents.data; + command_buffer->state.binding_length_storage[i] = + buffer_mapping.contents.data_length; + } + + // TODO(benvanik): plumb through an arena or fixed-size reservation to use. + // For now when deploying to devices where you want something like the + // inline command buffer you probably don't want 256KB of transient memory + // getting allocated and retained implicitly - this should be a compiler + // option. For now we just malloc here to make things work and strongly + // encourage the kind of user who wants synchronous inline execution to not + // also want tons of scratch memory. + iree_byte_span_t local_memory = iree_make_byte_span(NULL, local_memory_size); + if (local_memory_size > 0) { + IREE_RETURN_IF_ERROR(iree_allocator_malloc(command_buffer->host_allocator, + local_memory_size, + (void**)&local_memory.data)); + } + + // Since we are running on a borrowed thread, we know nothing about the + // floating point state. Reset it. + iree_fpu_state_t fpu_state = + iree_fpu_state_push(IREE_FPU_STATE_FLAG_FLUSH_DENORMALS_TO_ZERO); + iree_status_t status = iree_hal_local_executable_issue_dispatch_inline( + local_executable, entry_point, dispatch_state, + command_buffer->state.processor_id, local_memory); + iree_fpu_state_pop(fpu_state); + + if (local_memory.data) { + iree_allocator_free(command_buffer->host_allocator, local_memory.data); + } + return status; +} + typedef union iree_hal_vec3_t { struct { uint32_t x; @@ -574,6 +690,24 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect( workgroup_count.y, workgroup_count.z, flags); } +static iree_status_t iree_hal_inline_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. + iree_hal_buffer_mapping_t buffer_mapping = {{0}}; + IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( + workgroups_ref.buffer, IREE_HAL_MAPPING_MODE_PERSISTENT, + IREE_HAL_MEMORY_ACCESS_READ, workgroups_ref.offset, 3 * sizeof(uint32_t), + &buffer_mapping)); + iree_hal_vec3_t workgroup_count = + *(const iree_hal_vec3_t*)buffer_mapping.contents.data; + return iree_hal_inline_command_buffer_dispatch2( + base_command_buffer, executable, entry_point, workgroup_count.value, + constants, bindings, flags); +} + //===----------------------------------------------------------------------===// // iree_hal_command_buffer_vtable_t //===----------------------------------------------------------------------===// @@ -599,4 +733,6 @@ static const iree_hal_command_buffer_vtable_t iree_hal_inline_command_buffer_push_descriptor_set, .dispatch = iree_hal_inline_command_buffer_dispatch, .dispatch_indirect = iree_hal_inline_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_inline_command_buffer_dispatch2, + .dispatch2_indirect = iree_hal_inline_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c b/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c index 265c0b6bad1c..2675f8e78be0 100644 --- a/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c +++ b/runtime/src/iree/hal/local/loaders/vmvx_module_loader.c @@ -302,6 +302,7 @@ static iree_status_t iree_hal_vmvx_executable_create( .linkage = IREE_VM_FUNCTION_LINKAGE_EXPORT, .ordinal = executable->entry_fn_ordinals[i], }; + iree_string_view_t local_memory_str = iree_vm_function_lookup_attr_by_name( &entry_fn, iree_make_cstring_view("local_memory")); @@ -309,8 +310,26 @@ static iree_status_t iree_hal_vmvx_executable_create( if (!iree_string_view_is_empty(local_memory_str)) { iree_string_view_atoi_uint32(local_memory_str, &local_memory_size); } - local_memory_size /= IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE; + local_memory_size /= IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE; dispatch_attrs[i].local_memory_pages = (uint16_t)local_memory_size; + + iree_string_view_t constant_count_str = + iree_vm_function_lookup_attr_by_name( + &entry_fn, iree_make_cstring_view("constant_count")); + uint32_t constant_count = 0; + if (!iree_string_view_is_empty(constant_count_str)) { + iree_string_view_atoi_uint32(constant_count_str, &constant_count); + } + dispatch_attrs[i].constant_count = (uint8_t)constant_count; + + iree_string_view_t binding_count_str = + iree_vm_function_lookup_attr_by_name( + &entry_fn, iree_make_cstring_view("binding_count")); + uint32_t binding_count = 0; + if (!iree_string_view_is_empty(binding_count_str)) { + iree_string_view_atoi_uint32(binding_count_str, &binding_count); + } + dispatch_attrs[i].binding_count = (uint8_t)binding_count; } } diff --git a/runtime/src/iree/hal/local/local_executable.h b/runtime/src/iree/hal/local/local_executable.h index 6eeb038d169a..b6e24454db89 100644 --- a/runtime/src/iree/hal/local/local_executable.h +++ b/runtime/src/iree/hal/local/local_executable.h @@ -31,8 +31,8 @@ typedef struct iree_hal_local_executable_t { // Defines per-entry point how much workgroup local memory is required. // Contains entries with 0 to indicate no local memory is required or >0 in - // units of IREE_HAL_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE for the minimum amount - // of memory required by the function. + // units of IREE_HAL_EXECUTABLE_WORKGROUP_LOCAL_MEMORY_PAGE_SIZE for the + // minimum amount of memory required by the function. const iree_hal_executable_dispatch_attrs_v0_t* dispatch_attrs; // Execution environment. diff --git a/runtime/src/iree/hal/pipeline_layout.h b/runtime/src/iree/hal/pipeline_layout.h index bd15bb1a51c5..d090868a5083 100644 --- a/runtime/src/iree/hal/pipeline_layout.h +++ b/runtime/src/iree/hal/pipeline_layout.h @@ -43,9 +43,15 @@ typedef enum iree_hal_descriptor_type_e { // A bitmask of flags controlling the behavior of a descriptor. enum iree_hal_descriptor_flag_bits_t { IREE_HAL_DESCRIPTOR_FLAG_NONE = 0u, + // Indicates that the binding is treated as immutable within all dispatches // using it. IREE_HAL_DESCRIPTOR_FLAG_READ_ONLY = 1u << 0, + + // Indicates the descriptor is 'bindless' and passed via implementation- + // specific parameter buffers stored in memory instead of API-level calls. + // Ignored by implementations that don't have a concept of indirect bindings. + IREE_HAL_DESCRIPTOR_FLAG_INDIRECT = 1u << 1, }; typedef uint32_t iree_hal_descriptor_flags_t; diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c index 49ec3347a584..a1c92bf9e054 100644 --- a/runtime/src/iree/hal/utils/deferred_command_buffer.c +++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c @@ -27,6 +27,8 @@ typedef enum iree_hal_command_type_e { IREE_HAL_CMD_PUSH_DESCRIPTOR_SET, IREE_HAL_CMD_DISPATCH, IREE_HAL_CMD_DISPATCH_INDIRECT, + IREE_HAL_CMD_DISPATCH2, + IREE_HAL_CMD_DISPATCH2_INDIRECT, } iree_hal_cmd_type_t; // Header prefixed to all commands, forming a linked-list. @@ -854,6 +856,162 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch_indirect( cmd->flags); } +//===----------------------------------------------------------------------===// +// IREE_HAL_CMD_DISPATCH2 +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_cmd_dispatch2_t { + iree_hal_cmd_header_t header; + iree_hal_executable_t* executable; + int32_t entry_point; + uint32_t workgroup_count[3]; + iree_const_byte_span_t constants; + iree_hal_buffer_ref_list_t bindings; + iree_hal_dispatch_flags_t flags; +} iree_hal_cmd_dispatch2_t; + +static iree_status_t iree_hal_deferred_command_buffer_dispatch2( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + const uint32_t workgroup_count[3], iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_deferred_command_buffer_t* command_buffer = + iree_hal_deferred_command_buffer_cast(base_command_buffer); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( + command_buffer->resource_set, 1, &executable)); + + iree_hal_cmd_dispatch2_t* cmd = NULL; + iree_host_size_t total_size = + sizeof(*cmd) + iree_host_align(constants.data_length, iree_max_align_t) + + bindings.count * sizeof(bindings.values[0]); + IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command( + &command_buffer->cmd_list, IREE_HAL_CMD_DISPATCH2, total_size, + (void**)&cmd)); + cmd->executable = executable; + cmd->entry_point = entry_point; + memcpy(cmd->workgroup_count, workgroup_count, sizeof(cmd->workgroup_count)); + cmd->flags = flags; + + uint8_t* cmd_ptr = (uint8_t*)cmd; + cmd_ptr += sizeof(*cmd); + + memcpy(cmd_ptr, constants.data, constants.data_length); + cmd->constants = iree_make_const_byte_span(cmd_ptr, constants.data_length); + cmd_ptr += iree_host_align(constants.data_length, iree_max_align_t); + + cmd->bindings.count = bindings.count; + memcpy(cmd_ptr, bindings.values, bindings.count * sizeof(bindings.values[0])); + cmd->bindings.values = (iree_hal_buffer_ref_t*)cmd_ptr; + cmd_ptr += bindings.count * sizeof(bindings.values[0]); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided( + command_buffer->resource_set, bindings.count, bindings.values, + offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t))); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch2( + iree_hal_command_buffer_t* target_command_buffer, + iree_hal_buffer_binding_table_t binding_table, + const iree_hal_cmd_dispatch2_t* cmd) { + iree_hal_buffer_ref_t* binding_refs = (iree_hal_buffer_ref_t*)iree_alloca( + cmd->bindings.count * sizeof(iree_hal_buffer_ref_t)); + for (iree_host_size_t i = 0; i < cmd->bindings.count; ++i) { + IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref( + binding_table, cmd->bindings.values[i], &binding_refs[i])); + } + const iree_hal_buffer_ref_list_t binding_ref_list = { + .count = cmd->bindings.count, + .values = binding_refs, + }; + return iree_hal_command_buffer_dispatch2( + target_command_buffer, cmd->executable, cmd->entry_point, + cmd->workgroup_count, cmd->constants, binding_ref_list, cmd->flags); +} + +//===----------------------------------------------------------------------===// +// IREE_HAL_CMD_DISPATCH2_INDIRECT +//===----------------------------------------------------------------------===// + +typedef struct iree_hal_cmd_dispatch2_indirect_t { + iree_hal_cmd_header_t header; + iree_hal_executable_t* executable; + int32_t entry_point; + iree_hal_buffer_ref_t workgroups_ref; + iree_const_byte_span_t constants; + iree_hal_buffer_ref_list_t bindings; + iree_hal_dispatch_flags_t flags; +} iree_hal_cmd_dispatch2_indirect_t; + +static iree_status_t iree_hal_deferred_command_buffer_dispatch2_indirect( + iree_hal_command_buffer_t* base_command_buffer, + iree_hal_executable_t* executable, int32_t entry_point, + iree_hal_buffer_ref_t workgroups_ref, iree_const_byte_span_t constants, + iree_hal_buffer_ref_list_t bindings, iree_hal_dispatch_flags_t flags) { + iree_hal_deferred_command_buffer_t* command_buffer = + iree_hal_deferred_command_buffer_cast(base_command_buffer); + + iree_host_size_t resource_count = 0; + const void* resources[2] = {NULL, NULL}; + resources[resource_count++] = executable; + if (workgroups_ref.buffer) { + resources[resource_count++] = workgroups_ref.buffer; + } + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( + command_buffer->resource_set, resource_count, resources)); + + iree_hal_cmd_dispatch2_indirect_t* cmd = NULL; + iree_host_size_t total_size = + sizeof(*cmd) + iree_host_align(constants.data_length, iree_max_align_t) + + bindings.count * sizeof(bindings.values[0]); + IREE_RETURN_IF_ERROR(iree_hal_cmd_list_append_command( + &command_buffer->cmd_list, IREE_HAL_CMD_DISPATCH2_INDIRECT, total_size, + (void**)&cmd)); + cmd->executable = executable; + cmd->entry_point = entry_point; + cmd->workgroups_ref = workgroups_ref; + cmd->flags = flags; + + uint8_t* cmd_ptr = (uint8_t*)cmd; + cmd_ptr += sizeof(*cmd); + + memcpy(cmd_ptr, constants.data, constants.data_length); + cmd->constants = iree_make_const_byte_span(cmd_ptr, constants.data_length); + cmd_ptr += iree_host_align(constants.data_length, iree_max_align_t); + + cmd->bindings.count = bindings.count; + memcpy(cmd_ptr, bindings.values, bindings.count * sizeof(bindings.values[0])); + cmd->bindings.values = (iree_hal_buffer_ref_t*)cmd_ptr; + cmd_ptr += bindings.count * sizeof(bindings.values[0]); + IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert_strided( + command_buffer->resource_set, bindings.count, bindings.values, + offsetof(iree_hal_buffer_ref_t, buffer), sizeof(iree_hal_buffer_ref_t))); + + return iree_ok_status(); +} + +static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch2_indirect( + iree_hal_command_buffer_t* target_command_buffer, + iree_hal_buffer_binding_table_t binding_table, + const iree_hal_cmd_dispatch2_indirect_t* cmd) { + iree_hal_buffer_ref_t workgroups_ref; + IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref( + binding_table, cmd->workgroups_ref, &workgroups_ref)); + iree_hal_buffer_ref_t* binding_refs = (iree_hal_buffer_ref_t*)iree_alloca( + cmd->bindings.count * sizeof(iree_hal_buffer_ref_t)); + for (iree_host_size_t i = 0; i < cmd->bindings.count; ++i) { + IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref( + binding_table, cmd->bindings.values[i], &binding_refs[i])); + } + const iree_hal_buffer_ref_list_t binding_ref_list = { + .count = cmd->bindings.count, + .values = binding_refs, + }; + return iree_hal_command_buffer_dispatch2_indirect( + target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref, + cmd->constants, binding_ref_list, cmd->flags); +} + //===----------------------------------------------------------------------===// // Dynamic replay dispatch //===----------------------------------------------------------------------===// @@ -885,6 +1043,10 @@ static const iree_hal_cmd_apply_fn_t iree_hal_cmd_apply_table[] = { iree_hal_deferred_command_buffer_apply_dispatch, [IREE_HAL_CMD_DISPATCH_INDIRECT] = (iree_hal_cmd_apply_fn_t) iree_hal_deferred_command_buffer_apply_dispatch_indirect, + [IREE_HAL_CMD_DISPATCH2] = (iree_hal_cmd_apply_fn_t) + iree_hal_deferred_command_buffer_apply_dispatch2, + [IREE_HAL_CMD_DISPATCH2_INDIRECT] = (iree_hal_cmd_apply_fn_t) + iree_hal_deferred_command_buffer_apply_dispatch2_indirect, }; IREE_API_EXPORT iree_status_t iree_hal_deferred_command_buffer_apply( @@ -943,4 +1105,7 @@ static const iree_hal_command_buffer_vtable_t iree_hal_deferred_command_buffer_push_descriptor_set, .dispatch = iree_hal_deferred_command_buffer_dispatch, .dispatch_indirect = iree_hal_deferred_command_buffer_dispatch_indirect, + .dispatch2 = iree_hal_deferred_command_buffer_dispatch2, + .dispatch2_indirect = + iree_hal_deferred_command_buffer_dispatch2_indirect, }; diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index f6f96f2b810d..8d444456d598 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -50,8 +50,11 @@ EXPORT_FN("command_buffer.begin_debug_group", iree_hal_module_command_buffer_beg EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v) EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v) EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r) +// TODO(#18154): replace base dispatch with new `2` versions. EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiiiI, v) EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirII, v) +EXPORT_FN_CUSTOM("command_buffer.dispatch2", iree_hal_module_command_buffer_dispatch2, rriiiiICiDCiirIID, v) +EXPORT_FN_CUSTOM("command_buffer.dispatch2.indirect", iree_hal_module_command_buffer_dispatch2_indirect, rriirIICiDCiirIID, v) EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v) EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v) EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v) @@ -77,7 +80,9 @@ EXPORT_FN("devices.get", iree_hal_module_devices_get, i, r) EXPORT_FN("ex.file.from_memory", iree_hal_module_ex_file_from_memory, rIirIIi, r) +// TODO(#18154): replace base executable create with new `2` versions. EXPORT_FN("executable.create", iree_hal_module_executable_create, rrrrCrD, r) +EXPORT_FN("executable.create2", iree_hal_module_executable_create2, rrrr, r) EXPORT_FN("fence.await", iree_hal_module_fence_await, iCrD, i) EXPORT_FN("fence.create", iree_hal_module_fence_create, ri, r) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index e599d7740423..f3cac5ef9697 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -32,8 +32,8 @@ // Module type definitions //===----------------------------------------------------------------------===// -#define IREE_HAL_MODULE_VERSION_0_3 0x00000003u -#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_3 +#define IREE_HAL_MODULE_VERSION_0_4 0x00000004u +#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_4 typedef struct iree_hal_module_t { iree_allocator_t host_allocator; @@ -945,6 +945,212 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, // command_buffer, executable, entry_point, workgroups_ref, flags); } +// Argument signature: rriiiiICiDCiirIID +typedef struct { + union { + struct { + iree_vm_ref_t command_buffer; + iree_vm_ref_t executable; + int32_t entry_point; + uint32_t workgroup_count[3]; + iree_hal_dispatch_flags_t flags; + }; + iree_vm_abi_rriiiiI_t params; + }; + iree_vm_size_t constant_count; + const uint32_t* constants; + iree_vm_size_t binding_count; + const iree_vm_abi_iirII_t* bindings; +} iree_hal_module_command_buffer_dispatch2_args_t; +static iree_status_t iree_hal_module_command_buffer_dispatch2( + iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module, + iree_hal_module_state_t* IREE_RESTRICT state, + const iree_hal_module_command_buffer_dispatch2_args_t* IREE_RESTRICT args) { + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_check_deref(args->command_buffer, + &command_buffer)); + iree_hal_executable_t* executable = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_executable_check_deref(args->executable, &executable)); + + if (IREE_UNLIKELY(args->binding_count > + IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "binding count %" PRIhsz " > %" PRIhsz, + (iree_host_size_t)args->binding_count, + IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT); + } + iree_hal_buffer_ref_list_t bindings = { + .count = (iree_host_size_t)args->binding_count, + .values = (iree_hal_buffer_ref_t*)iree_alloca( + args->binding_count * sizeof(iree_hal_buffer_ref_t)), + }; + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + iree_hal_buffer_ref_t* binding = + (iree_hal_buffer_ref_t*)&bindings.values[i]; + binding->ordinal = 0; + binding->buffer_slot = (uint32_t)args->bindings[i].i1; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null( + args->bindings[i].r2, &binding->buffer)); + binding->offset = iree_hal_cast_device_size(args->bindings[i].i3); + binding->length = iree_hal_cast_device_size(args->bindings[i].i4); + } + + return iree_hal_command_buffer_dispatch2( + command_buffer, executable, args->entry_point, args->workgroup_count, + iree_make_const_byte_span(args->constants, + args->constant_count * sizeof(uint32_t)), + bindings, (iree_hal_dispatch_flags_t)args->flags); +} +static iree_status_t iree_hal_module_command_buffer_dispatch2_shim( + iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags, + iree_byte_span_t args_storage, iree_byte_span_t rets_storage, + iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module, + void* IREE_RESTRICT module_state) { + // TODO(benvanik): support multiple variadic segments in one call. + // For now we inline what it would do in a very painful way. + bool args_ok = true; + if (args_storage.data_length < + (sizeof(iree_vm_abi_rriiiiI_t) + sizeof(iree_vm_size_t) + + sizeof(iree_vm_size_t))) { + // Can't fit even with zero lengths. + args_ok = false; + } + iree_hal_module_command_buffer_dispatch2_args_t args = { + .params = *(const iree_vm_abi_rriiiiI_t*)args_storage.data, + }; + if (args_ok) { + const uint8_t* constants_ptr = args_storage.data + sizeof(args.params); + args.constant_count = *(const iree_vm_size_t*)constants_ptr; + args.constants = (const uint32_t*)(constants_ptr + sizeof(iree_vm_size_t)); + const uint8_t* bindings_ptr = + constants_ptr + sizeof(iree_vm_size_t) + + args.constant_count * sizeof(args.constants[0]); + args.binding_count = *(const iree_vm_size_t*)bindings_ptr; + args.bindings = + (const iree_vm_abi_iirII_t*)(bindings_ptr + sizeof(iree_vm_size_t)); + const uint8_t* max_ptr = (const uint8_t*)args.bindings + + args.binding_count * sizeof(args.bindings[0]); + const uint8_t* end_ptr = args_storage.data + args_storage.data_length; + if (max_ptr > end_ptr) args_ok = false; + } + if (IREE_UNLIKELY(!args_ok || rets_storage.data_length > 0)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument/result signature mismatch"); + } + IREE_ASSERT(target_fn == (iree_vm_native_function_target2_t) + iree_hal_module_command_buffer_dispatch2); + return iree_hal_module_command_buffer_dispatch2(stack, module, module_state, + &args); +} + +// Argument signature: rriirIICiDCiirIID +typedef struct { + union { + struct { + iree_vm_ref_t command_buffer; + iree_vm_ref_t executable; + int32_t entry_point; + int32_t workgroups_buffer_slot; + iree_vm_ref_t workgroups_buffer; + int64_t workgroups_offset; + iree_hal_dispatch_flags_t flags; + }; + iree_vm_abi_rriirII_t params; + }; + iree_vm_size_t constant_count; + const uint32_t* constants; + iree_vm_size_t binding_count; + const iree_vm_abi_iirII_t* bindings; +} iree_hal_module_command_buffer_dispatch2_indirect_args_t; +static iree_status_t iree_hal_module_command_buffer_dispatch2_indirect( + iree_vm_stack_t* IREE_RESTRICT stack, void* IREE_RESTRICT module, + iree_hal_module_state_t* IREE_RESTRICT state, + const iree_hal_module_command_buffer_dispatch2_indirect_args_t* + IREE_RESTRICT args) { + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_command_buffer_check_deref(args->command_buffer, + &command_buffer)); + iree_hal_executable_t* executable = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_executable_check_deref(args->executable, &executable)); + iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_indirect_buffer_ref( + args->workgroups_buffer_slot, + iree_hal_cast_device_size(args->workgroups_offset), 3 * sizeof(uint32_t)); + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null( + args->workgroups_buffer, &workgroups_ref.buffer)); + + if (IREE_UNLIKELY(args->binding_count > + IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT)) { + return iree_make_status(IREE_STATUS_OUT_OF_RANGE, + "binding count %" PRIhsz " > %" PRIhsz, + (iree_host_size_t)args->binding_count, + IREE_HAL_MODULE_MAX_DESCRIPTOR_BINDING_COUNT); + } + iree_hal_buffer_ref_list_t bindings = { + .count = (iree_host_size_t)args->binding_count, + .values = (iree_hal_buffer_ref_t*)iree_alloca( + args->binding_count * sizeof(iree_hal_buffer_ref_t)), + }; + for (iree_host_size_t i = 0; i < bindings.count; ++i) { + iree_hal_buffer_ref_t* binding = + (iree_hal_buffer_ref_t*)&bindings.values[i]; + binding->buffer_slot = (uint32_t)args->bindings[i].i1; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref_or_null( + args->bindings[i].r2, &binding->buffer)); + binding->offset = iree_hal_cast_device_size(args->bindings[i].i3); + binding->length = iree_hal_cast_device_size(args->bindings[i].i4); + } + + return iree_hal_command_buffer_dispatch2_indirect( + command_buffer, executable, args->entry_point, workgroups_ref, + iree_make_const_byte_span(args->constants, + args->constant_count * sizeof(uint32_t)), + bindings, (iree_hal_dispatch_flags_t)args->flags); +} +static iree_status_t iree_hal_module_command_buffer_dispatch2_indirect_shim( + iree_vm_stack_t* IREE_RESTRICT stack, iree_vm_native_function_flags_t flags, + iree_byte_span_t args_storage, iree_byte_span_t rets_storage, + iree_vm_native_function_target2_t target_fn, void* IREE_RESTRICT module, + void* IREE_RESTRICT module_state) { + // TODO(benvanik): support multiple variadic segments in one call. + // For now we inline what it would do in a very painful way. + bool args_ok = true; + if (args_storage.data_length < + (sizeof(iree_vm_abi_rriirII_t) + sizeof(iree_vm_size_t) + + sizeof(iree_vm_size_t))) { + // Can't fit even with zero lengths. + args_ok = false; + } + iree_hal_module_command_buffer_dispatch2_indirect_args_t args = { + .params = *(const iree_vm_abi_rriirII_t*)args_storage.data, + }; + if (args_ok) { + const uint8_t* constants_ptr = args_storage.data + sizeof(args.params); + args.constant_count = *(const iree_vm_size_t*)constants_ptr; + args.constants = (const uint32_t*)(constants_ptr + sizeof(iree_vm_size_t)); + const uint8_t* bindings_ptr = + constants_ptr + sizeof(iree_vm_size_t) + + args.constant_count * sizeof(args.constants[0]); + args.binding_count = *(const iree_vm_size_t*)bindings_ptr; + args.bindings = + (const iree_vm_abi_iirII_t*)(bindings_ptr + sizeof(iree_vm_size_t)); + const uint8_t* max_ptr = (const uint8_t*)args.bindings + + args.binding_count * sizeof(args.bindings[0]); + const uint8_t* end_ptr = args_storage.data + args_storage.data_length; + if (max_ptr > end_ptr) args_ok = false; + } + if (IREE_UNLIKELY(!args_ok || rets_storage.data_length > 0)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "argument/result signature mismatch"); + } + IREE_ASSERT(target_fn == + (iree_vm_native_function_target2_t) + iree_hal_module_command_buffer_dispatch2_indirect); + return iree_hal_module_command_buffer_dispatch2_indirect(stack, module, + module_state, &args); +} + //===----------------------------------------------------------------------===// // iree_hal_descriptor_set_layout //===----------------------------------------------------------------------===// @@ -1289,6 +1495,57 @@ IREE_VM_ABI_EXPORT(iree_hal_module_executable_create, // return status; } +IREE_VM_ABI_EXPORT(iree_hal_module_executable_create2, // + iree_hal_module_state_t, // + rrrr, r) { + iree_hal_device_t* device = NULL; + IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); + iree_vm_buffer_t* executable_format = NULL; + IREE_RETURN_IF_ERROR( + iree_vm_buffer_check_deref(args->r1, &executable_format)); + iree_string_view_t executable_format_str = + iree_vm_buffer_as_string(executable_format); + iree_vm_buffer_t* executable_data = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r2, &executable_data)); + iree_host_size_t constant_count = 0; + const uint32_t* constants = NULL; + if (iree_vm_buffer_isa(args->r3)) { + iree_vm_buffer_t* constant_buffer = NULL; + IREE_RETURN_IF_ERROR( + iree_vm_buffer_check_deref(args->r3, &constant_buffer)); + if (constant_buffer->data.data_length % 4 != 0) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "constant buffer data must contain 4-byte " + "elements but data length is %" PRIhsz, + constant_buffer->data.data_length); + } + constant_count = constant_buffer->data.data_length / sizeof(uint32_t); + constants = (const uint32_t*)constant_buffer->data.data; + } + + iree_hal_executable_cache_t* executable_cache = NULL; + IREE_RETURN_IF_ERROR(iree_hal_module_state_lookup_executable_cache( + state, device, &executable_cache)); + + iree_hal_executable_t* executable = NULL; + iree_hal_executable_params_t executable_params; + iree_hal_executable_params_initialize(&executable_params); + executable_params.caching_mode |= + executable_data->access == IREE_VM_BUFFER_ACCESS_ORIGIN_MODULE + ? IREE_HAL_EXECUTABLE_CACHING_MODE_ALIAS_PROVIDED_DATA + : 0; + executable_params.executable_format = executable_format_str; + executable_params.executable_data = iree_make_const_byte_span( + executable_data->data.data, executable_data->data.data_length); + executable_params.constant_count = constant_count; + executable_params.constants = constants; + IREE_RETURN_IF_ERROR(iree_hal_executable_cache_prepare_executable( + executable_cache, &executable_params, &executable)); + + rets->r0 = iree_hal_executable_move_ref(executable); + return iree_ok_status(); +} + //===----------------------------------------------------------------------===// // iree_hal_fence_t //===----------------------------------------------------------------------===// @@ -1652,8 +1909,14 @@ static const iree_vm_native_function_ptr_t iree_hal_module_funcs_[] = { iree_vm_shim_##arg_types##_##ret_types, \ .target = (iree_vm_native_function_target_t)(target_fn), \ }, +#define EXPORT_FN_CUSTOM(name, target_fn, arg_types, ret_types) \ + { \ + .shim = (iree_vm_native_function_shim_t)(target_fn##_shim), \ + .target = (iree_vm_native_function_target_t)(target_fn), \ + }, #include "iree/modules/hal/exports.inl" // IWYU pragma: keep #undef EXPORT_FN +#undef EXPORT_FN_CUSTOM }; // NOTE: 0 length, but can't express that in C. @@ -1668,8 +1931,10 @@ static const iree_vm_native_export_descriptor_t iree_hal_module_exports_[] = { .attr_count = 0, \ .attrs = NULL, \ }, +#define EXPORT_FN_CUSTOM EXPORT_FN #include "iree/modules/hal/exports.inl" // IWYU pragma: keep #undef EXPORT_FN +#undef EXPORT_FN_CUSTOM }; static_assert(IREE_ARRAYSIZE(iree_hal_module_funcs_) == IREE_ARRAYSIZE(iree_hal_module_exports_), diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index 5bd69a7dd5d4..2509ffa7dc12 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -46,6 +46,7 @@ IREE_VM_ABI_DEFINE_SHIM(riiIi, r); IREE_VM_ABI_DEFINE_SHIM(rIiiI, r); IREE_VM_ABI_DEFINE_SHIM(riIiirII, r); IREE_VM_ABI_DEFINE_SHIM(rriiiirrIIIII, v); +IREE_VM_ABI_DEFINE_SHIM(rrrr, r); IREE_VM_ABI_DEFINE_SHIM(rrrrCrD, r); IREE_VM_ABI_DEFINE_SHIM(ririi, v); IREE_VM_ABI_DEFINE_SHIM(rr, i); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index b47428ced911..cd14a4645589 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -585,6 +585,13 @@ IREE_VM_ABI_VLA_STRUCT(rirCrD, a3_count, a3, { iree_vm_abi_r_t a3[0]; }); +IREE_VM_ABI_FIXED_STRUCT(rrrr, { + iree_vm_ref_t r0; + iree_vm_ref_t r1; + iree_vm_ref_t r2; + iree_vm_ref_t r3; +}); + IREE_VM_ABI_VLA_STRUCT(rrrrCrD, a4_count, a4, { iree_vm_ref_t r0; iree_vm_ref_t r1; @@ -697,6 +704,7 @@ IREE_VM_ABI_DECLARE_SHIM(riiIi, r); IREE_VM_ABI_DECLARE_SHIM(rIiiI, r); IREE_VM_ABI_DECLARE_SHIM(riIiirII, r); IREE_VM_ABI_DECLARE_SHIM(rriiiirrIIIII, v); +IREE_VM_ABI_DECLARE_SHIM(rrrr, r); IREE_VM_ABI_DECLARE_SHIM(rrrrCrD, r); IREE_VM_ABI_DECLARE_SHIM(ririi, v); IREE_VM_ABI_DECLARE_SHIM(rr, i);