diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh index c98889ee52f5..7eb92c337e01 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh @@ -69,7 +69,9 @@ echo "Building with Ninja" cd "${CMAKE_BUILD_DIR?}" ninja -export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-$(nproc)} +# Limit parallelism dramatically to avoid exhausting GPU memory +# TODO(#5162): Handle this more robustly +export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-1} # Only test drivers that use the GPU, since we run all tests on non-GPU machines # as well. diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir b/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir index f832e8867a60..c547de236acb 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir +++ b/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir @@ -38,10 +38,10 @@ // CHECK-NEXT: %[[IN0_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[IN0_D0:.+]] = shapex.ranked_dim %[[IN0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[IN0_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[IN0_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb2: // CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index @@ -50,10 +50,10 @@ // CHECK-NEXT: %[[IN1_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[IN1_D0:.+]] = shapex.ranked_dim %[[IN1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[IN1_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[IN1_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb4: // CHECK-NEXT: return @@ -64,7 +64,7 @@ // CHECK: %[[IS_0:.+]] = cmpi eq, %[[INDEX]], %c0 : index // CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2 // CHECK-NEXT: ^bb1: -// CHECK-NEXT: %[[IN0_D0:.+]] = iree.list.get %[[LIST]], %c0 : !iree.list +// CHECK-NEXT: %[[IN0_D0:.+]] = iree.list.get %[[LIST]][%c0] : !iree.list // CHECK-NEXT: %[[IN0_SHAPE:.+]] = shapex.make_ranked_shape %[[IN0_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: flow.variable.store %[[IN0_SHAPE]], @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: br ^bb4 @@ -72,7 +72,7 @@ // CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index // CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4 // CHECK-NEXT: ^bb3: -// CHECK-NEXT: %[[IN1_D0:.+]] = iree.list.get %[[LIST]], %c0 : !iree.list +// CHECK-NEXT: %[[IN1_D0:.+]] = iree.list.get %[[LIST]][%c0] : !iree.list // CHECK-NEXT: %[[IN1_SHAPE:.+]] = shapex.make_ranked_shape %[[IN1_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: flow.variable.store %[[IN1_SHAPE]], @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: br ^bb4 @@ -90,10 +90,10 @@ // CHECK-NEXT: %[[OUT0_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_output0_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[OUT0_D0:.+]] = shapex.ranked_dim %[[OUT0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[OUT0_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[OUT0_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb2: // CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index @@ -102,10 +102,10 @@ // CHECK-NEXT: %[[OUT1_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_output1_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[OUT1_D0:.+]] = shapex.ranked_dim %[[OUT1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[OUT1_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[OUT1_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb4: // CHECK-NEXT: return diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp index 99732bda062f..a1eadcd9ccc4 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp @@ -63,7 +63,6 @@ void buildLLVMTransformPassPipeline(OpPassManager &passManager, if (options.usingLinalgOnTensors) { passManager.addPass(createMaterializeCPULaunchConfigurationPass()); OpPassManager &nestedModulePM = passManager.nest(); - nestedModulePM.addPass(createInlinerPass()); // TODO(ataei): We want to enable when tensor -> vector pass is fully // supported which requires first moving vector-tiling before this step. if (options.useLinalgOnTensorsToVectors) { diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp index f4fcc023b1ad..e7dcb7a9e17e 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp @@ -130,7 +130,8 @@ static PassRegistration pass( "iree-flow-outline-large-constants", "Outlines large tensor constants into flow.variables at the module level.", [] { - return std::make_unique(kMinLargeConstantSize); + // TODO(#5493): add a flag for this. + return std::make_unique(256); }); } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index 1509791f8b74..54ab27d6546c 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -122,9 +122,9 @@ std::unique_ptr> createExportBenchmarkFuncsPass(); // Outlines large tensor constants into flow.variables at the module level. // -// NOTE: a total guess :) this feels like about the most per-dispatch-buffer -// data we'd want to embed in the command buffer. -static constexpr size_t kMinLargeConstantSize = 256; +// TODO(#5493): implement the support for inlining constants into the command +// buffer and raise this value to one that is measured to be good. +static constexpr size_t kMinLargeConstantSize = 1; std::unique_ptr> createOutlineLargeConstantsPass( size_t minLargeConstantSize = kMinLargeConstantSize); diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp index 02cb2ffbc708..e98d51765cac 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp @@ -146,10 +146,12 @@ class CUDATargetBackend final : public TargetBackend { llvmModule->setDataLayout(targetMachine->createDataLayout()); - std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine); + FlatbufferBuilder builder; + iree_CUDAExecutableDef_start_as_root(builder); + // Serialize cuda kernel into the binary that we will embed in the // final flatbuffer. - FlatbufferBuilder builder; + std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine); auto ptxCudeRef = flatbuffers_uint8_vec_create( builder, reinterpret_cast(targetISA.c_str()), targetISA.size()); @@ -168,7 +170,6 @@ class CUDATargetBackend final : public TargetBackend { } auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder); - iree_CUDAExecutableDef_start_as_root(builder); iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef); iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef); iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef); diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp index 38cab2dda692..97ab76d13ee4 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp @@ -305,9 +305,11 @@ class LLVMAOTTargetBackend final : public TargetBackend { linkArtifacts.keepAllFiles(); } + FlatbufferBuilder builder; + iree_DyLibExecutableDef_start_as_root(builder); + // Embed debug symbols at the end of the flatbuffer by adding first in the // bottoms-up builder. - FlatbufferBuilder builder; flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0; flatbuffers_string_ref_t debugDatabaseFilenameRef = 0; if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) { @@ -328,7 +330,6 @@ class LLVMAOTTargetBackend final : public TargetBackend { << linkArtifacts.libraryFile.path; } - iree_DyLibExecutableDef_start_as_root(builder); iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef); iree_DyLibExecutableDef_debug_database_filename_add( builder, debugDatabaseFilenameRef); diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp index a9e5fe879c01..3bb9b7fafdf9 100644 --- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp @@ -121,6 +121,7 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend { // 4. Pack the MTLLibrary and metadata into a flatbuffer. FlatbufferBuilder builder; + iree_MetalExecutableDef_start_as_root(builder); auto shaderSourcesRef = builder.createStringVec(llvm::map_range( mslShaders, [&](const MetalShader &shader) { return shader.source; })); @@ -135,7 +136,6 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend { auto entryPointNamesRef = builder.createStringVec(entryPointNames); - iree_MetalExecutableDef_start_as_root(builder); iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef); iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef); iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef); diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp index 8e00d15b2365..6cd7fc550f02 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp @@ -102,8 +102,10 @@ class VMLATargetBackend final : public TargetBackend { LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp, OpBuilder &executableBuilder) override { - // Serialize the VM module to bytes directly into a flatbuffer. FlatbufferBuilder builder; + iree_VMLAExecutableDef_start_as_root(builder); + + // Serialize the VM module to bytes directly into a flatbuffer. IREE::VM::BytecodeTargetOptions bytecodeOptions; auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) { return succeeded(translateModuleToBytecode(targetOp.getInnerModule(), @@ -115,7 +117,6 @@ class VMLATargetBackend final : public TargetBackend { // Pack the executable definition and get the bytes with the proper header. // The header is used to verify the contents at runtime. - iree_VMLAExecutableDef_start_as_root(builder); iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef); iree_VMLAExecutableDef_end_as_root(builder); diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index de7c23794aca..e1a556e25d40 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -128,9 +128,11 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend { ModuleOp innerModuleOp = targetOp.getInnerModule(); auto spvModuleOp = *innerModuleOp.getOps().begin(); + FlatbufferBuilder builder; + iree_SpirVExecutableDef_start_as_root(builder); + // Serialize the spirv::ModuleOp into the binary that we will embed in the // final flatbuffer. - FlatbufferBuilder builder; SmallVector spvBinary; if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) { return targetOp.emitError() << "failed to serialize spv.module"; @@ -157,7 +159,6 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend { } auto entryPointsRef = builder.createStringVec(entryPointNames); - iree_SpirVExecutableDef_start_as_root(builder); iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef); iree_SpirVExecutableDef_code_add(builder, spvCodeRef); iree_SpirVExecutableDef_end_as_root(builder); diff --git a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp index 31d41a5f8ece..e8797e44ebcd 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp @@ -39,7 +39,9 @@ IREEDialect::IREEDialect(MLIRContext* context) Type IREEDialect::parseType(DialectAsmParser& parser) const { Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); llvm::StringRef spec = parser.getFullSymbolSpec(); - if (spec.consume_front("ptr")) { + if (spec == "variant") { + return IREE::VariantType::get(getContext()); + } else if (spec.consume_front("ptr")) { if (!spec.consume_front("<") || !spec.consume_back(">")) { parser.emitError(parser.getCurrentLocation()) << "malformed ptr type '" << parser.getFullSymbolSpec() << "'"; @@ -63,7 +65,12 @@ Type IREEDialect::parseType(DialectAsmParser& parser) const { << "malformed list type '" << parser.getFullSymbolSpec() << "'"; return Type(); } - auto elementType = mlir::parseType(spec, getContext()); + Type elementType; + if (spec == "?") { + elementType = IREE::VariantType::get(getContext()); + } else { + elementType = mlir::parseType(spec, getContext()); + } if (!elementType) { parser.emitError(parser.getCurrentLocation()) << "invalid list element type specification: '" @@ -77,14 +84,22 @@ Type IREEDialect::parseType(DialectAsmParser& parser) const { } void IREEDialect::printType(Type type, DialectAsmPrinter& os) const { - if (auto ptrType = type.dyn_cast()) { + if (type.isa()) { + os << "variant"; + } else if (auto ptrType = type.dyn_cast()) { os << "ptr<" << ptrType.getTargetType() << ">"; } else if (type.isa()) { os << "byte_buffer"; } else if (type.isa()) { os << "mutable_byte_buffer"; } else if (auto listType = type.dyn_cast()) { - os << "list<" << listType.getElementType() << ">"; + os << "list<"; + if (listType.getElementType().isa()) { + os << "?"; + } else { + os << listType.getElementType(); + } + os << ">"; } else { llvm_unreachable("unhandled IREE type"); } diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp index b71d02ec23e0..42e49c7c6c4a 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp @@ -155,43 +155,82 @@ void UnfoldableConstantOp::getCanonicalizationPatterns( // Lists //===----------------------------------------------------------------------===// -static ParseResult parseListType(OpAsmParser &parser, Type &listType, - Type &elementType) { +static ParseResult parseListTypeGet(OpAsmParser &parser, Type &listType, + Type &elementType) { if (failed(parser.parseType(listType))) { return parser.emitError(parser.getCurrentLocation(), - "expected !iree.list<> type"); + "expected !iree.list type"); + } + auto listElementType = listType.cast().getElementType(); + if (succeeded(parser.parseOptionalArrow())) { + // Use overridden type - required for variants only. + if (failed(parser.parseType(elementType))) { + return parser.emitError( + parser.getCurrentLocation(), + "expected an element type when specifying list access types"); + } + if (!ListType::canImplicitlyCast(listElementType, elementType)) { + return parser.emitError( + parser.getCurrentLocation(), + "list access types must match the same base type as the list element " + "type (when not variant)"); + } + } else { + // Use list element type as the result element type. + elementType = listElementType; } - elementType = listType.cast().getElementType(); return success(); } -static ParseResult parseListType(OpAsmParser &parser, Type &listType, - SmallVectorImpl &elementTypes) { - if (failed(parser.parseType(listType))) { +static void printListTypeGet(OpAsmPrinter &printer, Operation *, Type listType, + Type elementType) { + printer.printType(listType); + auto listElementType = listType.cast().getElementType(); + if (listElementType != elementType) { + printer.printArrowTypeList(ArrayRef{elementType}); + } +} + +static ParseResult parseListTypeSet(OpAsmParser &parser, Type &listType, + Type &elementType) { + Type leadingType; + if (failed(parser.parseType(leadingType))) { return parser.emitError(parser.getCurrentLocation(), - "expected !iree.list<> type"); + "expected element type or !iree.list type"); } - for (size_t i = 0; i < elementTypes.size(); ++i) { - elementTypes[i] = listType.cast().getElementType(); + if (succeeded(parser.parseOptionalArrow())) { + elementType = leadingType; + if (failed(parser.parseType(listType)) || !listType.isa()) { + return parser.emitError(parser.getCurrentLocation(), + "expected an !iree.list type"); + } + } else { + if (!leadingType.isa()) { + return parser.emitError(parser.getCurrentLocation(), + "expected an !iree.list type"); + } + listType = leadingType; + elementType = listType.cast().getElementType(); } return success(); } -static void printListType(OpAsmPrinter &printer, Operation *, Type listType, - Type elementType) { - printer.printType(listType); -} - -static void printListType(OpAsmPrinter &printer, Operation *, Type listType, - TypeRange elementTypes) { - printer.printType(listType); +static void printListTypeSet(OpAsmPrinter &printer, Operation *, Type listType, + Type elementType) { + auto listElementType = listType.cast().getElementType(); + if (listElementType != elementType) { + printer.printType(elementType); + printer.printArrowTypeList(ArrayRef{listType}); + } else { + printer.printType(listType); + } } static LogicalResult verifyListGetOp(ListGetOp &op) { auto listType = op.list().getType().cast(); auto elementType = listType.getElementType(); auto resultType = op.result().getType(); - if (resultType != elementType) { + if (!ListType::canImplicitlyCast(elementType, resultType)) { return op.emitError() << "list contains " << elementType << " and cannot be accessed as " << resultType; } @@ -202,7 +241,7 @@ static LogicalResult verifyListSetOp(ListSetOp &op) { auto listType = op.list().getType().cast(); auto elementType = listType.getElementType(); auto valueType = op.value().getType(); - if (valueType != elementType) { + if (!ListType::canImplicitlyCast(valueType, elementType)) { return op.emitError() << "list contains " << elementType << " and cannot be mutated as " << valueType; } diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td index 7bf01022f515..0d13d3823372 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEOps.td +++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td @@ -187,6 +187,22 @@ def IREE_UnreachableOp : IREE_Op<"unreachable", [NoSideEffect, Terminator]> { // new SSA values. This would make optimizing the list usage much easier and // enable hoisting/CSE of list access/mutation. +def IREE_ListCreateOp : IREE_PureOp<"list.create"> { + let summary = [{creates a new empty list}]; + let description = [{ + Creates a new empty list with an optional initial capacity. + }]; + + let arguments = (ins + Optional:$initial_capacity + ); + let results = (outs + AnyList:$result + ); + + let assemblyFormat = "($initial_capacity^)? attr-dict `:` type($result)"; +} + def IREE_ListSizeOp : IREE_Op<"list.size"> { let summary = [{the size of the list in elements}]; let description = [{ @@ -234,7 +250,7 @@ def IREE_ListGetOp : IREE_Op<"list.get"> { AnyType:$result ); - let assemblyFormat = "operands attr-dict `:` custom(type($list), type($result))"; + let assemblyFormat = "$list `[` $index `]` attr-dict `:` custom(type($list), type($result))"; let verifier = [{ return verify$cppClass(*this); }]; } @@ -251,7 +267,7 @@ def IREE_ListSetOp : IREE_Op<"list.set"> { AnyType:$value ); - let assemblyFormat = "operands attr-dict `:` custom(type($list), type($value))"; + let assemblyFormat = "$list `[` $index `]` `,` $value attr-dict `:` custom(type($list), type($value))"; let verifier = [{ return verify$cppClass(*this); }]; } diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp index 19b0e3745da8..ed748a3b9d3b 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp @@ -52,6 +52,13 @@ struct ListTypeStorage : public TypeStorage { // static bool ListType::isCompatible(Type type) { return true; } +// static +bool ListType::canImplicitlyCast(Type from, Type to) { + if (from.isa() || to.isa()) return true; + if (from.isa() && to.isa()) return true; + return from == to; +} + ListType ListType::get(Type elementType) { return Base::get(elementType.getContext(), elementType); } @@ -244,7 +251,7 @@ void excludeTiedOperandAndResultIndices( void IREEDialect::registerTypes() { addTypes(); + IREE::PtrType, IREE::VariantType>(); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h index 73e81305bdec..ae03c40ef37b 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.h +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h @@ -58,6 +58,12 @@ enum class StatusCode : int32_t { DoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 }; +/// Placeholder for a variant type (`?`). +class VariantType : public Type::TypeBase { + public: + using Base::Base; +}; + /// A list containing an optional element type. class ListType : public Type::TypeBase { @@ -67,6 +73,10 @@ class ListType /// Returns true if the given type can be wrapped in a list. static bool isCompatible(Type type); + /// Returns true if |from| can be implicitly cast to |to| as part of a list + /// access operation. Example: tensor<*xf32> -> tensor<4xf32>. + static bool canImplicitlyCast(Type from, Type to); + /// Gets or creates a ListType with the provided element type. static ListType get(Type elementType); diff --git a/iree/compiler/Dialect/IREE/IR/test/BUILD b/iree/compiler/Dialect/IREE/IR/test/BUILD index aacb0e0f1def..3935fda86241 100644 --- a/iree/compiler/Dialect/IREE/IR/test/BUILD +++ b/iree/compiler/Dialect/IREE/IR/test/BUILD @@ -27,6 +27,7 @@ iree_lit_test_suite( [ "byte_buffer_ops.mlir", "do_not_optimize.mlir", + "list_ops.mlir", "parse_print.mlir", ], include = ["*.mlir"], diff --git a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt index 34fbf727c1c8..63b007972a1f 100644 --- a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "byte_buffer_ops.mlir" "do_not_optimize.mlir" + "list_ops.mlir" "parse_print.mlir" DATA iree::tools::IreeFileCheck diff --git a/iree/compiler/Dialect/IREE/IR/test/list_ops.mlir b/iree/compiler/Dialect/IREE/IR/test/list_ops.mlir new file mode 100644 index 000000000000..b3f046bb5363 --- /dev/null +++ b/iree/compiler/Dialect/IREE/IR/test/list_ops.mlir @@ -0,0 +1,85 @@ +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s + +// CHECK-LABEL: @list_init_ops +func @list_init_ops() { + // CHECK: %[[CAPACITY:.+]] = constant 5 + %capacity = constant 5 : index + // CHECK: = iree.list.create %[[CAPACITY]] : !iree.list + %list_initial_capacity = iree.list.create %capacity : !iree.list + + // CHECK: %[[LIST:.+]] = iree.list.create : !iree.list + %list = iree.list.create : !iree.list + + // CHECK: %[[NEW_SIZE:.+]] = constant 100 + %new_size = constant 100 : index + // CHECK: iree.list.resize %[[LIST]], %[[NEW_SIZE]] : !iree.list + iree.list.resize %list, %new_size : !iree.list + + return +} + +// ----- + +// CHECK-LABEL: @list_access +// CHECK-SAME: (%[[LIST:.+]]: !iree.list) +func @list_access(%list: !iree.list) { + %c10 = constant 10 : index + + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list + %0 = iree.list.get %list[%c10] : !iree.list + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list + %1 = iree.list.get %list[%c10] : !iree.list -> i32 + + // CHECK: %[[NEW_VALUE:.+]] = constant 100 : i32 + %new_value = constant 100 : i32 + // CHECK: iree.list.set %[[LIST]][%c10], %[[NEW_VALUE]] : !iree.list + iree.list.set %list[%c10], %new_value : !iree.list + + return +} + +// ----- + +// CHECK-LABEL: @list_access_tensor +// CHECK-SAME: (%[[LIST:.+]]: !iree.list>) +func @list_access_tensor(%list: !iree.list>) { + %c10 = constant 10 : index + + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list> -> tensor + %0 = iree.list.get %list[%c10] : !iree.list> -> tensor + + // CHECK: %[[NEW_VALUE:.+]] = constant dense<1> : tensor<5xi32> + %new_value = constant dense<1> : tensor<5xi32> + // CHECK: iree.list.set %[[LIST]][%c10], %[[NEW_VALUE]] : tensor<5xi32> -> !iree.list> + iree.list.set %list[%c10], %new_value : tensor<5xi32> -> !iree.list> + + return +} + +// ----- + +// CHECK-LABEL: @list_access_variant +// CHECK-SAME: (%[[LIST:.+]]: !iree.list) +func @list_access_variant(%list: !iree.list) { + %c10 = constant 10 : index + %c11 = constant 11 : index + + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list -> i32 + %0 = iree.list.get %list[%c10] : !iree.list -> i32 + + // CHECK: %[[NEW_I32_VALUE:.+]] = constant 100 : i32 + %new_i32_value = constant 100 : i32 + // CHECK: iree.list.set %[[LIST]][%c10], %[[NEW_I32_VALUE]] : i32 -> !iree.list + iree.list.set %list[%c10], %new_i32_value : i32 -> !iree.list + + // CHECK: = iree.list.get %[[LIST]][%c11] : !iree.list -> tensor<5xf32> + %1 = iree.list.get %list[%c11] : !iree.list -> tensor<5xf32> + + // CHECK: %[[NEW_TENSOR_VALUE:.+]] = constant dense<1> : tensor<5xi32> + %new_tensor_value = constant dense<1> : tensor<5xi32> + // CHECK: iree.list.set %[[LIST]][%c11], %[[NEW_TENSOR_VALUE]] : tensor<5xi32> -> !iree.list + iree.list.set %list[%c11], %new_tensor_value : tensor<5xi32> -> !iree.list + + return +} + diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp index 8f74279ecd1f..f49e8f26ac73 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp @@ -87,6 +87,19 @@ class UnreachableOpConversion // Lists //===----------------------------------------------------------------------===// +class ListCreateOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::ListCreateOp srcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + IREE::ListCreateOpAdaptor srcOperands(operands); + rewriter.replaceOpWithNewOp( + srcOp, typeConverter->convertType(srcOp.result().getType()), + srcOperands.initial_capacity()); + return success(); + } +}; + class ListSizeOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -169,13 +182,19 @@ void populateIREEToVMPatterns(MLIRContext *context, typeConverter.addConversion( [&typeConverter](IREE::ListType type) -> Optional { - auto elementType = typeConverter.convertType(type.getElementType()); + Type elementType; + if (type.getElementType().isa()) { + elementType = IREE::VM::OpaqueType::get(type.getContext()); + } else { + elementType = typeConverter.convertType(type.getElementType()); + } if (!elementType) return llvm::None; return IREE::VM::RefType::get(IREE::VM::ListType::get(elementType)); }); - patterns.insert(typeConverter, - context); + patterns + .insert( + typeConverter, context); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD index 26a05ad4bb48..16ed188d57ad 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD @@ -27,6 +27,7 @@ iree_lit_test_suite( [ "byte_buffer_ops.mlir", "hint_ops.mlir", + "list_ops.mlir", ], include = ["*.mlir"], ), diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt index 81db6954bef8..440353019b31 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "byte_buffer_ops.mlir" "hint_ops.mlir" + "list_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir new file mode 100644 index 000000000000..ea4b92d367b8 --- /dev/null +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir @@ -0,0 +1,37 @@ +// RUN: iree-opt -split-input-file -iree-vm-conversion %s | IreeFileCheck %s + +// CHECK-LABEL: @list_ops +module @list_ops { module { + // CHECK: vm.func @my_fn + // CHECK-SAME: (%[[BUFFER_VIEW:.+]]: !vm.ref) + func @my_fn(%buffer_view: !hal.buffer_view) { + // CHECK: %[[CAPACITY:.+]] = vm.const.i32 5 + %capacity = constant 5 : index + // CHECK: %[[LIST:.+]] = vm.list.alloc %[[CAPACITY]] : (i32) -> !vm.list + %list = iree.list.create %capacity : !iree.list + + // CHECK: %[[NEW_SIZE:.+]] = vm.const.i32 100 + %new_size = constant 100 : index + // CHECK: vm.list.resize %[[LIST]], %[[NEW_SIZE]] : (!vm.list, i32) + iree.list.resize %list, %new_size : !iree.list + + %c10 = constant 10 : index + %c11 = constant 11 : index + + // CHECK: = vm.list.get.i32 %[[LIST]], %c10 : (!vm.list, i32) -> i32 + %0 = iree.list.get %list[%c10] : !iree.list -> i32 + + // CHECK: %[[NEW_I32_VALUE:.+]] = vm.const.i32 101 + %new_i32_value = constant 101 : i32 + // CHECK: vm.list.set.i32 %[[LIST]], %c10, %[[NEW_I32_VALUE]] : (!vm.list, i32, i32) + iree.list.set %list[%c10], %new_i32_value : i32 -> !iree.list + + // CHECK: = vm.list.get.ref %[[LIST]], %c11 : (!vm.list, i32) -> !vm.ref + %1 = iree.list.get %list[%c11] : !iree.list -> !hal.buffer_view + + // CHECK: vm.list.set.ref %[[LIST]], %c11, %[[BUFFER_VIEW]] : (!vm.list, i32, !vm.ref) + iree.list.set %list[%c11], %buffer_view : !hal.buffer_view -> !iree.list + + return + } +} } diff --git a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index f5dac0305c55..a495ebfd807c 100644 --- a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -117,7 +117,8 @@ static Value createStringTableValue(Location loc, StringAttr attrValue, return rewriter.create( loc, IREE::VM::RefType::get(IREE::ByteBufferType::get(rewriter.getContext())), - rewriter.getStringAttr(safeIdentifier), utf8Bytes); + rewriter.getStringAttr(safeIdentifier), utf8Bytes, + /*alignment=*/rewriter.getI64IntegerAttr(1)); } size_t getSegmentSpanSize(Type spanType) { diff --git a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp index fccf14c56316..0a2797924d56 100644 --- a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp +++ b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp @@ -31,8 +31,14 @@ namespace VM { TypeConverter::TypeConverter(TargetOptions targetOptions) : targetOptions_(targetOptions) { + // Variant means opaque in VM. + addConversion([](IREE::VariantType type) { + return IREE::VM::OpaqueType::get(type.getContext()); + }); + // All ref types are passed through unmodified. addConversion([](IREE::VM::RefType type) { return type; }); + // Wrap ref types. addConversion([](Type type) -> Optional { if (RefType::isCompatible(type)) { diff --git a/iree/compiler/Dialect/VM/IR/VMBase.td b/iree/compiler/Dialect/VM/IR/VMBase.td index 3f404a5bab4a..bbf9a10a69b7 100644 --- a/iree/compiler/Dialect/VM/IR/VMBase.td +++ b/iree/compiler/Dialect/VM/IR/VMBase.td @@ -590,9 +590,12 @@ def VM_AnyList : DialectType< class VM_ListOf : Type().getObjectType().isa()">, - SubstLeaves<"$_self", - "$_self.cast().getObjectType().cast().getElementType()", - type.predicate> + Or<[ + CPred<"$_self.cast().getObjectType().cast().getElementType().isa()">, + SubstLeaves<"$_self", + "$_self.cast().getObjectType().cast().getElementType()", + type.predicate> + ]>, ]>, "list<" # type.summary # ">"> { // Set the builder call if the base type has a builder call. string builderCall = !if(!empty(type.builderCall), diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp index 38d5c1416a61..d2d796dce419 100644 --- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp +++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp @@ -236,7 +236,12 @@ Type VMDialect::parseType(DialectAsmParser &parser) const { << "malformed list type '" << parser.getFullSymbolSpec() << "'"; return Type(); } - auto elementType = mlir::parseType(spec, getContext()); + Type elementType; + if (spec == "?") { + elementType = OpaqueType::get(getContext()); + } else { + elementType = mlir::parseType(spec, getContext()); + } if (!elementType) { parser.emitError(parser.getCurrentLocation()) << "invalid list element type specification: '" @@ -284,7 +289,13 @@ void VMDialect::printType(Type type, DialectAsmPrinter &os) const { } else if (type.isa()) { os << "opaque"; } else if (auto listType = type.dyn_cast()) { - os << "list<" << listType.getElementType() << ">"; + os << "list<"; + if (listType.getElementType().isa()) { + os << "?"; + } else { + os << listType.getElementType(); + } + os << ">"; } else { llvm_unreachable("unhandled VM type"); } diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp index d85a8b3ca996..765330cb1d93 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -693,17 +693,19 @@ static LogicalResult verifyListGetRefOp(ListGetRefOp &op) { .cast(); auto elementType = listType.getElementType(); auto resultType = op.result().getType(); - if (elementType.isa() != - resultType.isa()) { - // Attempting to go between a primitive type and ref type. - return op.emitError() << "cannot convert between list type " << elementType - << " and result type " << resultType; - } else if (auto refType = elementType.dyn_cast()) { - if (!refType.getObjectType().isa() && - elementType != resultType) { - // List has a concrete type, verify it matches. - return op.emitError() << "list contains " << elementType - << " that cannot be accessed as " << resultType; + if (!elementType.isa()) { + if (elementType.isa() != + resultType.isa()) { + // Attempting to go between a primitive type and ref type. + return op.emitError() << "cannot convert between list type " + << elementType << " and result type " << resultType; + } else if (auto refType = elementType.dyn_cast()) { + if (!refType.getObjectType().isa() && + elementType != resultType) { + // List has a concrete type, verify it matches. + return op.emitError() << "list contains " << elementType + << " that cannot be accessed as " << resultType; + } } } return success(); @@ -717,17 +719,19 @@ static LogicalResult verifyListSetRefOp(ListSetRefOp &op) { .cast(); auto elementType = listType.getElementType(); auto valueType = op.value().getType(); - if (elementType.isa() != - valueType.isa()) { - // Attempting to go between a primitive type and ref type. - return op.emitError() << "cannot convert between list type " << elementType - << " and value type " << valueType; - } else if (auto refType = elementType.dyn_cast()) { - if (!refType.getObjectType().isa() && - elementType != valueType) { - // List has a concrete type, verify it matches. - return op.emitError() << "list contains " << elementType - << " that cannot be mutated as " << valueType; + if (!elementType.isa()) { + if (elementType.isa() != + valueType.isa()) { + // Attempting to go between a primitive type and ref type. + return op.emitError() << "cannot convert between list type " + << elementType << " and value type " << valueType; + } else if (auto refType = elementType.dyn_cast()) { + if (!refType.getObjectType().isa() && + elementType != valueType) { + // List has a concrete type, verify it matches. + return op.emitError() << "list contains " << elementType + << " that cannot be mutated as " << valueType; + } } } return success(); diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index 82a344febe7a..408f4e646e09 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td @@ -754,18 +754,26 @@ def VM_RodataOp : VM_Op<"rodata", [ value leaves the module. For example, returning rodata from an exported function must keep the data (possibly backed by mmap) valid for its entire lifetime. + + By default all rodata will be aligned in the final module output at a + 16-byte granularity. An optional alignment can be specified to override the + default for cases where larger or smaller alignments are needed. }]; let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value, + OptionalAttr:$alignment, OptionalAttr:$ordinal ); let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, "ElementsAttr":$value, - CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins + "StringRef":$name, + "ElementsAttr":$value, + CArg<"ArrayRef", "{}">:$attrs + )>, ]; } @@ -810,12 +818,14 @@ def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [ ]> { let summary = [{inlined constant rodata}]; let description = [{ - vm.rodata that can be embedded inline in functions. + vm.rodata that can be embedded inline in functions. See vm.rodata for more + information. }]; let arguments = (ins OptionalAttr:$name, - ElementsAttr:$value + ElementsAttr:$value, + OptionalAttr:$alignment ); let results = (outs diff --git a/iree/compiler/Dialect/VM/IR/VMTypes.cpp b/iree/compiler/Dialect/VM/IR/VMTypes.cpp index fb6d8ec1cce5..5963e5e6ce3d 100644 --- a/iree/compiler/Dialect/VM/IR/VMTypes.cpp +++ b/iree/compiler/Dialect/VM/IR/VMTypes.cpp @@ -56,7 +56,10 @@ struct ListTypeStorage : public TypeStorage { // static bool ListType::isCompatible(Type type) { - if (type.isa()) { + if (type.isa()) { + // Allow all types (variant). + return true; + } else if (type.isa()) { // Allow all ref types. return true; } else if (type.isIntOrFloat()) { diff --git a/iree/compiler/Dialect/VM/IR/test/list_ops.mlir b/iree/compiler/Dialect/VM/IR/test/list_ops.mlir index 8741558c63d7..de2b61aef263 100644 --- a/iree/compiler/Dialect/VM/IR/test/list_ops.mlir +++ b/iree/compiler/Dialect/VM/IR/test/list_ops.mlir @@ -28,7 +28,7 @@ vm.module @module { // Typed accessors for lists with i32 elements. vm.module @module { // CHECK: @list_i32 - vm.func @list_i32(%arg0 : !vm.list) { + vm.func @list_i32(%arg0: !vm.list) { %c100 = vm.const.i32 100 : i32 // CHECK: vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 @@ -42,7 +42,7 @@ vm.module @module { } // CHECK: @list_i8_coerce - vm.func @list_i8_coerce(%arg0 : !vm.list) { + vm.func @list_i8_coerce(%arg0: !vm.list) { %c100 = vm.const.i32 100 : i32 // CHECK: = vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 @@ -60,7 +60,7 @@ vm.module @module { // Typed accessors for lists with opaque ref elements. vm.module @module { // CHECK: @list_ref_any - vm.func @list_ref_any(%arg0 : !vm.list>) { + vm.func @list_ref_any(%arg0: !vm.list>) { %c100 = vm.const.i32 100 : i32 // CHECK: %ref = vm.list.get.ref %arg0, %c100 : (!vm.list>, i32) -> !vm.ref @@ -78,7 +78,7 @@ vm.module @module { // Typed accessors for lists with strongly-typed ref elements. vm.module @module { // CHECK: @list_ref_typed - vm.func @list_ref_typed(%arg0 : !vm.list>) { + vm.func @list_ref_typed(%arg0: !vm.list>) { %c100 = vm.const.i32 100 : i32 // CHECK: %ref = vm.list.get.ref %arg0, %c100 : (!vm.list>, i32) -> !vm.ref @@ -90,3 +90,37 @@ vm.module @module { vm.return } } + +// ----- + +// Variant access allows any type access. +vm.module @module { + // CHECK: @list_create_variant + vm.func @list_create_variant() { + %c42 = vm.const.i32 42 : i32 + // CHECK: %list = vm.list.alloc %c42 : (i32) -> !vm.list + %list = vm.list.alloc %c42 : (i32) -> !vm.list + + vm.return + } + + // CHECK: @list_access_variant + vm.func @list_access_variant(%arg0: !vm.list) { + %c100 = vm.const.i32 100 : i32 + %c101 = vm.const.i32 101 : i32 + + // CHECK: = vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 + %0 = vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 + + // CHECK: vm.list.set.i32 %arg0, %c100, %0 : (!vm.list, i32, i32) + vm.list.set.i32 %arg0, %c100, %0 : (!vm.list, i32, i32) + + // CHECK: %ref = vm.list.get.ref %arg0, %c101 : (!vm.list, i32) -> !vm.ref + %ref = vm.list.get.ref %arg0, %c101 : (!vm.list, i32) -> !vm.ref + + // CHECK: vm.list.set.ref %arg0, %c101, %ref : (!vm.list, i32, !vm.ref) + vm.list.set.ref %arg0, %c101, %ref : (!vm.list, i32, !vm.ref) + + vm.return + } +} diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp index 3e4cee0899a8..1da05094de1b 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp @@ -92,7 +92,19 @@ class V0BytecodeEncoder : public BytecodeEncoder { } LogicalResult encodeType(Type type) override { - int typeOrdinal = typeTable_->lookup(type); + // HACK: it'd be nice to remove the implicit ref wrapper hiding. + if (auto refType = type.dyn_cast()) { + if (refType.getObjectType().isa()) { + type = refType.getObjectType(); + } + } + auto it = typeTable_->find(type); + if (it == typeTable_->end()) { + return currentOp_->emitOpError() + << "type " << type + << " cannot be encoded; not registered in type table"; + } + int typeOrdinal = it->second; return writeUint32(typeOrdinal); } diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 821ba43ec517..a23e9d001335 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -73,9 +73,8 @@ static std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { sstream.flush(); typeMap.try_emplace(type, str); if (auto listType = type.dyn_cast()) { - if (listType.getElementType()) { - tryInsertType(listType.getElementType()); - } + assert(listType.getElementType()); + tryInsertType(listType.getElementType()); } }; for (auto funcOp : moduleOp.getBlock().getOps()) { @@ -90,7 +89,7 @@ static std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { for (const auto &typeString : typeMap) { table.push_back(TypeDef{typeString.first, typeString.second}); } - llvm::sort( + llvm::stable_sort( table, +[](const TypeDef &lhs, const TypeDef &rhs) { // Always sort builtins above custom types. if (lhs.full_name[0] != '!' && rhs.full_name[0] == '!') { @@ -276,6 +275,11 @@ static iree_vm_FunctionSignatureDef_ref_t makeInternalFunctionSignatureDef( static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp, FlatbufferBuilder &fbb) { + // Start the buffer so that we can begin recording data prior to the root + // table (which we do at the very end). This does not change the layout of the + // file and is only used to prime the flatcc builder. + iree_vm_BytecodeModuleDef_start_as_root(fbb); + SymbolTable symbolTable(moduleOp); if (!moduleOp.ordinal_counts().hasValue()) { return moduleOp.emitError() << "ordinal_counts attribute not found. The " @@ -316,9 +320,20 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, // layout planning by preserving the order in the IR is useful. SmallVector rodataContentRefs; rodataContentRefs.reserve(rodataOps.size()); + + // All constants are defaulted to 16-byte aligned as that is the maximum + // (reasonable) alignment of all data types on all platforms. This can be + // overridden by creators of the rodata with the `alignment` attribute. + static constexpr int kDefaultRodataAlignment = 16; + for (auto rodataOp : llvm::reverse(rodataOps)) { + size_t alignment = + rodataOp.alignment() + ? static_cast(rodataOp.alignment().getValue()) + : 0; + if (alignment == 0) alignment = kDefaultRodataAlignment; auto rodataRef = - serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb); + serializeConstant(rodataOp.getLoc(), rodataOp.value(), alignment, fbb); if (!rodataRef) { return rodataOp.emitOpError() << "failed to encode"; } @@ -462,7 +477,6 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, auto moduleNameRef = fbb.createString( moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name()); - iree_vm_BytecodeModuleDef_start_as_root(fbb); iree_vm_BytecodeModuleDef_name_add(fbb, moduleNameRef); iree_vm_BytecodeModuleDef_types_add(fbb, typesRef); iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef); diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp index b33908bdfe6f..91a5fb26170a 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp @@ -26,11 +26,11 @@ namespace VM { // TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness. static flatbuffers_uint8_vec_ref_t serializeConstantI8Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { // vm.rodata and other very large constants end up as this; since i8 is i8 // everywhere (endianness doesn't matter when you have one byte :) we can // directly access the data and memcpy. - flatbuffers_uint8_vec_start(fbb); + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(int8_t)); if (attr.isSplat()) { @@ -47,8 +47,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI8Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI16Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int16_t)); uint16_t *nativePtr = reinterpret_cast(bytePtr); @@ -59,8 +59,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI16Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI32Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int32_t)); uint32_t *nativePtr = reinterpret_cast(bytePtr); @@ -71,8 +71,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI32Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI64Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int64_t)); uint64_t *nativePtr = reinterpret_cast(bytePtr); @@ -83,8 +83,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI64Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF32Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(float)); float *nativePtr = reinterpret_cast(bytePtr); @@ -95,8 +95,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF32Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF64Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(double)); double *nativePtr = reinterpret_cast(bytePtr); @@ -107,8 +107,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF64Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF16Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(uint16_t)); uint16_t *nativePtr = reinterpret_cast(bytePtr); @@ -121,17 +121,18 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF16Array( flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, ElementsAttr elementsAttr, + size_t alignment, FlatbufferBuilder &fbb) { if (auto attr = elementsAttr.dyn_cast()) { switch (attr.getType().getElementTypeBitWidth()) { case 8: - return serializeConstantI8Array(attr, fbb); + return serializeConstantI8Array(attr, alignment, fbb); case 16: - return serializeConstantI16Array(attr, fbb); + return serializeConstantI16Array(attr, alignment, fbb); case 32: - return serializeConstantI32Array(attr, fbb); + return serializeConstantI32Array(attr, alignment, fbb); case 64: - return serializeConstantI64Array(attr, fbb); + return serializeConstantI64Array(attr, alignment, fbb); default: emitError(loc) << "unhandled element bitwidth " << attr.getType().getElementTypeBitWidth(); @@ -140,11 +141,11 @@ flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, } else if (auto attr = elementsAttr.dyn_cast()) { switch (attr.getType().getElementTypeBitWidth()) { case 16: - return serializeConstantF16Array(attr, fbb); + return serializeConstantF16Array(attr, alignment, fbb); case 32: - return serializeConstantF32Array(attr, fbb); + return serializeConstantF32Array(attr, alignment, fbb); case 64: - return serializeConstantF64Array(attr, fbb); + return serializeConstantF64Array(attr, alignment, fbb); default: emitError(loc) << "unhandled element bitwidth " << attr.getType().getElementTypeBitWidth(); diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h index 56471a633fff..94dca8186ae7 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h +++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h @@ -28,6 +28,7 @@ namespace VM { // Serializes a constant attribute to the FlatBuffer as a binary blob. flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, ElementsAttr elementsAttr, + size_t alignment, FlatbufferBuilder &fbb); } // namespace VM diff --git a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp index 240a507f5004..7aa741dd3bd0 100644 --- a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp +++ b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp @@ -60,6 +60,9 @@ class HoistInlinedRodataPass auto rodataOp = OpBuilder(moduleOp.getContext()) .create(inlineOp.getLoc(), name, inlineOp.value()); + if (inlineOp.alignmentAttr()) { + rodataOp.alignmentAttr(inlineOp.alignmentAttr()); + } moduleSymbolTable.insert(rodataOp, moduleBuilder.getInsertionPoint()); rodataOp.setPrivate(); replaceInlineOpWithRodataRef(inlineOp, rodataOp); diff --git a/iree/compiler/Utils/FlatbufferUtils.cpp b/iree/compiler/Utils/FlatbufferUtils.cpp index e98167b75d40..90d03be0f1d7 100644 --- a/iree/compiler/Utils/FlatbufferUtils.cpp +++ b/iree/compiler/Utils/FlatbufferUtils.cpp @@ -44,8 +44,8 @@ FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); } FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); } flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec( - std::function fn) { - flatbuffers_uint8_vec_start(*this); + std::function fn, size_t alignment) { + flatcc_builder_start_vector(*this, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); raw_flatbuffer_uint8_vec_ostream stream(*this); if (!fn(stream)) { return 0; diff --git a/iree/compiler/Utils/FlatbufferUtils.h b/iree/compiler/Utils/FlatbufferUtils.h index 524cf63f7725..783bffb910ed 100644 --- a/iree/compiler/Utils/FlatbufferUtils.h +++ b/iree/compiler/Utils/FlatbufferUtils.h @@ -108,7 +108,7 @@ class FlatbufferBuilder { // my_type_uint8_vec_field_add(builder, ref); // use vec reference // ... flatbuffers_uint8_vec_ref_t streamUint8Vec( - std::function fn); + std::function fn, size_t alignment = 16); // Captures the current contents of the flatbuffer builder and returns them // as a shaped `vector` dense attr. The builder is left unmodified. diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h index fefe38375078..6d849feccbb4 100644 --- a/iree/hal/buffer_view.h +++ b/iree/hal/buffer_view.h @@ -147,6 +147,11 @@ iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view); // Returns the buffer underlying the buffer view. // The caller must retain the returned buffer if they want to continue using it. +// +// NOTE: the returned buffer length will almost always be larger than the valid +// bytes representing this buffer view due to padding. Always query the actual +// valid length with iree_hal_buffer_view_byte_length instead of assuming the +// buffer is already clamped. IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view); diff --git a/iree/hal/local/loaders/BUILD b/iree/hal/local/loaders/BUILD index 618881dac987..fa4e1f10794e 100644 --- a/iree/hal/local/loaders/BUILD +++ b/iree/hal/local/loaders/BUILD @@ -42,6 +42,21 @@ cc_library( ], ) +cc_library( + name = "static_library_loader", + srcs = ["static_library_loader.c"], + hdrs = ["static_library_loader.h"], + defines = [ + "IREE_HAL_HAVE_STATIC_LIBRARY_LOADER=1", + ], + deps = [ + "//iree/base:api", + "//iree/base:tracing", + "//iree/hal:api", + "//iree/hal/local", + ], +) + cc_library( name = "system_library_loader", srcs = ["system_library_loader.c"], diff --git a/iree/hal/local/loaders/CMakeLists.txt b/iree/hal/local/loaders/CMakeLists.txt index d4168d44b05d..f440672e1dc7 100644 --- a/iree/hal/local/loaders/CMakeLists.txt +++ b/iree/hal/local/loaders/CMakeLists.txt @@ -31,6 +31,23 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + static_library_loader + HDRS + "static_library_loader.h" + SRCS + "static_library_loader.c" + DEPS + iree::base::api + iree::base::tracing + iree::hal::api + iree::hal::local + DEFINES + "IREE_HAL_HAVE_STATIC_LIBRARY_LOADER=1" + PUBLIC +) + iree_cc_library( NAME system_library_loader diff --git a/iree/hal/local/loaders/static_library_loader.c b/iree/hal/local/loaders/static_library_loader.c new file mode 100644 index 000000000000..51572baff630 --- /dev/null +++ b/iree/hal/local/loaders/static_library_loader.c @@ -0,0 +1,250 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/loaders/static_library_loader.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/local_executable.h" + +//===----------------------------------------------------------------------===// +// iree_hal_static_executable_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_local_executable_t base; + + // Name used for the file field in tracy and debuggers. + iree_string_view_t identifier; + + union { + const iree_hal_executable_library_header_t* header; + const iree_hal_executable_library_v0_t* v0; + } library; +} iree_hal_static_executable_t; + +static const iree_hal_local_executable_vtable_t + iree_hal_static_executable_vtable; + +static iree_status_t iree_hal_static_executable_create( + const iree_hal_executable_library_header_t* library_header, + iree_host_size_t executable_layout_count, + iree_hal_executable_layout_t* const* executable_layouts, + iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(library_header); + IREE_ASSERT_ARGUMENT(!executable_layout_count || executable_layouts); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_static_executable_t* executable = NULL; + iree_host_size_t total_size = + sizeof(*executable) + + executable_layout_count * sizeof(iree_hal_local_executable_layout_t); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_local_executable_layout_t** executable_layouts_ptr = + (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) + + sizeof(*executable)); + iree_hal_local_executable_initialize( + &iree_hal_static_executable_vtable, executable_layout_count, + executable_layouts, executable_layouts_ptr, host_allocator, + &executable->base); + executable->library.header = library_header; + executable->identifier = iree_make_cstring_view(library_header->name); + *out_executable = (iree_hal_executable_t*)executable; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_static_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_static_executable_t* executable = + (iree_hal_static_executable_t*)base_executable; + iree_allocator_t host_allocator = executable->base.host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_deinitialize( + (iree_hal_local_executable_t*)base_executable); + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_static_executable_issue_call( + iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal, + const iree_hal_executable_dispatch_state_v0_t* dispatch_state, + const iree_hal_vec3_t* workgroup_id) { + iree_hal_static_executable_t* executable = + (iree_hal_static_executable_t*)base_executable; + const iree_hal_executable_library_v0_t* library = executable->library.v0; + + if (IREE_UNLIKELY(ordinal >= library->entry_point_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "entry point ordinal out of bounds"); + } + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + iree_string_view_t entry_point_name = iree_string_view_empty(); + if (library->entry_point_names != NULL) { + entry_point_name = + iree_make_cstring_view(library->entry_point_names[ordinal]); + } + if (iree_string_view_is_empty(entry_point_name)) { + entry_point_name = iree_make_cstring_view("unknown_dylib_call"); + } + IREE_TRACE_ZONE_BEGIN_EXTERNAL( + z0, executable->identifier.data, executable->identifier.size, ordinal, + entry_point_name.data, entry_point_name.size, NULL, 0); +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + int ret = library->entry_points[ordinal](dispatch_state, workgroup_id); + + IREE_TRACE_ZONE_END(z0); + + return ret == 0 ? iree_ok_status() + : iree_make_status( + IREE_STATUS_INTERNAL, + "executable entry point returned catastrophic error %d", + ret); +} + +static const iree_hal_local_executable_vtable_t + iree_hal_static_executable_vtable = { + .base = + { + .destroy = iree_hal_static_executable_destroy, + }, + .issue_call = iree_hal_static_executable_issue_call, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_static_library_loader_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_executable_loader_t base; + iree_allocator_t host_allocator; + iree_host_size_t library_count; + iree_hal_executable_library_header_t* const libraries[]; +} iree_hal_static_library_loader_t; + +static const iree_hal_executable_loader_vtable_t + iree_hal_static_library_loader_vtable; + +iree_status_t iree_hal_static_library_loader_create( + iree_host_size_t library_count, + const iree_hal_executable_library_header_t* const* libraries, + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader) { + IREE_ASSERT_ARGUMENT(out_executable_loader); + *out_executable_loader = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Verify the libraries provided all match our expected version. + // It's rare they won't, however static libraries generated with a newer + // version of the IREE compiler that are then linked with an older version of + // the runtime are difficult to spot otherwise. + for (iree_host_size_t i = 0; i < library_count; ++i) { + if (libraries[i]->version > IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "executable does not support this version of the " + "runtime (executable: %d, runtime: %d)", + libraries[i]->version, + IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION); + } + } + + iree_hal_static_library_loader_t* executable_loader = NULL; + iree_host_size_t total_size = + sizeof(*executable_loader) + + sizeof(executable_loader->libraries[0]) * library_count; + iree_status_t status = iree_allocator_malloc(host_allocator, total_size, + (void**)&executable_loader); + if (iree_status_is_ok(status)) { + iree_hal_executable_loader_initialize( + &iree_hal_static_library_loader_vtable, &executable_loader->base); + executable_loader->host_allocator = host_allocator; + executable_loader->library_count = library_count; + memcpy((void*)executable_loader->libraries, libraries, + sizeof(libraries[0]) * library_count); + *out_executable_loader = (iree_hal_executable_loader_t*)executable_loader; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_static_library_loader_destroy( + iree_hal_executable_loader_t* base_executable_loader) { + iree_hal_static_library_loader_t* executable_loader = + (iree_hal_static_library_loader_t*)base_executable_loader; + iree_allocator_t host_allocator = executable_loader->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_loader); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_static_library_loader_query_support( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_caching_mode_t caching_mode, + iree_string_view_t executable_format) { + return iree_string_view_equal(executable_format, + iree_make_cstring_view("static")); +} + +static iree_status_t iree_hal_static_library_loader_try_load( + iree_hal_executable_loader_t* base_executable_loader, + const iree_hal_executable_spec_t* executable_spec, + iree_hal_executable_t** out_executable) { + iree_hal_static_library_loader_t* executable_loader = + (iree_hal_static_library_loader_t*)base_executable_loader; + + // The executable data is just the name of the library. + iree_string_view_t library_name = + iree_make_string_view((const char*)executable_spec->executable_data.data, + executable_spec->executable_data.data_length); + + // Linear scan of the registered libraries; there's usually only one per + // module (aka source model) and as such it's a small list and probably not + // worth optimizing. We could sort the libraries list by name on loader + // creation to perform a binary-search fairly easily, though, at the cost of + // the additional code size. + for (iree_host_size_t i = 0; i < executable_loader->library_count; ++i) { + if (iree_string_view_equal( + library_name, + iree_make_cstring_view(executable_loader->libraries[i]->name))) { + return iree_hal_static_executable_create( + executable_loader->libraries[i], + executable_spec->executable_layout_count, + executable_spec->executable_layouts, + executable_loader->host_allocator, out_executable); + } + } + return iree_make_status(IREE_STATUS_NOT_FOUND, + "no static library with the name '%.*s' registered", + (int)library_name.size, library_name.data); +} + +static const iree_hal_executable_loader_vtable_t + iree_hal_static_library_loader_vtable = { + .destroy = iree_hal_static_library_loader_destroy, + .query_support = iree_hal_static_library_loader_query_support, + .try_load = iree_hal_static_library_loader_try_load, +}; diff --git a/iree/hal/local/loaders/static_library_loader.h b/iree/hal/local/loaders/static_library_loader.h new file mode 100644 index 000000000000..185872995277 --- /dev/null +++ b/iree/hal/local/loaders/static_library_loader.h @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOADERS_STATIC_LIBRARY_LOADER_H_ +#define IREE_HAL_LOCAL_LOADERS_STATIC_LIBRARY_LOADER_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/local/executable_library.h" +#include "iree/hal/local/executable_loader.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a library loader that exposes the provided libraries to the HAL for +// use as executables. +// +// This loader will handle executable formats of 'static'. Version checks will +// ensure that the IREE compiler-produced static library version is one that the +// runtime can support. +// +// The name defined on each library will be used to lookup the executables and +// must match with the names used during compilation exactly. The +// iree_hal_executable_spec_t used to reference the executables will contain the +// library name and be used to lookup the library in the list. +// +// Multiple static library loaders can be registered in cases when several +// independent sets of libraries are linked in however duplicate names both +// within and across loaders will result in undefined behavior. +iree_status_t iree_hal_static_library_loader_create( + iree_host_size_t library_count, + const iree_hal_executable_library_header_t* const* libraries, + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOADERS_STATIC_LIBRARY_LOADER_H_ diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c index 15d0b1fd7fad..d6312f5f64c3 100644 --- a/iree/hal/local/task_command_buffer.c +++ b/iree/hal/local/task_command_buffer.c @@ -763,6 +763,12 @@ static iree_status_t iree_hal_task_command_buffer_build_dispatch( workgroup_size, workgroup_count, &cmd->task); iree_hal_executable_dispatch_state_v0_t* state = &cmd->state; + + // When we support imports we can populate those here based on what the + // executable declared (as each executable may import a unique set of + // functions). + state->imports = NULL; + memcpy(&state->workgroup_size, workgroup_size, sizeof(iree_hal_vec3_t)); memcpy(&state->workgroup_count, workgroup_count, sizeof(iree_hal_vec3_t)); diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc index 64f9ac23d7a7..c68429643bf6 100644 --- a/iree/modules/check/native_module.cc +++ b/iree/modules/check/native_module.cc @@ -201,10 +201,11 @@ class CheckModuleState final { iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(view); iree_hal_buffer_t* buf = iree_hal_buffer_view_buffer(view); + iree_device_size_t size = iree_hal_buffer_view_byte_length(view); iree_hal_buffer_mapping_t mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( - buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &mapped_memory)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_map_range(buf, IREE_HAL_MEMORY_ACCESS_READ, + /*byte_offset=*/0, size, &mapped_memory)); IREE_RETURN_IF_ERROR( ::iree::ExpectAllTrue(mapped_memory.contents, element_type)); iree_hal_buffer_unmap_range(&mapped_memory); @@ -215,6 +216,8 @@ class CheckModuleState final { vm::ref rhs_ref) { auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); + + iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs); size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs); std::vector lhs_shape(lhs_rank); if (lhs_rank > 0) { @@ -222,6 +225,7 @@ class CheckModuleState final { iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr)); } + iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs); size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs); std::vector rhs_shape(rhs_rank); if (rhs_rank > 0) { @@ -238,12 +242,12 @@ class CheckModuleState final { iree_hal_buffer_mapping_t lhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( lhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory)); + /*byte_offset=*/0, lhs_size, &lhs_mapped_memory)); iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs); iree_hal_buffer_mapping_t rhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( rhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory)); + /*byte_offset=*/0, rhs_size, &rhs_mapped_memory)); bool element_types_eq = lhs_element_type == rhs_element_type; bool shape_eq = lhs_shape == rhs_shape; @@ -288,6 +292,8 @@ class CheckModuleState final { vm::ref rhs_ref) { auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); + + iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs); size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs); std::vector lhs_shape(lhs_rank); if (lhs_rank > 0) { @@ -295,6 +301,7 @@ class CheckModuleState final { iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr)); } + iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs); size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs); std::vector rhs_shape(rhs_rank); if (rhs_rank > 0) { @@ -311,12 +318,12 @@ class CheckModuleState final { iree_hal_buffer_mapping_t lhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( lhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory)); + /*byte_offset=*/0, lhs_size, &lhs_mapped_memory)); iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs); iree_hal_buffer_mapping_t rhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( rhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory)); + /*byte_offset=*/0, rhs_size, &rhs_mapped_memory)); bool element_types_eq = lhs_element_type == rhs_element_type; bool shape_eq = lhs_shape == rhs_shape; diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c index 7ab130156f11..c741706704ea 100644 --- a/iree/vm/bytecode_dispatch.c +++ b/iree/vm/bytecode_dispatch.c @@ -872,7 +872,7 @@ iree_status_t iree_vm_bytecode_dispatch( bool list_is_move; iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); iree_vm_list_t* list = iree_vm_list_deref(*list_ref); - if (!list) { + if (IREE_UNLIKELY(!list)) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); } uint32_t index = VM_DecOperandRegI32("index"); @@ -882,26 +882,33 @@ iree_status_t iree_vm_bytecode_dispatch( }); DISPATCH_OP(CORE, ListGetRef, { - // bool list_is_move; - // iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); - // iree_vm_list_t* list = iree_vm_list_deref(list_ref); - // if (!list) return iree_make_status(IREE_STATUS_INVALID_ARGUMENT); - // uint32_t index = VM_DecOperandRegI32("index"); - // iree_vm_ref_t* result = VM_DecResultRegRef("result"); - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "vm.list.get.ref not implemented"); + bool list_is_move; + iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); + iree_vm_list_t* list = iree_vm_list_deref(*list_ref); + if (IREE_UNLIKELY(!list)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); + } + uint32_t index = VM_DecOperandRegI32("index"); + bool result_is_move; + iree_vm_ref_t* result = VM_DecResultRegRef("result", &result_is_move); + return iree_vm_list_get_ref_retain(list, index, result); }); DISPATCH_OP(CORE, ListSetRef, { - // bool list_is_move; - // iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); - // iree_vm_list_t* list = iree_vm_list_deref(list_ref); - // if (!list) return iree_make_status(IREE_STATUS_INVALID_ARGUMENT); - // uint32_t index = VM_DecOperandRegI32("index"); - // bool operand_is_move = VM_DecOperandRegRefIsMove("value"); - // iree_vm_ref_t* operand = VM_DecOperandRegRef("value"); - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "vm.list.set.ref not implemented"); + bool list_is_move; + iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); + iree_vm_list_t* list = iree_vm_list_deref(*list_ref); + if (IREE_UNLIKELY(!list)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); + } + uint32_t index = VM_DecOperandRegI32("index"); + bool operand_is_move; + iree_vm_ref_t* operand = VM_DecOperandRegRef("value", &operand_is_move); + if (operand_is_move) { + return iree_vm_list_set_ref_move(list, index, operand); + } else { + return iree_vm_list_set_ref_retain(list, index, operand); + } }); //===------------------------------------------------------------------===// diff --git a/iree/vm/bytecode_module.c b/iree/vm/bytecode_module.c index 812ccca3b9a0..9a4b8e3859df 100644 --- a/iree/vm/bytecode_module.c +++ b/iree/vm/bytecode_module.c @@ -28,38 +28,40 @@ static bool iree_vm_flatbuffer_strcmp(flatbuffers_string_t lhs, return x != 0 ? x : lhs_size < rhs.size ? -1 : lhs_size > rhs.size; } -// Returns true if the given |type_def| is valid, meaning that the type it was -// resolved from is registered or known to the system as a builtin. -static bool iree_vm_type_def_is_valid(iree_vm_type_def_t type_def) { - return type_def.value_type != IREE_VM_VALUE_TYPE_NONE || - type_def.ref_type != IREE_VM_REF_TYPE_NULL; -} - // Resolves a type through either builtin rules or the ref registered types. -static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type( - iree_vm_TypeDef_table_t type_def) { - iree_vm_type_def_t result; - memset(&result, 0, sizeof(result)); +static bool iree_vm_bytecode_module_resolve_type( + iree_vm_TypeDef_table_t type_def, iree_vm_type_def_t* out_type) { + memset(out_type, 0, sizeof(*out_type)); flatbuffers_string_t full_name = iree_vm_TypeDef_full_name(type_def); if (!flatbuffers_string_len(full_name)) { - return result; + return false; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i8")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I8; + out_type->value_type = IREE_VM_VALUE_TYPE_I8; + return true; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i16")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I16; + out_type->value_type = IREE_VM_VALUE_TYPE_I16; + return true; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i32")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I32; + out_type->value_type = IREE_VM_VALUE_TYPE_I32; + return true; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i64")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I64; + out_type->value_type = IREE_VM_VALUE_TYPE_I64; + return true; + } else if (iree_vm_flatbuffer_strcmp( + full_name, iree_make_cstring_view("!vm.opaque")) == 0) { + out_type->value_type = IREE_VM_VALUE_TYPE_NONE; + out_type->ref_type = IREE_VM_REF_TYPE_NULL; + return true; } else if (full_name[0] == '!') { // Note that we drop the ! prefix: iree_string_view_t type_name = {full_name + 1, flatbuffers_string_len(full_name) - 1}; - if (strncmp(type_name.data, "vm.list<", strlen("vm.list<")) == 0) { + if (iree_string_view_starts_with(type_name, + iree_make_cstring_view("vm.list"))) { // This is a !vm.list<...> type. We don't actually care about the type as // we allow list types to be widened. Rewrite to just vm.list as that's // all we have registered. @@ -68,10 +70,11 @@ static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type( const iree_vm_ref_type_descriptor_t* type_descriptor = iree_vm_ref_lookup_registered_type(type_name); if (type_descriptor) { - result.ref_type = type_descriptor->type; + out_type->ref_type = type_descriptor->type; } + return true; } - return result; + return false; } // Resolves all types through either builtin rules or the ref registered types. @@ -80,18 +83,18 @@ static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type( static iree_status_t iree_vm_bytecode_module_resolve_types( iree_vm_TypeDef_vec_t type_defs, iree_vm_type_def_t* type_table) { IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); for (size_t i = 0; i < iree_vm_TypeDef_vec_len(type_defs); ++i) { iree_vm_TypeDef_table_t type_def = iree_vm_TypeDef_vec_at(type_defs, i); - type_table[i] = iree_vm_bytecode_module_resolve_type(type_def); - if (!iree_vm_type_def_is_valid(type_table[i])) { - IREE_TRACE_ZONE_END(z0); - return iree_make_status(IREE_STATUS_NOT_FOUND, - "no type registered with name '%s'", - iree_vm_TypeDef_full_name(type_def)); + if (!iree_vm_bytecode_module_resolve_type(type_def, &type_table[i])) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "no type registered with name '%s'", + iree_vm_TypeDef_full_name(type_def)); + break; } } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } // Verifies the structure of the flatbuffer so that we can avoid doing so during diff --git a/iree/vm/test/BUILD b/iree/vm/test/BUILD index 4068d6e696df..af6820d0dfdd 100644 --- a/iree/vm/test/BUILD +++ b/iree/vm/test/BUILD @@ -45,6 +45,7 @@ cc_embed_data( ":conversion_ops_i64.vmfb", ":global_ops.vmfb", ":list_ops.vmfb", + ":list_variant_ops.vmfb", ":shift_ops.vmfb", ":shift_ops_i64.vmfb", ], @@ -123,6 +124,13 @@ iree_bytecode_module( flags = ["-iree-vm-ir-to-bytecode-module"], ) +iree_bytecode_module( + name = "list_variant_ops", + src = "list_variant_ops.mlir", + cc_namespace = "iree::vm::test", + flags = ["-iree-vm-ir-to-bytecode-module"], +) + iree_bytecode_module( name = "shift_ops", src = "shift_ops.mlir", diff --git a/iree/vm/test/CMakeLists.txt b/iree/vm/test/CMakeLists.txt index 84789c552558..3154ac24020d 100644 --- a/iree/vm/test/CMakeLists.txt +++ b/iree/vm/test/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_embed_data( "conversion_ops_i64.vmfb" "global_ops.vmfb" "list_ops.vmfb" + "list_variant_ops.vmfb" "shift_ops.vmfb" "shift_ops_i64.vmfb" CC_FILE_OUTPUT @@ -157,6 +158,18 @@ iree_bytecode_module( PUBLIC ) +iree_bytecode_module( + NAME + list_variant_ops + SRC + "list_variant_ops.mlir" + CC_NAMESPACE + "iree::vm::test" + FLAGS + "-iree-vm-ir-to-bytecode-module" + PUBLIC +) + iree_bytecode_module( NAME shift_ops diff --git a/iree/vm/test/list_ops.mlir b/iree/vm/test/list_ops.mlir index b336b6fe7f4c..a6d3c303989d 100644 --- a/iree/vm/test/list_ops.mlir +++ b/iree/vm/test/list_ops.mlir @@ -58,15 +58,4 @@ vm.module @list_ops { // TODO(benvanik): test vm.list with ref types. vm.return } - - //===--------------------------------------------------------------------===// - // vm.list.* with variant types - //===--------------------------------------------------------------------===// - - vm.export @test_variant - vm.func @test_variant() { - // TODO(benvanik): test vm.list with variant types. - vm.return - } - } diff --git a/iree/vm/test/list_variant_ops.mlir b/iree/vm/test/list_variant_ops.mlir new file mode 100644 index 000000000000..5f50d036dfa0 --- /dev/null +++ b/iree/vm/test/list_variant_ops.mlir @@ -0,0 +1,120 @@ +vm.module @list_variant_ops { + + //===--------------------------------------------------------------------===// + // vm.list.* with list types (nesting) + //===--------------------------------------------------------------------===// + + vm.export @test_listception + vm.func @test_listception() { + %c0 = vm.const.i32 0 : i32 + %c1 = vm.const.i32 1 : i32 + %c2 = vm.const.i32 2 : i32 + %c3 = vm.const.i32 3 : i32 + %c100 = vm.const.i32 100 : i32 + %c101 = vm.const.i32 101 : i32 + %c102 = vm.const.i32 102 : i32 + + // [100, 101, 102] + %inner0 = vm.list.alloc %c3 : (i32) -> !vm.list + vm.list.resize %inner0, %c3 : (!vm.list, i32) + vm.list.set.i32 %inner0, %c0, %c100 : (!vm.list, i32, i32) + vm.list.set.i32 %inner0, %c1, %c101 : (!vm.list, i32, i32) + vm.list.set.i32 %inner0, %c2, %c102 : (!vm.list, i32, i32) + + // [102, 101, 100] + %inner1 = vm.list.alloc %c3 : (i32) -> !vm.list + vm.list.resize %inner1, %c3 : (!vm.list, i32) + vm.list.set.i32 %inner1, %c0, %c102 : (!vm.list, i32, i32) + vm.list.set.i32 %inner1, %c1, %c101 : (!vm.list, i32, i32) + vm.list.set.i32 %inner1, %c2, %c100 : (!vm.list, i32, i32) + + // [ [100, 101, 102], [102, 101, 100] ] + %capacity = vm.const.i32 8 : i32 + %outer = vm.list.alloc %capacity : (i32) -> !vm.list> + vm.list.resize %outer, %c2 : (!vm.list>, i32) + vm.list.set.ref %outer, %c0, %inner0 : (!vm.list>, i32, !vm.list) + vm.list.set.ref %outer, %c1, %inner1 : (!vm.list>, i32, !vm.list) + + %inner0_ret = vm.list.get.ref %outer, %c0 : (!vm.list>, i32) -> !vm.list + vm.check.eq %inner0_ret, %inner0 : !vm.list + %inner0_e2 = vm.list.get.i32 %inner0_ret, %c2 : (!vm.list, i32) -> i32 + vm.check.eq %inner0_e2, %c102 : i32 + + %inner1_ret = vm.list.get.ref %outer, %c0 : (!vm.list>, i32) -> !vm.list + vm.check.eq %inner1_ret, %inner1 : !vm.list + %inner1_e2 = vm.list.get.i32 %inner1_ret, %c2 : (!vm.list, i32) -> i32 + vm.check.eq %inner1_e2, %c100 : i32 + + vm.return + } + + //===--------------------------------------------------------------------===// + // vm.list.* with variant types + //===--------------------------------------------------------------------===// + + vm.rodata @byte_buffer dense<[1, 2, 3]> : tensor<3xi32> + + vm.export @test_variant + vm.func @test_variant() { + %capacity = vm.const.i32 42 : i32 + %list = vm.list.alloc %capacity : (i32) -> !vm.list + vm.list.resize %list, %capacity : (!vm.list, i32) + + // Access element 10 as an i32. + %c10 = vm.const.i32 10 : i32 + %v10_i32 = vm.const.i32 1234 : i32 + vm.list.set.i32 %list, %c10, %v10_i32 : (!vm.list, i32, i32) + %e10_i32 = vm.list.get.i32 %list, %c10 : (!vm.list, i32) -> i32 + vm.check.eq %e10_i32, %v10_i32 : i32 + + // Access element 10 as an i64. + %v10_i64 = vm.const.i64 1234 : i64 + vm.list.set.i64 %list, %c10, %v10_i64 : (!vm.list, i32, i64) + %e10_i64 = vm.list.get.i64 %list, %c10 : (!vm.list, i32) -> i64 + vm.check.eq %e10_i64, %v10_i64 : i64 + + // Access element 11 as a ref object. + %c11 = vm.const.i32 11 : i32 + %v11_buf = vm.const.ref.rodata @byte_buffer : !vm.ref + vm.list.set.ref %list, %c11, %v11_buf : (!vm.list, i32, !vm.ref) + %e11_buf = vm.list.get.ref %list, %c11 : (!vm.list, i32) -> !vm.ref + vm.check.eq %e11_buf, %v11_buf : !vm.ref + + // Access element 11 as a different kind of ref object (incompatible). + // Should return null. + %e11_bad = vm.list.get.ref %list, %c11 : (!vm.list, i32) -> !vm.list + %null = vm.const.ref.zero : !vm.list + vm.check.eq %e11_bad, %null : !vm.list + + vm.return + } + + vm.export @test_variant_slot_change + vm.func @test_variant_slot_change() { + %capacity = vm.const.i32 42 : i32 + %list = vm.list.alloc %capacity : (i32) -> !vm.list + vm.list.resize %list, %capacity : (!vm.list, i32) + + %c10 = vm.const.i32 10 : i32 + + // Access element 10 as an i32. + %v10_i32 = vm.const.i32 1234 : i32 + vm.list.set.i32 %list, %c10, %v10_i32 : (!vm.list, i32, i32) + %e10_i32 = vm.list.get.i32 %list, %c10 : (!vm.list, i32) -> i32 + vm.check.eq %e10_i32, %v10_i32 : i32 + + // Access element 10 as a ref object. + %v10_buf = vm.const.ref.rodata @byte_buffer : !vm.ref + vm.list.set.ref %list, %c10, %v10_buf : (!vm.list, i32, !vm.ref) + %e10_buf = vm.list.get.ref %list, %c10 : (!vm.list, i32) -> !vm.ref + vm.check.eq %e10_buf, %v10_buf : !vm.ref + + // Accessing it as an i32 now that it stores the ref should return a + // default (until we support type queries). + %e10_any = vm.list.get.i32 %list, %c10 : (!vm.list, i32) -> i32 + %zero = vm.const.i32.zero : i32 + vm.check.eq %e10_any, %zero : i32 + + vm.return + } +}