From 36db60039462a097d04cab99cbe91d424fef3ea8 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 14 Apr 2021 12:31:15 -0700 Subject: [PATCH 1/9] Adding support for !vm.list variant types. This was only half plumbed through before and now it's about 85% plumbed. Probably. --- iree/compiler/Dialect/VM/IR/VMBase.td | 9 +- iree/compiler/Dialect/VM/IR/VMDialect.cpp | 15 ++- iree/compiler/Dialect/VM/IR/VMOps.cpp | 48 +++---- iree/compiler/Dialect/VM/IR/VMTypes.cpp | 5 +- .../compiler/Dialect/VM/IR/test/list_ops.mlir | 42 +++++- .../VM/Target/Bytecode/BytecodeEncoder.cpp | 14 +- .../Target/Bytecode/BytecodeModuleTarget.cpp | 7 +- iree/vm/bytecode_dispatch.c | 43 ++++--- iree/vm/bytecode_module.c | 55 ++++---- iree/vm/test/BUILD | 8 ++ iree/vm/test/CMakeLists.txt | 13 ++ iree/vm/test/list_ops.mlir | 11 -- iree/vm/test/list_variant_ops.mlir | 120 ++++++++++++++++++ 13 files changed, 298 insertions(+), 92 deletions(-) create mode 100644 iree/vm/test/list_variant_ops.mlir diff --git a/iree/compiler/Dialect/VM/IR/VMBase.td b/iree/compiler/Dialect/VM/IR/VMBase.td index 3f404a5bab4a..bbf9a10a69b7 100644 --- a/iree/compiler/Dialect/VM/IR/VMBase.td +++ b/iree/compiler/Dialect/VM/IR/VMBase.td @@ -590,9 +590,12 @@ def VM_AnyList : DialectType< class VM_ListOf : Type().getObjectType().isa()">, - SubstLeaves<"$_self", - "$_self.cast().getObjectType().cast().getElementType()", - type.predicate> + Or<[ + CPred<"$_self.cast().getObjectType().cast().getElementType().isa()">, + SubstLeaves<"$_self", + "$_self.cast().getObjectType().cast().getElementType()", + type.predicate> + ]>, ]>, "list<" # type.summary # ">"> { // Set the builder call if the base type has a builder call. string builderCall = !if(!empty(type.builderCall), diff --git a/iree/compiler/Dialect/VM/IR/VMDialect.cpp b/iree/compiler/Dialect/VM/IR/VMDialect.cpp index 38d5c1416a61..d2d796dce419 100644 --- a/iree/compiler/Dialect/VM/IR/VMDialect.cpp +++ b/iree/compiler/Dialect/VM/IR/VMDialect.cpp @@ -236,7 +236,12 @@ Type VMDialect::parseType(DialectAsmParser &parser) const { << "malformed list type '" << parser.getFullSymbolSpec() << "'"; return Type(); } - auto elementType = mlir::parseType(spec, getContext()); + Type elementType; + if (spec == "?") { + elementType = OpaqueType::get(getContext()); + } else { + elementType = mlir::parseType(spec, getContext()); + } if (!elementType) { parser.emitError(parser.getCurrentLocation()) << "invalid list element type specification: '" @@ -284,7 +289,13 @@ void VMDialect::printType(Type type, DialectAsmPrinter &os) const { } else if (type.isa()) { os << "opaque"; } else if (auto listType = type.dyn_cast()) { - os << "list<" << listType.getElementType() << ">"; + os << "list<"; + if (listType.getElementType().isa()) { + os << "?"; + } else { + os << listType.getElementType(); + } + os << ">"; } else { llvm_unreachable("unhandled VM type"); } diff --git a/iree/compiler/Dialect/VM/IR/VMOps.cpp b/iree/compiler/Dialect/VM/IR/VMOps.cpp index d85a8b3ca996..765330cb1d93 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.cpp +++ b/iree/compiler/Dialect/VM/IR/VMOps.cpp @@ -693,17 +693,19 @@ static LogicalResult verifyListGetRefOp(ListGetRefOp &op) { .cast(); auto elementType = listType.getElementType(); auto resultType = op.result().getType(); - if (elementType.isa() != - resultType.isa()) { - // Attempting to go between a primitive type and ref type. - return op.emitError() << "cannot convert between list type " << elementType - << " and result type " << resultType; - } else if (auto refType = elementType.dyn_cast()) { - if (!refType.getObjectType().isa() && - elementType != resultType) { - // List has a concrete type, verify it matches. - return op.emitError() << "list contains " << elementType - << " that cannot be accessed as " << resultType; + if (!elementType.isa()) { + if (elementType.isa() != + resultType.isa()) { + // Attempting to go between a primitive type and ref type. + return op.emitError() << "cannot convert between list type " + << elementType << " and result type " << resultType; + } else if (auto refType = elementType.dyn_cast()) { + if (!refType.getObjectType().isa() && + elementType != resultType) { + // List has a concrete type, verify it matches. + return op.emitError() << "list contains " << elementType + << " that cannot be accessed as " << resultType; + } } } return success(); @@ -717,17 +719,19 @@ static LogicalResult verifyListSetRefOp(ListSetRefOp &op) { .cast(); auto elementType = listType.getElementType(); auto valueType = op.value().getType(); - if (elementType.isa() != - valueType.isa()) { - // Attempting to go between a primitive type and ref type. - return op.emitError() << "cannot convert between list type " << elementType - << " and value type " << valueType; - } else if (auto refType = elementType.dyn_cast()) { - if (!refType.getObjectType().isa() && - elementType != valueType) { - // List has a concrete type, verify it matches. - return op.emitError() << "list contains " << elementType - << " that cannot be mutated as " << valueType; + if (!elementType.isa()) { + if (elementType.isa() != + valueType.isa()) { + // Attempting to go between a primitive type and ref type. + return op.emitError() << "cannot convert between list type " + << elementType << " and value type " << valueType; + } else if (auto refType = elementType.dyn_cast()) { + if (!refType.getObjectType().isa() && + elementType != valueType) { + // List has a concrete type, verify it matches. + return op.emitError() << "list contains " << elementType + << " that cannot be mutated as " << valueType; + } } } return success(); diff --git a/iree/compiler/Dialect/VM/IR/VMTypes.cpp b/iree/compiler/Dialect/VM/IR/VMTypes.cpp index fb6d8ec1cce5..5963e5e6ce3d 100644 --- a/iree/compiler/Dialect/VM/IR/VMTypes.cpp +++ b/iree/compiler/Dialect/VM/IR/VMTypes.cpp @@ -56,7 +56,10 @@ struct ListTypeStorage : public TypeStorage { // static bool ListType::isCompatible(Type type) { - if (type.isa()) { + if (type.isa()) { + // Allow all types (variant). + return true; + } else if (type.isa()) { // Allow all ref types. return true; } else if (type.isIntOrFloat()) { diff --git a/iree/compiler/Dialect/VM/IR/test/list_ops.mlir b/iree/compiler/Dialect/VM/IR/test/list_ops.mlir index 8741558c63d7..de2b61aef263 100644 --- a/iree/compiler/Dialect/VM/IR/test/list_ops.mlir +++ b/iree/compiler/Dialect/VM/IR/test/list_ops.mlir @@ -28,7 +28,7 @@ vm.module @module { // Typed accessors for lists with i32 elements. vm.module @module { // CHECK: @list_i32 - vm.func @list_i32(%arg0 : !vm.list) { + vm.func @list_i32(%arg0: !vm.list) { %c100 = vm.const.i32 100 : i32 // CHECK: vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 @@ -42,7 +42,7 @@ vm.module @module { } // CHECK: @list_i8_coerce - vm.func @list_i8_coerce(%arg0 : !vm.list) { + vm.func @list_i8_coerce(%arg0: !vm.list) { %c100 = vm.const.i32 100 : i32 // CHECK: = vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 @@ -60,7 +60,7 @@ vm.module @module { // Typed accessors for lists with opaque ref elements. vm.module @module { // CHECK: @list_ref_any - vm.func @list_ref_any(%arg0 : !vm.list>) { + vm.func @list_ref_any(%arg0: !vm.list>) { %c100 = vm.const.i32 100 : i32 // CHECK: %ref = vm.list.get.ref %arg0, %c100 : (!vm.list>, i32) -> !vm.ref @@ -78,7 +78,7 @@ vm.module @module { // Typed accessors for lists with strongly-typed ref elements. vm.module @module { // CHECK: @list_ref_typed - vm.func @list_ref_typed(%arg0 : !vm.list>) { + vm.func @list_ref_typed(%arg0: !vm.list>) { %c100 = vm.const.i32 100 : i32 // CHECK: %ref = vm.list.get.ref %arg0, %c100 : (!vm.list>, i32) -> !vm.ref @@ -90,3 +90,37 @@ vm.module @module { vm.return } } + +// ----- + +// Variant access allows any type access. +vm.module @module { + // CHECK: @list_create_variant + vm.func @list_create_variant() { + %c42 = vm.const.i32 42 : i32 + // CHECK: %list = vm.list.alloc %c42 : (i32) -> !vm.list + %list = vm.list.alloc %c42 : (i32) -> !vm.list + + vm.return + } + + // CHECK: @list_access_variant + vm.func @list_access_variant(%arg0: !vm.list) { + %c100 = vm.const.i32 100 : i32 + %c101 = vm.const.i32 101 : i32 + + // CHECK: = vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 + %0 = vm.list.get.i32 %arg0, %c100 : (!vm.list, i32) -> i32 + + // CHECK: vm.list.set.i32 %arg0, %c100, %0 : (!vm.list, i32, i32) + vm.list.set.i32 %arg0, %c100, %0 : (!vm.list, i32, i32) + + // CHECK: %ref = vm.list.get.ref %arg0, %c101 : (!vm.list, i32) -> !vm.ref + %ref = vm.list.get.ref %arg0, %c101 : (!vm.list, i32) -> !vm.ref + + // CHECK: vm.list.set.ref %arg0, %c101, %ref : (!vm.list, i32, !vm.ref) + vm.list.set.ref %arg0, %c101, %ref : (!vm.list, i32, !vm.ref) + + vm.return + } +} diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp index 3e4cee0899a8..1da05094de1b 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeEncoder.cpp @@ -92,7 +92,19 @@ class V0BytecodeEncoder : public BytecodeEncoder { } LogicalResult encodeType(Type type) override { - int typeOrdinal = typeTable_->lookup(type); + // HACK: it'd be nice to remove the implicit ref wrapper hiding. + if (auto refType = type.dyn_cast()) { + if (refType.getObjectType().isa()) { + type = refType.getObjectType(); + } + } + auto it = typeTable_->find(type); + if (it == typeTable_->end()) { + return currentOp_->emitOpError() + << "type " << type + << " cannot be encoded; not registered in type table"; + } + int typeOrdinal = it->second; return writeUint32(typeOrdinal); } diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 821ba43ec517..13261cca816b 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -73,9 +73,8 @@ static std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { sstream.flush(); typeMap.try_emplace(type, str); if (auto listType = type.dyn_cast()) { - if (listType.getElementType()) { - tryInsertType(listType.getElementType()); - } + assert(listType.getElementType()); + tryInsertType(listType.getElementType()); } }; for (auto funcOp : moduleOp.getBlock().getOps()) { @@ -90,7 +89,7 @@ static std::vector buildTypeTable(IREE::VM::ModuleOp moduleOp) { for (const auto &typeString : typeMap) { table.push_back(TypeDef{typeString.first, typeString.second}); } - llvm::sort( + llvm::stable_sort( table, +[](const TypeDef &lhs, const TypeDef &rhs) { // Always sort builtins above custom types. if (lhs.full_name[0] != '!' && rhs.full_name[0] == '!') { diff --git a/iree/vm/bytecode_dispatch.c b/iree/vm/bytecode_dispatch.c index 7ab130156f11..c741706704ea 100644 --- a/iree/vm/bytecode_dispatch.c +++ b/iree/vm/bytecode_dispatch.c @@ -872,7 +872,7 @@ iree_status_t iree_vm_bytecode_dispatch( bool list_is_move; iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); iree_vm_list_t* list = iree_vm_list_deref(*list_ref); - if (!list) { + if (IREE_UNLIKELY(!list)) { return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); } uint32_t index = VM_DecOperandRegI32("index"); @@ -882,26 +882,33 @@ iree_status_t iree_vm_bytecode_dispatch( }); DISPATCH_OP(CORE, ListGetRef, { - // bool list_is_move; - // iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); - // iree_vm_list_t* list = iree_vm_list_deref(list_ref); - // if (!list) return iree_make_status(IREE_STATUS_INVALID_ARGUMENT); - // uint32_t index = VM_DecOperandRegI32("index"); - // iree_vm_ref_t* result = VM_DecResultRegRef("result"); - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "vm.list.get.ref not implemented"); + bool list_is_move; + iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); + iree_vm_list_t* list = iree_vm_list_deref(*list_ref); + if (IREE_UNLIKELY(!list)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); + } + uint32_t index = VM_DecOperandRegI32("index"); + bool result_is_move; + iree_vm_ref_t* result = VM_DecResultRegRef("result", &result_is_move); + return iree_vm_list_get_ref_retain(list, index, result); }); DISPATCH_OP(CORE, ListSetRef, { - // bool list_is_move; - // iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); - // iree_vm_list_t* list = iree_vm_list_deref(list_ref); - // if (!list) return iree_make_status(IREE_STATUS_INVALID_ARGUMENT); - // uint32_t index = VM_DecOperandRegI32("index"); - // bool operand_is_move = VM_DecOperandRegRefIsMove("value"); - // iree_vm_ref_t* operand = VM_DecOperandRegRef("value"); - return iree_make_status(IREE_STATUS_UNIMPLEMENTED, - "vm.list.set.ref not implemented"); + bool list_is_move; + iree_vm_ref_t* list_ref = VM_DecOperandRegRef("list", &list_is_move); + iree_vm_list_t* list = iree_vm_list_deref(*list_ref); + if (IREE_UNLIKELY(!list)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "list is null"); + } + uint32_t index = VM_DecOperandRegI32("index"); + bool operand_is_move; + iree_vm_ref_t* operand = VM_DecOperandRegRef("value", &operand_is_move); + if (operand_is_move) { + return iree_vm_list_set_ref_move(list, index, operand); + } else { + return iree_vm_list_set_ref_retain(list, index, operand); + } }); //===------------------------------------------------------------------===// diff --git a/iree/vm/bytecode_module.c b/iree/vm/bytecode_module.c index 812ccca3b9a0..9a4b8e3859df 100644 --- a/iree/vm/bytecode_module.c +++ b/iree/vm/bytecode_module.c @@ -28,38 +28,40 @@ static bool iree_vm_flatbuffer_strcmp(flatbuffers_string_t lhs, return x != 0 ? x : lhs_size < rhs.size ? -1 : lhs_size > rhs.size; } -// Returns true if the given |type_def| is valid, meaning that the type it was -// resolved from is registered or known to the system as a builtin. -static bool iree_vm_type_def_is_valid(iree_vm_type_def_t type_def) { - return type_def.value_type != IREE_VM_VALUE_TYPE_NONE || - type_def.ref_type != IREE_VM_REF_TYPE_NULL; -} - // Resolves a type through either builtin rules or the ref registered types. -static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type( - iree_vm_TypeDef_table_t type_def) { - iree_vm_type_def_t result; - memset(&result, 0, sizeof(result)); +static bool iree_vm_bytecode_module_resolve_type( + iree_vm_TypeDef_table_t type_def, iree_vm_type_def_t* out_type) { + memset(out_type, 0, sizeof(*out_type)); flatbuffers_string_t full_name = iree_vm_TypeDef_full_name(type_def); if (!flatbuffers_string_len(full_name)) { - return result; + return false; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i8")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I8; + out_type->value_type = IREE_VM_VALUE_TYPE_I8; + return true; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i16")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I16; + out_type->value_type = IREE_VM_VALUE_TYPE_I16; + return true; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i32")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I32; + out_type->value_type = IREE_VM_VALUE_TYPE_I32; + return true; } else if (iree_vm_flatbuffer_strcmp(full_name, iree_make_cstring_view("i64")) == 0) { - result.value_type = IREE_VM_VALUE_TYPE_I64; + out_type->value_type = IREE_VM_VALUE_TYPE_I64; + return true; + } else if (iree_vm_flatbuffer_strcmp( + full_name, iree_make_cstring_view("!vm.opaque")) == 0) { + out_type->value_type = IREE_VM_VALUE_TYPE_NONE; + out_type->ref_type = IREE_VM_REF_TYPE_NULL; + return true; } else if (full_name[0] == '!') { // Note that we drop the ! prefix: iree_string_view_t type_name = {full_name + 1, flatbuffers_string_len(full_name) - 1}; - if (strncmp(type_name.data, "vm.list<", strlen("vm.list<")) == 0) { + if (iree_string_view_starts_with(type_name, + iree_make_cstring_view("vm.list"))) { // This is a !vm.list<...> type. We don't actually care about the type as // we allow list types to be widened. Rewrite to just vm.list as that's // all we have registered. @@ -68,10 +70,11 @@ static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type( const iree_vm_ref_type_descriptor_t* type_descriptor = iree_vm_ref_lookup_registered_type(type_name); if (type_descriptor) { - result.ref_type = type_descriptor->type; + out_type->ref_type = type_descriptor->type; } + return true; } - return result; + return false; } // Resolves all types through either builtin rules or the ref registered types. @@ -80,18 +83,18 @@ static iree_vm_type_def_t iree_vm_bytecode_module_resolve_type( static iree_status_t iree_vm_bytecode_module_resolve_types( iree_vm_TypeDef_vec_t type_defs, iree_vm_type_def_t* type_table) { IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); for (size_t i = 0; i < iree_vm_TypeDef_vec_len(type_defs); ++i) { iree_vm_TypeDef_table_t type_def = iree_vm_TypeDef_vec_at(type_defs, i); - type_table[i] = iree_vm_bytecode_module_resolve_type(type_def); - if (!iree_vm_type_def_is_valid(type_table[i])) { - IREE_TRACE_ZONE_END(z0); - return iree_make_status(IREE_STATUS_NOT_FOUND, - "no type registered with name '%s'", - iree_vm_TypeDef_full_name(type_def)); + if (!iree_vm_bytecode_module_resolve_type(type_def, &type_table[i])) { + status = iree_make_status(IREE_STATUS_NOT_FOUND, + "no type registered with name '%s'", + iree_vm_TypeDef_full_name(type_def)); + break; } } IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); + return status; } // Verifies the structure of the flatbuffer so that we can avoid doing so during diff --git a/iree/vm/test/BUILD b/iree/vm/test/BUILD index 4068d6e696df..af6820d0dfdd 100644 --- a/iree/vm/test/BUILD +++ b/iree/vm/test/BUILD @@ -45,6 +45,7 @@ cc_embed_data( ":conversion_ops_i64.vmfb", ":global_ops.vmfb", ":list_ops.vmfb", + ":list_variant_ops.vmfb", ":shift_ops.vmfb", ":shift_ops_i64.vmfb", ], @@ -123,6 +124,13 @@ iree_bytecode_module( flags = ["-iree-vm-ir-to-bytecode-module"], ) +iree_bytecode_module( + name = "list_variant_ops", + src = "list_variant_ops.mlir", + cc_namespace = "iree::vm::test", + flags = ["-iree-vm-ir-to-bytecode-module"], +) + iree_bytecode_module( name = "shift_ops", src = "shift_ops.mlir", diff --git a/iree/vm/test/CMakeLists.txt b/iree/vm/test/CMakeLists.txt index 84789c552558..3154ac24020d 100644 --- a/iree/vm/test/CMakeLists.txt +++ b/iree/vm/test/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_embed_data( "conversion_ops_i64.vmfb" "global_ops.vmfb" "list_ops.vmfb" + "list_variant_ops.vmfb" "shift_ops.vmfb" "shift_ops_i64.vmfb" CC_FILE_OUTPUT @@ -157,6 +158,18 @@ iree_bytecode_module( PUBLIC ) +iree_bytecode_module( + NAME + list_variant_ops + SRC + "list_variant_ops.mlir" + CC_NAMESPACE + "iree::vm::test" + FLAGS + "-iree-vm-ir-to-bytecode-module" + PUBLIC +) + iree_bytecode_module( NAME shift_ops diff --git a/iree/vm/test/list_ops.mlir b/iree/vm/test/list_ops.mlir index b336b6fe7f4c..a6d3c303989d 100644 --- a/iree/vm/test/list_ops.mlir +++ b/iree/vm/test/list_ops.mlir @@ -58,15 +58,4 @@ vm.module @list_ops { // TODO(benvanik): test vm.list with ref types. vm.return } - - //===--------------------------------------------------------------------===// - // vm.list.* with variant types - //===--------------------------------------------------------------------===// - - vm.export @test_variant - vm.func @test_variant() { - // TODO(benvanik): test vm.list with variant types. - vm.return - } - } diff --git a/iree/vm/test/list_variant_ops.mlir b/iree/vm/test/list_variant_ops.mlir new file mode 100644 index 000000000000..5f50d036dfa0 --- /dev/null +++ b/iree/vm/test/list_variant_ops.mlir @@ -0,0 +1,120 @@ +vm.module @list_variant_ops { + + //===--------------------------------------------------------------------===// + // vm.list.* with list types (nesting) + //===--------------------------------------------------------------------===// + + vm.export @test_listception + vm.func @test_listception() { + %c0 = vm.const.i32 0 : i32 + %c1 = vm.const.i32 1 : i32 + %c2 = vm.const.i32 2 : i32 + %c3 = vm.const.i32 3 : i32 + %c100 = vm.const.i32 100 : i32 + %c101 = vm.const.i32 101 : i32 + %c102 = vm.const.i32 102 : i32 + + // [100, 101, 102] + %inner0 = vm.list.alloc %c3 : (i32) -> !vm.list + vm.list.resize %inner0, %c3 : (!vm.list, i32) + vm.list.set.i32 %inner0, %c0, %c100 : (!vm.list, i32, i32) + vm.list.set.i32 %inner0, %c1, %c101 : (!vm.list, i32, i32) + vm.list.set.i32 %inner0, %c2, %c102 : (!vm.list, i32, i32) + + // [102, 101, 100] + %inner1 = vm.list.alloc %c3 : (i32) -> !vm.list + vm.list.resize %inner1, %c3 : (!vm.list, i32) + vm.list.set.i32 %inner1, %c0, %c102 : (!vm.list, i32, i32) + vm.list.set.i32 %inner1, %c1, %c101 : (!vm.list, i32, i32) + vm.list.set.i32 %inner1, %c2, %c100 : (!vm.list, i32, i32) + + // [ [100, 101, 102], [102, 101, 100] ] + %capacity = vm.const.i32 8 : i32 + %outer = vm.list.alloc %capacity : (i32) -> !vm.list> + vm.list.resize %outer, %c2 : (!vm.list>, i32) + vm.list.set.ref %outer, %c0, %inner0 : (!vm.list>, i32, !vm.list) + vm.list.set.ref %outer, %c1, %inner1 : (!vm.list>, i32, !vm.list) + + %inner0_ret = vm.list.get.ref %outer, %c0 : (!vm.list>, i32) -> !vm.list + vm.check.eq %inner0_ret, %inner0 : !vm.list + %inner0_e2 = vm.list.get.i32 %inner0_ret, %c2 : (!vm.list, i32) -> i32 + vm.check.eq %inner0_e2, %c102 : i32 + + %inner1_ret = vm.list.get.ref %outer, %c0 : (!vm.list>, i32) -> !vm.list + vm.check.eq %inner1_ret, %inner1 : !vm.list + %inner1_e2 = vm.list.get.i32 %inner1_ret, %c2 : (!vm.list, i32) -> i32 + vm.check.eq %inner1_e2, %c100 : i32 + + vm.return + } + + //===--------------------------------------------------------------------===// + // vm.list.* with variant types + //===--------------------------------------------------------------------===// + + vm.rodata @byte_buffer dense<[1, 2, 3]> : tensor<3xi32> + + vm.export @test_variant + vm.func @test_variant() { + %capacity = vm.const.i32 42 : i32 + %list = vm.list.alloc %capacity : (i32) -> !vm.list + vm.list.resize %list, %capacity : (!vm.list, i32) + + // Access element 10 as an i32. + %c10 = vm.const.i32 10 : i32 + %v10_i32 = vm.const.i32 1234 : i32 + vm.list.set.i32 %list, %c10, %v10_i32 : (!vm.list, i32, i32) + %e10_i32 = vm.list.get.i32 %list, %c10 : (!vm.list, i32) -> i32 + vm.check.eq %e10_i32, %v10_i32 : i32 + + // Access element 10 as an i64. + %v10_i64 = vm.const.i64 1234 : i64 + vm.list.set.i64 %list, %c10, %v10_i64 : (!vm.list, i32, i64) + %e10_i64 = vm.list.get.i64 %list, %c10 : (!vm.list, i32) -> i64 + vm.check.eq %e10_i64, %v10_i64 : i64 + + // Access element 11 as a ref object. + %c11 = vm.const.i32 11 : i32 + %v11_buf = vm.const.ref.rodata @byte_buffer : !vm.ref + vm.list.set.ref %list, %c11, %v11_buf : (!vm.list, i32, !vm.ref) + %e11_buf = vm.list.get.ref %list, %c11 : (!vm.list, i32) -> !vm.ref + vm.check.eq %e11_buf, %v11_buf : !vm.ref + + // Access element 11 as a different kind of ref object (incompatible). + // Should return null. + %e11_bad = vm.list.get.ref %list, %c11 : (!vm.list, i32) -> !vm.list + %null = vm.const.ref.zero : !vm.list + vm.check.eq %e11_bad, %null : !vm.list + + vm.return + } + + vm.export @test_variant_slot_change + vm.func @test_variant_slot_change() { + %capacity = vm.const.i32 42 : i32 + %list = vm.list.alloc %capacity : (i32) -> !vm.list + vm.list.resize %list, %capacity : (!vm.list, i32) + + %c10 = vm.const.i32 10 : i32 + + // Access element 10 as an i32. + %v10_i32 = vm.const.i32 1234 : i32 + vm.list.set.i32 %list, %c10, %v10_i32 : (!vm.list, i32, i32) + %e10_i32 = vm.list.get.i32 %list, %c10 : (!vm.list, i32) -> i32 + vm.check.eq %e10_i32, %v10_i32 : i32 + + // Access element 10 as a ref object. + %v10_buf = vm.const.ref.rodata @byte_buffer : !vm.ref + vm.list.set.ref %list, %c10, %v10_buf : (!vm.list, i32, !vm.ref) + %e10_buf = vm.list.get.ref %list, %c10 : (!vm.list, i32) -> !vm.ref + vm.check.eq %e10_buf, %v10_buf : !vm.ref + + // Accessing it as an i32 now that it stores the ref should return a + // default (until we support type queries). + %e10_any = vm.list.get.i32 %list, %c10 : (!vm.list, i32) -> i32 + %zero = vm.const.i32.zero : i32 + vm.check.eq %e10_any, %zero : i32 + + vm.return + } +} From 0d6cade583ce2041559f3b5df14bf8549e0d8ee2 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 15 Apr 2021 11:12:12 -0700 Subject: [PATCH 2/9] Adding iree.list support and variant type conversion. --- .../test/materialize_shape_support.mlir | 36 ++++---- iree/compiler/Dialect/IREE/IR/IREEDialect.cpp | 23 ++++- iree/compiler/Dialect/IREE/IR/IREEOps.cpp | 79 ++++++++++++----- iree/compiler/Dialect/IREE/IR/IREEOps.td | 20 ++++- iree/compiler/Dialect/IREE/IR/IREETypes.cpp | 9 +- iree/compiler/Dialect/IREE/IR/IREETypes.h | 10 +++ iree/compiler/Dialect/IREE/IR/test/BUILD | 1 + .../Dialect/IREE/IR/test/CMakeLists.txt | 1 + .../Dialect/IREE/IR/test/list_ops.mlir | 85 +++++++++++++++++++ .../Conversion/IREEToVM/ConvertIREEToVM.cpp | 27 +++++- .../Dialect/VM/Conversion/IREEToVM/test/BUILD | 1 + .../Conversion/IREEToVM/test/CMakeLists.txt | 1 + .../VM/Conversion/IREEToVM/test/list_ops.mlir | 37 ++++++++ .../Dialect/VM/Conversion/TypeConverter.cpp | 6 ++ 14 files changed, 287 insertions(+), 49 deletions(-) create mode 100644 iree/compiler/Dialect/IREE/IR/test/list_ops.mlir create mode 100644 iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir diff --git a/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir b/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir index f832e8867a60..c547de236acb 100644 --- a/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir +++ b/iree/compiler/Bindings/TFLite/Transforms/test/materialize_shape_support.mlir @@ -38,10 +38,10 @@ // CHECK-NEXT: %[[IN0_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[IN0_D0:.+]] = shapex.ranked_dim %[[IN0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[IN0_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[IN0_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb2: // CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index @@ -50,10 +50,10 @@ // CHECK-NEXT: %[[IN1_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[IN1_D0:.+]] = shapex.ranked_dim %[[IN1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[IN1_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[IN1_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb4: // CHECK-NEXT: return @@ -64,7 +64,7 @@ // CHECK: %[[IS_0:.+]] = cmpi eq, %[[INDEX]], %c0 : index // CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2 // CHECK-NEXT: ^bb1: -// CHECK-NEXT: %[[IN0_D0:.+]] = iree.list.get %[[LIST]], %c0 : !iree.list +// CHECK-NEXT: %[[IN0_D0:.+]] = iree.list.get %[[LIST]][%c0] : !iree.list // CHECK-NEXT: %[[IN0_SHAPE:.+]] = shapex.make_ranked_shape %[[IN0_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: flow.variable.store %[[IN0_SHAPE]], @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: br ^bb4 @@ -72,7 +72,7 @@ // CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index // CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4 // CHECK-NEXT: ^bb3: -// CHECK-NEXT: %[[IN1_D0:.+]] = iree.list.get %[[LIST]], %c0 : !iree.list +// CHECK-NEXT: %[[IN1_D0:.+]] = iree.list.get %[[LIST]][%c0] : !iree.list // CHECK-NEXT: %[[IN1_SHAPE:.+]] = shapex.make_ranked_shape %[[IN1_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: flow.variable.store %[[IN1_SHAPE]], @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: br ^bb4 @@ -90,10 +90,10 @@ // CHECK-NEXT: %[[OUT0_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_output0_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[OUT0_D0:.+]] = shapex.ranked_dim %[[OUT0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[OUT0_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[OUT0_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb2: // CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index @@ -102,10 +102,10 @@ // CHECK-NEXT: %[[OUT1_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_output1_shape : !shapex.ranked_shape<[?,8,8,3]> // CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list // CHECK-NEXT: %[[OUT1_D0:.+]] = shapex.ranked_dim %[[OUT1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index -// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[OUT1_D0]] : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list -// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[OUT1_D0]] : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list +// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list // CHECK-NEXT: br ^bb4 // CHECK-NEXT: ^bb4: // CHECK-NEXT: return diff --git a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp index 31d41a5f8ece..e8797e44ebcd 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREEDialect.cpp @@ -39,7 +39,9 @@ IREEDialect::IREEDialect(MLIRContext* context) Type IREEDialect::parseType(DialectAsmParser& parser) const { Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); llvm::StringRef spec = parser.getFullSymbolSpec(); - if (spec.consume_front("ptr")) { + if (spec == "variant") { + return IREE::VariantType::get(getContext()); + } else if (spec.consume_front("ptr")) { if (!spec.consume_front("<") || !spec.consume_back(">")) { parser.emitError(parser.getCurrentLocation()) << "malformed ptr type '" << parser.getFullSymbolSpec() << "'"; @@ -63,7 +65,12 @@ Type IREEDialect::parseType(DialectAsmParser& parser) const { << "malformed list type '" << parser.getFullSymbolSpec() << "'"; return Type(); } - auto elementType = mlir::parseType(spec, getContext()); + Type elementType; + if (spec == "?") { + elementType = IREE::VariantType::get(getContext()); + } else { + elementType = mlir::parseType(spec, getContext()); + } if (!elementType) { parser.emitError(parser.getCurrentLocation()) << "invalid list element type specification: '" @@ -77,14 +84,22 @@ Type IREEDialect::parseType(DialectAsmParser& parser) const { } void IREEDialect::printType(Type type, DialectAsmPrinter& os) const { - if (auto ptrType = type.dyn_cast()) { + if (type.isa()) { + os << "variant"; + } else if (auto ptrType = type.dyn_cast()) { os << "ptr<" << ptrType.getTargetType() << ">"; } else if (type.isa()) { os << "byte_buffer"; } else if (type.isa()) { os << "mutable_byte_buffer"; } else if (auto listType = type.dyn_cast()) { - os << "list<" << listType.getElementType() << ">"; + os << "list<"; + if (listType.getElementType().isa()) { + os << "?"; + } else { + os << listType.getElementType(); + } + os << ">"; } else { llvm_unreachable("unhandled IREE type"); } diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp index b71d02ec23e0..42e49c7c6c4a 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEOps.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREEOps.cpp @@ -155,43 +155,82 @@ void UnfoldableConstantOp::getCanonicalizationPatterns( // Lists //===----------------------------------------------------------------------===// -static ParseResult parseListType(OpAsmParser &parser, Type &listType, - Type &elementType) { +static ParseResult parseListTypeGet(OpAsmParser &parser, Type &listType, + Type &elementType) { if (failed(parser.parseType(listType))) { return parser.emitError(parser.getCurrentLocation(), - "expected !iree.list<> type"); + "expected !iree.list type"); + } + auto listElementType = listType.cast().getElementType(); + if (succeeded(parser.parseOptionalArrow())) { + // Use overridden type - required for variants only. + if (failed(parser.parseType(elementType))) { + return parser.emitError( + parser.getCurrentLocation(), + "expected an element type when specifying list access types"); + } + if (!ListType::canImplicitlyCast(listElementType, elementType)) { + return parser.emitError( + parser.getCurrentLocation(), + "list access types must match the same base type as the list element " + "type (when not variant)"); + } + } else { + // Use list element type as the result element type. + elementType = listElementType; } - elementType = listType.cast().getElementType(); return success(); } -static ParseResult parseListType(OpAsmParser &parser, Type &listType, - SmallVectorImpl &elementTypes) { - if (failed(parser.parseType(listType))) { +static void printListTypeGet(OpAsmPrinter &printer, Operation *, Type listType, + Type elementType) { + printer.printType(listType); + auto listElementType = listType.cast().getElementType(); + if (listElementType != elementType) { + printer.printArrowTypeList(ArrayRef{elementType}); + } +} + +static ParseResult parseListTypeSet(OpAsmParser &parser, Type &listType, + Type &elementType) { + Type leadingType; + if (failed(parser.parseType(leadingType))) { return parser.emitError(parser.getCurrentLocation(), - "expected !iree.list<> type"); + "expected element type or !iree.list type"); } - for (size_t i = 0; i < elementTypes.size(); ++i) { - elementTypes[i] = listType.cast().getElementType(); + if (succeeded(parser.parseOptionalArrow())) { + elementType = leadingType; + if (failed(parser.parseType(listType)) || !listType.isa()) { + return parser.emitError(parser.getCurrentLocation(), + "expected an !iree.list type"); + } + } else { + if (!leadingType.isa()) { + return parser.emitError(parser.getCurrentLocation(), + "expected an !iree.list type"); + } + listType = leadingType; + elementType = listType.cast().getElementType(); } return success(); } -static void printListType(OpAsmPrinter &printer, Operation *, Type listType, - Type elementType) { - printer.printType(listType); -} - -static void printListType(OpAsmPrinter &printer, Operation *, Type listType, - TypeRange elementTypes) { - printer.printType(listType); +static void printListTypeSet(OpAsmPrinter &printer, Operation *, Type listType, + Type elementType) { + auto listElementType = listType.cast().getElementType(); + if (listElementType != elementType) { + printer.printType(elementType); + printer.printArrowTypeList(ArrayRef{listType}); + } else { + printer.printType(listType); + } } static LogicalResult verifyListGetOp(ListGetOp &op) { auto listType = op.list().getType().cast(); auto elementType = listType.getElementType(); auto resultType = op.result().getType(); - if (resultType != elementType) { + if (!ListType::canImplicitlyCast(elementType, resultType)) { return op.emitError() << "list contains " << elementType << " and cannot be accessed as " << resultType; } @@ -202,7 +241,7 @@ static LogicalResult verifyListSetOp(ListSetOp &op) { auto listType = op.list().getType().cast(); auto elementType = listType.getElementType(); auto valueType = op.value().getType(); - if (valueType != elementType) { + if (!ListType::canImplicitlyCast(valueType, elementType)) { return op.emitError() << "list contains " << elementType << " and cannot be mutated as " << valueType; } diff --git a/iree/compiler/Dialect/IREE/IR/IREEOps.td b/iree/compiler/Dialect/IREE/IR/IREEOps.td index 7bf01022f515..0d13d3823372 100644 --- a/iree/compiler/Dialect/IREE/IR/IREEOps.td +++ b/iree/compiler/Dialect/IREE/IR/IREEOps.td @@ -187,6 +187,22 @@ def IREE_UnreachableOp : IREE_Op<"unreachable", [NoSideEffect, Terminator]> { // new SSA values. This would make optimizing the list usage much easier and // enable hoisting/CSE of list access/mutation. +def IREE_ListCreateOp : IREE_PureOp<"list.create"> { + let summary = [{creates a new empty list}]; + let description = [{ + Creates a new empty list with an optional initial capacity. + }]; + + let arguments = (ins + Optional:$initial_capacity + ); + let results = (outs + AnyList:$result + ); + + let assemblyFormat = "($initial_capacity^)? attr-dict `:` type($result)"; +} + def IREE_ListSizeOp : IREE_Op<"list.size"> { let summary = [{the size of the list in elements}]; let description = [{ @@ -234,7 +250,7 @@ def IREE_ListGetOp : IREE_Op<"list.get"> { AnyType:$result ); - let assemblyFormat = "operands attr-dict `:` custom(type($list), type($result))"; + let assemblyFormat = "$list `[` $index `]` attr-dict `:` custom(type($list), type($result))"; let verifier = [{ return verify$cppClass(*this); }]; } @@ -251,7 +267,7 @@ def IREE_ListSetOp : IREE_Op<"list.set"> { AnyType:$value ); - let assemblyFormat = "operands attr-dict `:` custom(type($list), type($value))"; + let assemblyFormat = "$list `[` $index `]` `,` $value attr-dict `:` custom(type($list), type($value))"; let verifier = [{ return verify$cppClass(*this); }]; } diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp index 19b0e3745da8..ed748a3b9d3b 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.cpp +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.cpp @@ -52,6 +52,13 @@ struct ListTypeStorage : public TypeStorage { // static bool ListType::isCompatible(Type type) { return true; } +// static +bool ListType::canImplicitlyCast(Type from, Type to) { + if (from.isa() || to.isa()) return true; + if (from.isa() && to.isa()) return true; + return from == to; +} + ListType ListType::get(Type elementType) { return Base::get(elementType.getContext(), elementType); } @@ -244,7 +251,7 @@ void excludeTiedOperandAndResultIndices( void IREEDialect::registerTypes() { addTypes(); + IREE::PtrType, IREE::VariantType>(); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/IREE/IR/IREETypes.h b/iree/compiler/Dialect/IREE/IR/IREETypes.h index 73e81305bdec..ae03c40ef37b 100644 --- a/iree/compiler/Dialect/IREE/IR/IREETypes.h +++ b/iree/compiler/Dialect/IREE/IR/IREETypes.h @@ -58,6 +58,12 @@ enum class StatusCode : int32_t { DoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 }; +/// Placeholder for a variant type (`?`). +class VariantType : public Type::TypeBase { + public: + using Base::Base; +}; + /// A list containing an optional element type. class ListType : public Type::TypeBase { @@ -67,6 +73,10 @@ class ListType /// Returns true if the given type can be wrapped in a list. static bool isCompatible(Type type); + /// Returns true if |from| can be implicitly cast to |to| as part of a list + /// access operation. Example: tensor<*xf32> -> tensor<4xf32>. + static bool canImplicitlyCast(Type from, Type to); + /// Gets or creates a ListType with the provided element type. static ListType get(Type elementType); diff --git a/iree/compiler/Dialect/IREE/IR/test/BUILD b/iree/compiler/Dialect/IREE/IR/test/BUILD index aacb0e0f1def..3935fda86241 100644 --- a/iree/compiler/Dialect/IREE/IR/test/BUILD +++ b/iree/compiler/Dialect/IREE/IR/test/BUILD @@ -27,6 +27,7 @@ iree_lit_test_suite( [ "byte_buffer_ops.mlir", "do_not_optimize.mlir", + "list_ops.mlir", "parse_print.mlir", ], include = ["*.mlir"], diff --git a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt index 34fbf727c1c8..63b007972a1f 100644 --- a/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt +++ b/iree/compiler/Dialect/IREE/IR/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "byte_buffer_ops.mlir" "do_not_optimize.mlir" + "list_ops.mlir" "parse_print.mlir" DATA iree::tools::IreeFileCheck diff --git a/iree/compiler/Dialect/IREE/IR/test/list_ops.mlir b/iree/compiler/Dialect/IREE/IR/test/list_ops.mlir new file mode 100644 index 000000000000..b3f046bb5363 --- /dev/null +++ b/iree/compiler/Dialect/IREE/IR/test/list_ops.mlir @@ -0,0 +1,85 @@ +// RUN: iree-opt -split-input-file %s | iree-opt -split-input-file | IreeFileCheck %s + +// CHECK-LABEL: @list_init_ops +func @list_init_ops() { + // CHECK: %[[CAPACITY:.+]] = constant 5 + %capacity = constant 5 : index + // CHECK: = iree.list.create %[[CAPACITY]] : !iree.list + %list_initial_capacity = iree.list.create %capacity : !iree.list + + // CHECK: %[[LIST:.+]] = iree.list.create : !iree.list + %list = iree.list.create : !iree.list + + // CHECK: %[[NEW_SIZE:.+]] = constant 100 + %new_size = constant 100 : index + // CHECK: iree.list.resize %[[LIST]], %[[NEW_SIZE]] : !iree.list + iree.list.resize %list, %new_size : !iree.list + + return +} + +// ----- + +// CHECK-LABEL: @list_access +// CHECK-SAME: (%[[LIST:.+]]: !iree.list) +func @list_access(%list: !iree.list) { + %c10 = constant 10 : index + + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list + %0 = iree.list.get %list[%c10] : !iree.list + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list + %1 = iree.list.get %list[%c10] : !iree.list -> i32 + + // CHECK: %[[NEW_VALUE:.+]] = constant 100 : i32 + %new_value = constant 100 : i32 + // CHECK: iree.list.set %[[LIST]][%c10], %[[NEW_VALUE]] : !iree.list + iree.list.set %list[%c10], %new_value : !iree.list + + return +} + +// ----- + +// CHECK-LABEL: @list_access_tensor +// CHECK-SAME: (%[[LIST:.+]]: !iree.list>) +func @list_access_tensor(%list: !iree.list>) { + %c10 = constant 10 : index + + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list> -> tensor + %0 = iree.list.get %list[%c10] : !iree.list> -> tensor + + // CHECK: %[[NEW_VALUE:.+]] = constant dense<1> : tensor<5xi32> + %new_value = constant dense<1> : tensor<5xi32> + // CHECK: iree.list.set %[[LIST]][%c10], %[[NEW_VALUE]] : tensor<5xi32> -> !iree.list> + iree.list.set %list[%c10], %new_value : tensor<5xi32> -> !iree.list> + + return +} + +// ----- + +// CHECK-LABEL: @list_access_variant +// CHECK-SAME: (%[[LIST:.+]]: !iree.list) +func @list_access_variant(%list: !iree.list) { + %c10 = constant 10 : index + %c11 = constant 11 : index + + // CHECK: = iree.list.get %[[LIST]][%c10] : !iree.list -> i32 + %0 = iree.list.get %list[%c10] : !iree.list -> i32 + + // CHECK: %[[NEW_I32_VALUE:.+]] = constant 100 : i32 + %new_i32_value = constant 100 : i32 + // CHECK: iree.list.set %[[LIST]][%c10], %[[NEW_I32_VALUE]] : i32 -> !iree.list + iree.list.set %list[%c10], %new_i32_value : i32 -> !iree.list + + // CHECK: = iree.list.get %[[LIST]][%c11] : !iree.list -> tensor<5xf32> + %1 = iree.list.get %list[%c11] : !iree.list -> tensor<5xf32> + + // CHECK: %[[NEW_TENSOR_VALUE:.+]] = constant dense<1> : tensor<5xi32> + %new_tensor_value = constant dense<1> : tensor<5xi32> + // CHECK: iree.list.set %[[LIST]][%c11], %[[NEW_TENSOR_VALUE]] : tensor<5xi32> -> !iree.list + iree.list.set %list[%c11], %new_tensor_value : tensor<5xi32> -> !iree.list + + return +} + diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp index 8f74279ecd1f..f49e8f26ac73 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/ConvertIREEToVM.cpp @@ -87,6 +87,19 @@ class UnreachableOpConversion // Lists //===----------------------------------------------------------------------===// +class ListCreateOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + IREE::ListCreateOp srcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + IREE::ListCreateOpAdaptor srcOperands(operands); + rewriter.replaceOpWithNewOp( + srcOp, typeConverter->convertType(srcOp.result().getType()), + srcOperands.initial_capacity()); + return success(); + } +}; + class ListSizeOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( @@ -169,13 +182,19 @@ void populateIREEToVMPatterns(MLIRContext *context, typeConverter.addConversion( [&typeConverter](IREE::ListType type) -> Optional { - auto elementType = typeConverter.convertType(type.getElementType()); + Type elementType; + if (type.getElementType().isa()) { + elementType = IREE::VM::OpaqueType::get(type.getContext()); + } else { + elementType = typeConverter.convertType(type.getElementType()); + } if (!elementType) return llvm::None; return IREE::VM::RefType::get(IREE::VM::ListType::get(elementType)); }); - patterns.insert(typeConverter, - context); + patterns + .insert( + typeConverter, context); } } // namespace iree_compiler diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD index 26a05ad4bb48..16ed188d57ad 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/BUILD @@ -27,6 +27,7 @@ iree_lit_test_suite( [ "byte_buffer_ops.mlir", "hint_ops.mlir", + "list_ops.mlir", ], include = ["*.mlir"], ), diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt index 81db6954bef8..440353019b31 100644 --- a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "byte_buffer_ops.mlir" "hint_ops.mlir" + "list_ops.mlir" DATA iree::tools::IreeFileCheck iree::tools::iree-opt diff --git a/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir new file mode 100644 index 000000000000..ea4b92d367b8 --- /dev/null +++ b/iree/compiler/Dialect/VM/Conversion/IREEToVM/test/list_ops.mlir @@ -0,0 +1,37 @@ +// RUN: iree-opt -split-input-file -iree-vm-conversion %s | IreeFileCheck %s + +// CHECK-LABEL: @list_ops +module @list_ops { module { + // CHECK: vm.func @my_fn + // CHECK-SAME: (%[[BUFFER_VIEW:.+]]: !vm.ref) + func @my_fn(%buffer_view: !hal.buffer_view) { + // CHECK: %[[CAPACITY:.+]] = vm.const.i32 5 + %capacity = constant 5 : index + // CHECK: %[[LIST:.+]] = vm.list.alloc %[[CAPACITY]] : (i32) -> !vm.list + %list = iree.list.create %capacity : !iree.list + + // CHECK: %[[NEW_SIZE:.+]] = vm.const.i32 100 + %new_size = constant 100 : index + // CHECK: vm.list.resize %[[LIST]], %[[NEW_SIZE]] : (!vm.list, i32) + iree.list.resize %list, %new_size : !iree.list + + %c10 = constant 10 : index + %c11 = constant 11 : index + + // CHECK: = vm.list.get.i32 %[[LIST]], %c10 : (!vm.list, i32) -> i32 + %0 = iree.list.get %list[%c10] : !iree.list -> i32 + + // CHECK: %[[NEW_I32_VALUE:.+]] = vm.const.i32 101 + %new_i32_value = constant 101 : i32 + // CHECK: vm.list.set.i32 %[[LIST]], %c10, %[[NEW_I32_VALUE]] : (!vm.list, i32, i32) + iree.list.set %list[%c10], %new_i32_value : i32 -> !iree.list + + // CHECK: = vm.list.get.ref %[[LIST]], %c11 : (!vm.list, i32) -> !vm.ref + %1 = iree.list.get %list[%c11] : !iree.list -> !hal.buffer_view + + // CHECK: vm.list.set.ref %[[LIST]], %c11, %[[BUFFER_VIEW]] : (!vm.list, i32, !vm.ref) + iree.list.set %list[%c11], %buffer_view : !hal.buffer_view -> !iree.list + + return + } +} } diff --git a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp index fccf14c56316..0a2797924d56 100644 --- a/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp +++ b/iree/compiler/Dialect/VM/Conversion/TypeConverter.cpp @@ -31,8 +31,14 @@ namespace VM { TypeConverter::TypeConverter(TargetOptions targetOptions) : targetOptions_(targetOptions) { + // Variant means opaque in VM. + addConversion([](IREE::VariantType type) { + return IREE::VM::OpaqueType::get(type.getContext()); + }); + // All ref types are passed through unmodified. addConversion([](IREE::VM::RefType type) { return type; }); + // Wrap ref types. addConversion([](Type type) -> Optional { if (RefType::isCompatible(type)) { From 20ce42b36ef5014b75c2e591e83c32b739755643 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 16 Apr 2021 13:15:02 -0700 Subject: [PATCH 3/9] Adding a static library loader. (#5480) This allows applications to inject a set of static libraries that can be resolved at runtime by name. This only includes the runtime portion: there's nothing in the compiler actually generating static libraries yet. Future improvements to the driver API as part of #4298 will make it easier to inject these loaders. Fixes #4679. --- iree/hal/local/loaders/BUILD | 15 ++ iree/hal/local/loaders/CMakeLists.txt | 17 ++ .../hal/local/loaders/static_library_loader.c | 250 ++++++++++++++++++ .../hal/local/loaders/static_library_loader.h | 54 ++++ 4 files changed, 336 insertions(+) create mode 100644 iree/hal/local/loaders/static_library_loader.c create mode 100644 iree/hal/local/loaders/static_library_loader.h diff --git a/iree/hal/local/loaders/BUILD b/iree/hal/local/loaders/BUILD index 618881dac987..fa4e1f10794e 100644 --- a/iree/hal/local/loaders/BUILD +++ b/iree/hal/local/loaders/BUILD @@ -42,6 +42,21 @@ cc_library( ], ) +cc_library( + name = "static_library_loader", + srcs = ["static_library_loader.c"], + hdrs = ["static_library_loader.h"], + defines = [ + "IREE_HAL_HAVE_STATIC_LIBRARY_LOADER=1", + ], + deps = [ + "//iree/base:api", + "//iree/base:tracing", + "//iree/hal:api", + "//iree/hal/local", + ], +) + cc_library( name = "system_library_loader", srcs = ["system_library_loader.c"], diff --git a/iree/hal/local/loaders/CMakeLists.txt b/iree/hal/local/loaders/CMakeLists.txt index d4168d44b05d..f440672e1dc7 100644 --- a/iree/hal/local/loaders/CMakeLists.txt +++ b/iree/hal/local/loaders/CMakeLists.txt @@ -31,6 +31,23 @@ iree_cc_library( PUBLIC ) +iree_cc_library( + NAME + static_library_loader + HDRS + "static_library_loader.h" + SRCS + "static_library_loader.c" + DEPS + iree::base::api + iree::base::tracing + iree::hal::api + iree::hal::local + DEFINES + "IREE_HAL_HAVE_STATIC_LIBRARY_LOADER=1" + PUBLIC +) + iree_cc_library( NAME system_library_loader diff --git a/iree/hal/local/loaders/static_library_loader.c b/iree/hal/local/loaders/static_library_loader.c new file mode 100644 index 000000000000..51572baff630 --- /dev/null +++ b/iree/hal/local/loaders/static_library_loader.c @@ -0,0 +1,250 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "iree/hal/local/loaders/static_library_loader.h" + +#include "iree/base/tracing.h" +#include "iree/hal/local/local_executable.h" + +//===----------------------------------------------------------------------===// +// iree_hal_static_executable_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_local_executable_t base; + + // Name used for the file field in tracy and debuggers. + iree_string_view_t identifier; + + union { + const iree_hal_executable_library_header_t* header; + const iree_hal_executable_library_v0_t* v0; + } library; +} iree_hal_static_executable_t; + +static const iree_hal_local_executable_vtable_t + iree_hal_static_executable_vtable; + +static iree_status_t iree_hal_static_executable_create( + const iree_hal_executable_library_header_t* library_header, + iree_host_size_t executable_layout_count, + iree_hal_executable_layout_t* const* executable_layouts, + iree_allocator_t host_allocator, iree_hal_executable_t** out_executable) { + IREE_ASSERT_ARGUMENT(library_header); + IREE_ASSERT_ARGUMENT(!executable_layout_count || executable_layouts); + IREE_ASSERT_ARGUMENT(out_executable); + *out_executable = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_static_executable_t* executable = NULL; + iree_host_size_t total_size = + sizeof(*executable) + + executable_layout_count * sizeof(iree_hal_local_executable_layout_t); + iree_status_t status = + iree_allocator_malloc(host_allocator, total_size, (void**)&executable); + if (iree_status_is_ok(status)) { + iree_hal_local_executable_layout_t** executable_layouts_ptr = + (iree_hal_local_executable_layout_t**)(((uint8_t*)executable) + + sizeof(*executable)); + iree_hal_local_executable_initialize( + &iree_hal_static_executable_vtable, executable_layout_count, + executable_layouts, executable_layouts_ptr, host_allocator, + &executable->base); + executable->library.header = library_header; + executable->identifier = iree_make_cstring_view(library_header->name); + *out_executable = (iree_hal_executable_t*)executable; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_static_executable_destroy( + iree_hal_executable_t* base_executable) { + iree_hal_static_executable_t* executable = + (iree_hal_static_executable_t*)base_executable; + iree_allocator_t host_allocator = executable->base.host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_hal_local_executable_deinitialize( + (iree_hal_local_executable_t*)base_executable); + iree_allocator_free(host_allocator, executable); + + IREE_TRACE_ZONE_END(z0); +} + +static iree_status_t iree_hal_static_executable_issue_call( + iree_hal_local_executable_t* base_executable, iree_host_size_t ordinal, + const iree_hal_executable_dispatch_state_v0_t* dispatch_state, + const iree_hal_vec3_t* workgroup_id) { + iree_hal_static_executable_t* executable = + (iree_hal_static_executable_t*)base_executable; + const iree_hal_executable_library_v0_t* library = executable->library.v0; + + if (IREE_UNLIKELY(ordinal >= library->entry_point_count)) { + return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, + "entry point ordinal out of bounds"); + } + +#if IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + iree_string_view_t entry_point_name = iree_string_view_empty(); + if (library->entry_point_names != NULL) { + entry_point_name = + iree_make_cstring_view(library->entry_point_names[ordinal]); + } + if (iree_string_view_is_empty(entry_point_name)) { + entry_point_name = iree_make_cstring_view("unknown_dylib_call"); + } + IREE_TRACE_ZONE_BEGIN_EXTERNAL( + z0, executable->identifier.data, executable->identifier.size, ordinal, + entry_point_name.data, entry_point_name.size, NULL, 0); +#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION + + int ret = library->entry_points[ordinal](dispatch_state, workgroup_id); + + IREE_TRACE_ZONE_END(z0); + + return ret == 0 ? iree_ok_status() + : iree_make_status( + IREE_STATUS_INTERNAL, + "executable entry point returned catastrophic error %d", + ret); +} + +static const iree_hal_local_executable_vtable_t + iree_hal_static_executable_vtable = { + .base = + { + .destroy = iree_hal_static_executable_destroy, + }, + .issue_call = iree_hal_static_executable_issue_call, +}; + +//===----------------------------------------------------------------------===// +// iree_hal_static_library_loader_t +//===----------------------------------------------------------------------===// + +typedef struct { + iree_hal_executable_loader_t base; + iree_allocator_t host_allocator; + iree_host_size_t library_count; + iree_hal_executable_library_header_t* const libraries[]; +} iree_hal_static_library_loader_t; + +static const iree_hal_executable_loader_vtable_t + iree_hal_static_library_loader_vtable; + +iree_status_t iree_hal_static_library_loader_create( + iree_host_size_t library_count, + const iree_hal_executable_library_header_t* const* libraries, + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader) { + IREE_ASSERT_ARGUMENT(out_executable_loader); + *out_executable_loader = NULL; + IREE_TRACE_ZONE_BEGIN(z0); + + // Verify the libraries provided all match our expected version. + // It's rare they won't, however static libraries generated with a newer + // version of the IREE compiler that are then linked with an older version of + // the runtime are difficult to spot otherwise. + for (iree_host_size_t i = 0; i < library_count; ++i) { + if (libraries[i]->version > IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION) { + IREE_TRACE_ZONE_END(z0); + return iree_make_status(IREE_STATUS_FAILED_PRECONDITION, + "executable does not support this version of the " + "runtime (executable: %d, runtime: %d)", + libraries[i]->version, + IREE_HAL_EXECUTABLE_LIBRARY_LATEST_VERSION); + } + } + + iree_hal_static_library_loader_t* executable_loader = NULL; + iree_host_size_t total_size = + sizeof(*executable_loader) + + sizeof(executable_loader->libraries[0]) * library_count; + iree_status_t status = iree_allocator_malloc(host_allocator, total_size, + (void**)&executable_loader); + if (iree_status_is_ok(status)) { + iree_hal_executable_loader_initialize( + &iree_hal_static_library_loader_vtable, &executable_loader->base); + executable_loader->host_allocator = host_allocator; + executable_loader->library_count = library_count; + memcpy((void*)executable_loader->libraries, libraries, + sizeof(libraries[0]) * library_count); + *out_executable_loader = (iree_hal_executable_loader_t*)executable_loader; + } + + IREE_TRACE_ZONE_END(z0); + return status; +} + +static void iree_hal_static_library_loader_destroy( + iree_hal_executable_loader_t* base_executable_loader) { + iree_hal_static_library_loader_t* executable_loader = + (iree_hal_static_library_loader_t*)base_executable_loader; + iree_allocator_t host_allocator = executable_loader->host_allocator; + IREE_TRACE_ZONE_BEGIN(z0); + + iree_allocator_free(host_allocator, executable_loader); + + IREE_TRACE_ZONE_END(z0); +} + +static bool iree_hal_static_library_loader_query_support( + iree_hal_executable_loader_t* base_executable_loader, + iree_hal_executable_caching_mode_t caching_mode, + iree_string_view_t executable_format) { + return iree_string_view_equal(executable_format, + iree_make_cstring_view("static")); +} + +static iree_status_t iree_hal_static_library_loader_try_load( + iree_hal_executable_loader_t* base_executable_loader, + const iree_hal_executable_spec_t* executable_spec, + iree_hal_executable_t** out_executable) { + iree_hal_static_library_loader_t* executable_loader = + (iree_hal_static_library_loader_t*)base_executable_loader; + + // The executable data is just the name of the library. + iree_string_view_t library_name = + iree_make_string_view((const char*)executable_spec->executable_data.data, + executable_spec->executable_data.data_length); + + // Linear scan of the registered libraries; there's usually only one per + // module (aka source model) and as such it's a small list and probably not + // worth optimizing. We could sort the libraries list by name on loader + // creation to perform a binary-search fairly easily, though, at the cost of + // the additional code size. + for (iree_host_size_t i = 0; i < executable_loader->library_count; ++i) { + if (iree_string_view_equal( + library_name, + iree_make_cstring_view(executable_loader->libraries[i]->name))) { + return iree_hal_static_executable_create( + executable_loader->libraries[i], + executable_spec->executable_layout_count, + executable_spec->executable_layouts, + executable_loader->host_allocator, out_executable); + } + } + return iree_make_status(IREE_STATUS_NOT_FOUND, + "no static library with the name '%.*s' registered", + (int)library_name.size, library_name.data); +} + +static const iree_hal_executable_loader_vtable_t + iree_hal_static_library_loader_vtable = { + .destroy = iree_hal_static_library_loader_destroy, + .query_support = iree_hal_static_library_loader_query_support, + .try_load = iree_hal_static_library_loader_try_load, +}; diff --git a/iree/hal/local/loaders/static_library_loader.h b/iree/hal/local/loaders/static_library_loader.h new file mode 100644 index 000000000000..185872995277 --- /dev/null +++ b/iree/hal/local/loaders/static_library_loader.h @@ -0,0 +1,54 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef IREE_HAL_LOCAL_LOADERS_STATIC_LIBRARY_LOADER_H_ +#define IREE_HAL_LOCAL_LOADERS_STATIC_LIBRARY_LOADER_H_ + +#include +#include + +#include "iree/base/api.h" +#include "iree/hal/local/executable_library.h" +#include "iree/hal/local/executable_loader.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Creates a library loader that exposes the provided libraries to the HAL for +// use as executables. +// +// This loader will handle executable formats of 'static'. Version checks will +// ensure that the IREE compiler-produced static library version is one that the +// runtime can support. +// +// The name defined on each library will be used to lookup the executables and +// must match with the names used during compilation exactly. The +// iree_hal_executable_spec_t used to reference the executables will contain the +// library name and be used to lookup the library in the list. +// +// Multiple static library loaders can be registered in cases when several +// independent sets of libraries are linked in however duplicate names both +// within and across loaders will result in undefined behavior. +iree_status_t iree_hal_static_library_loader_create( + iree_host_size_t library_count, + const iree_hal_executable_library_header_t* const* libraries, + iree_allocator_t host_allocator, + iree_hal_executable_loader_t** out_executable_loader); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // IREE_HAL_LOCAL_LOADERS_STATIC_LIBRARY_LOADER_H_ From 9a80e0dec19131a9e7b08bc1bf0064af4fca1124 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 16 Apr 2021 13:37:33 -0700 Subject: [PATCH 4/9] Force all constants to be outlined and stored in the constant pool. #5493 can return the inlining behavior once we have better ringbuffer support and can measure the benefits. --- .../Dialect/Flow/Transforms/OutlineLargeConstants.cpp | 3 ++- iree/compiler/Dialect/Flow/Transforms/Passes.h | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp index f4fcc023b1ad..e7dcb7a9e17e 100644 --- a/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp +++ b/iree/compiler/Dialect/Flow/Transforms/OutlineLargeConstants.cpp @@ -130,7 +130,8 @@ static PassRegistration pass( "iree-flow-outline-large-constants", "Outlines large tensor constants into flow.variables at the module level.", [] { - return std::make_unique(kMinLargeConstantSize); + // TODO(#5493): add a flag for this. + return std::make_unique(256); }); } // namespace Flow diff --git a/iree/compiler/Dialect/Flow/Transforms/Passes.h b/iree/compiler/Dialect/Flow/Transforms/Passes.h index 1509791f8b74..54ab27d6546c 100644 --- a/iree/compiler/Dialect/Flow/Transforms/Passes.h +++ b/iree/compiler/Dialect/Flow/Transforms/Passes.h @@ -122,9 +122,9 @@ std::unique_ptr> createExportBenchmarkFuncsPass(); // Outlines large tensor constants into flow.variables at the module level. // -// NOTE: a total guess :) this feels like about the most per-dispatch-buffer -// data we'd want to embed in the command buffer. -static constexpr size_t kMinLargeConstantSize = 256; +// TODO(#5493): implement the support for inlining constants into the command +// buffer and raise this value to one that is measured to be good. +static constexpr size_t kMinLargeConstantSize = 1; std::unique_ptr> createOutlineLargeConstantsPass( size_t minLargeConstantSize = kMinLargeConstantSize); From 3525c5b0542ec3e850eab177f985e49200c0a0d9 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 16 Apr 2021 14:18:38 -0700 Subject: [PATCH 5/9] Fixing check test methods to use the buffer view byte length. Previously they were assuming the entire underlying storage buffer was equal to the valid byte length which caused problems now that we are packing multiple buffers together. --- iree/hal/buffer_view.h | 5 +++++ iree/modules/check/native_module.cc | 21 ++++++++++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/iree/hal/buffer_view.h b/iree/hal/buffer_view.h index fefe38375078..6d849feccbb4 100644 --- a/iree/hal/buffer_view.h +++ b/iree/hal/buffer_view.h @@ -147,6 +147,11 @@ iree_hal_buffer_view_release(iree_hal_buffer_view_t* buffer_view); // Returns the buffer underlying the buffer view. // The caller must retain the returned buffer if they want to continue using it. +// +// NOTE: the returned buffer length will almost always be larger than the valid +// bytes representing this buffer view due to padding. Always query the actual +// valid length with iree_hal_buffer_view_byte_length instead of assuming the +// buffer is already clamped. IREE_API_EXPORT iree_hal_buffer_t* IREE_API_CALL iree_hal_buffer_view_buffer(const iree_hal_buffer_view_t* buffer_view); diff --git a/iree/modules/check/native_module.cc b/iree/modules/check/native_module.cc index 64f9ac23d7a7..c68429643bf6 100644 --- a/iree/modules/check/native_module.cc +++ b/iree/modules/check/native_module.cc @@ -201,10 +201,11 @@ class CheckModuleState final { iree_hal_element_type_t element_type = iree_hal_buffer_view_element_type(view); iree_hal_buffer_t* buf = iree_hal_buffer_view_buffer(view); + iree_device_size_t size = iree_hal_buffer_view_byte_length(view); iree_hal_buffer_mapping_t mapped_memory; - IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( - buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &mapped_memory)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_map_range(buf, IREE_HAL_MEMORY_ACCESS_READ, + /*byte_offset=*/0, size, &mapped_memory)); IREE_RETURN_IF_ERROR( ::iree::ExpectAllTrue(mapped_memory.contents, element_type)); iree_hal_buffer_unmap_range(&mapped_memory); @@ -215,6 +216,8 @@ class CheckModuleState final { vm::ref rhs_ref) { auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); + + iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs); size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs); std::vector lhs_shape(lhs_rank); if (lhs_rank > 0) { @@ -222,6 +225,7 @@ class CheckModuleState final { iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr)); } + iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs); size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs); std::vector rhs_shape(rhs_rank); if (rhs_rank > 0) { @@ -238,12 +242,12 @@ class CheckModuleState final { iree_hal_buffer_mapping_t lhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( lhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory)); + /*byte_offset=*/0, lhs_size, &lhs_mapped_memory)); iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs); iree_hal_buffer_mapping_t rhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( rhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory)); + /*byte_offset=*/0, rhs_size, &rhs_mapped_memory)); bool element_types_eq = lhs_element_type == rhs_element_type; bool shape_eq = lhs_shape == rhs_shape; @@ -288,6 +292,8 @@ class CheckModuleState final { vm::ref rhs_ref) { auto* lhs = lhs_ref.get(); auto* rhs = rhs_ref.get(); + + iree_device_size_t lhs_size = iree_hal_buffer_view_byte_length(lhs); size_t lhs_rank = iree_hal_buffer_view_shape_rank(lhs); std::vector lhs_shape(lhs_rank); if (lhs_rank > 0) { @@ -295,6 +301,7 @@ class CheckModuleState final { iree_hal_buffer_view_shape(lhs, lhs_rank, lhs_shape.data(), nullptr)); } + iree_device_size_t rhs_size = iree_hal_buffer_view_byte_length(rhs); size_t rhs_rank = iree_hal_buffer_view_shape_rank(rhs); std::vector rhs_shape(rhs_rank); if (rhs_rank > 0) { @@ -311,12 +318,12 @@ class CheckModuleState final { iree_hal_buffer_mapping_t lhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( lhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &lhs_mapped_memory)); + /*byte_offset=*/0, lhs_size, &lhs_mapped_memory)); iree_hal_buffer_t* rhs_buf = iree_hal_buffer_view_buffer(rhs); iree_hal_buffer_mapping_t rhs_mapped_memory; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( rhs_buf, IREE_HAL_MEMORY_ACCESS_READ, - /*byte_offset=*/0, IREE_WHOLE_BUFFER, &rhs_mapped_memory)); + /*byte_offset=*/0, rhs_size, &rhs_mapped_memory)); bool element_types_eq = lhs_element_type == rhs_element_type; bool shape_eq = lhs_shape == rhs_shape; From 6d0aae70da21054cf9e288d43a093c46b0d1b8f5 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Fri, 16 Apr 2021 14:33:21 -0700 Subject: [PATCH 6/9] Revert "[ci] Allow parallel testing on NVIDIA GPUs" (#5496) I'm suspicious that this is the proximate cause of crashes we're seeing while destroying the Vulkan device. We think https://github.com/google/iree/issues/2659 and https://github.com/google/iree/issues/2900 are the root cause, but reverting for now to try to get us back to green. Example failure, starting at the PR being reverted: https://source.cloud.google.com/results/invocations/a3093172-ae55-43bb-956c-256bb726006f/targets/iree%2Fgcp_ubuntu%2Fcmake-bazel%2Flinux%2Fx86-turing%2Fmain/log Reverts google/iree#5467 --- .../kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh index c98889ee52f5..7eb92c337e01 100755 --- a/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh +++ b/build_tools/kokoro/gcp_ubuntu/cmake-bazel/linux/x86-turing/build.sh @@ -69,7 +69,9 @@ echo "Building with Ninja" cd "${CMAKE_BUILD_DIR?}" ninja -export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-$(nproc)} +# Limit parallelism dramatically to avoid exhausting GPU memory +# TODO(#5162): Handle this more robustly +export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-1} # Only test drivers that use the GPU, since we run all tests on non-GPU machines # as well. From 3c8880e9aa75c0d7363295702401b3705fd191f9 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 16 Apr 2021 15:16:26 -0700 Subject: [PATCH 7/9] Specify an alignment on `vm.rodata` and use it in flatbuffers. (#5494) This required a minor tweak to start_as_root the flatbuffer root tables prior to recording any data. This was discovered thanks to the gracious help (and patience) of of the flatcc author: https://github.com/dvidelabs/flatcc/issues/179 With this all of our binary blobs embedded into the flatbuffers are now 16-byte aligned with the option to tweak it further via a vm.rodata attribute. We don't need utf8 strings to be 16-byte aligned, for example. --- .../Dialect/HAL/Target/CUDA/CUDATarget.cpp | 7 +-- .../Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp | 5 ++- .../Target/MetalSPIRV/MetalSPIRVTarget.cpp | 2 +- .../Dialect/HAL/Target/VMLA/VMLATarget.cpp | 5 ++- .../Target/VulkanSPIRV/VulkanSPIRVTarget.cpp | 5 ++- .../Dialect/VM/Conversion/ImportUtils.cpp | 3 +- iree/compiler/Dialect/VM/IR/VMOps.td | 18 ++++++-- .../Target/Bytecode/BytecodeModuleTarget.cpp | 19 +++++++- .../VM/Target/Bytecode/ConstantEncoder.cpp | 43 ++++++++++--------- .../VM/Target/Bytecode/ConstantEncoder.h | 1 + .../VM/Transforms/HoistInlinedRodata.cpp | 3 ++ iree/compiler/Utils/FlatbufferUtils.cpp | 4 +- iree/compiler/Utils/FlatbufferUtils.h | 2 +- 13 files changed, 76 insertions(+), 41 deletions(-) diff --git a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp index 02cb2ffbc708..e98d51765cac 100644 --- a/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp @@ -146,10 +146,12 @@ class CUDATargetBackend final : public TargetBackend { llvmModule->setDataLayout(targetMachine->createDataLayout()); - std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine); + FlatbufferBuilder builder; + iree_CUDAExecutableDef_start_as_root(builder); + // Serialize cuda kernel into the binary that we will embed in the // final flatbuffer. - FlatbufferBuilder builder; + std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine); auto ptxCudeRef = flatbuffers_uint8_vec_create( builder, reinterpret_cast(targetISA.c_str()), targetISA.size()); @@ -168,7 +170,6 @@ class CUDATargetBackend final : public TargetBackend { } auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder); - iree_CUDAExecutableDef_start_as_root(builder); iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef); iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef); iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef); diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp index 38cab2dda692..97ab76d13ee4 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp @@ -305,9 +305,11 @@ class LLVMAOTTargetBackend final : public TargetBackend { linkArtifacts.keepAllFiles(); } + FlatbufferBuilder builder; + iree_DyLibExecutableDef_start_as_root(builder); + // Embed debug symbols at the end of the flatbuffer by adding first in the // bottoms-up builder. - FlatbufferBuilder builder; flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0; flatbuffers_string_ref_t debugDatabaseFilenameRef = 0; if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) { @@ -328,7 +330,6 @@ class LLVMAOTTargetBackend final : public TargetBackend { << linkArtifacts.libraryFile.path; } - iree_DyLibExecutableDef_start_as_root(builder); iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef); iree_DyLibExecutableDef_debug_database_filename_add( builder, debugDatabaseFilenameRef); diff --git a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp index a9e5fe879c01..3bb9b7fafdf9 100644 --- a/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/MetalSPIRV/MetalSPIRVTarget.cpp @@ -121,6 +121,7 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend { // 4. Pack the MTLLibrary and metadata into a flatbuffer. FlatbufferBuilder builder; + iree_MetalExecutableDef_start_as_root(builder); auto shaderSourcesRef = builder.createStringVec(llvm::map_range( mslShaders, [&](const MetalShader &shader) { return shader.source; })); @@ -135,7 +136,6 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend { auto entryPointNamesRef = builder.createStringVec(entryPointNames); - iree_MetalExecutableDef_start_as_root(builder); iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef); iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef); iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef); diff --git a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp index 8e00d15b2365..6cd7fc550f02 100644 --- a/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp @@ -102,8 +102,10 @@ class VMLATargetBackend final : public TargetBackend { LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp, OpBuilder &executableBuilder) override { - // Serialize the VM module to bytes directly into a flatbuffer. FlatbufferBuilder builder; + iree_VMLAExecutableDef_start_as_root(builder); + + // Serialize the VM module to bytes directly into a flatbuffer. IREE::VM::BytecodeTargetOptions bytecodeOptions; auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) { return succeeded(translateModuleToBytecode(targetOp.getInnerModule(), @@ -115,7 +117,6 @@ class VMLATargetBackend final : public TargetBackend { // Pack the executable definition and get the bytes with the proper header. // The header is used to verify the contents at runtime. - iree_VMLAExecutableDef_start_as_root(builder); iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef); iree_VMLAExecutableDef_end_as_root(builder); diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp index de7c23794aca..e1a556e25d40 100644 --- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp @@ -128,9 +128,11 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend { ModuleOp innerModuleOp = targetOp.getInnerModule(); auto spvModuleOp = *innerModuleOp.getOps().begin(); + FlatbufferBuilder builder; + iree_SpirVExecutableDef_start_as_root(builder); + // Serialize the spirv::ModuleOp into the binary that we will embed in the // final flatbuffer. - FlatbufferBuilder builder; SmallVector spvBinary; if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) { return targetOp.emitError() << "failed to serialize spv.module"; @@ -157,7 +159,6 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend { } auto entryPointsRef = builder.createStringVec(entryPointNames); - iree_SpirVExecutableDef_start_as_root(builder); iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef); iree_SpirVExecutableDef_code_add(builder, spvCodeRef); iree_SpirVExecutableDef_end_as_root(builder); diff --git a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index f5dac0305c55..a495ebfd807c 100644 --- a/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -117,7 +117,8 @@ static Value createStringTableValue(Location loc, StringAttr attrValue, return rewriter.create( loc, IREE::VM::RefType::get(IREE::ByteBufferType::get(rewriter.getContext())), - rewriter.getStringAttr(safeIdentifier), utf8Bytes); + rewriter.getStringAttr(safeIdentifier), utf8Bytes, + /*alignment=*/rewriter.getI64IntegerAttr(1)); } size_t getSegmentSpanSize(Type spanType) { diff --git a/iree/compiler/Dialect/VM/IR/VMOps.td b/iree/compiler/Dialect/VM/IR/VMOps.td index 82a344febe7a..408f4e646e09 100644 --- a/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/iree/compiler/Dialect/VM/IR/VMOps.td @@ -754,18 +754,26 @@ def VM_RodataOp : VM_Op<"rodata", [ value leaves the module. For example, returning rodata from an exported function must keep the data (possibly backed by mmap) valid for its entire lifetime. + + By default all rodata will be aligned in the final module output at a + 16-byte granularity. An optional alignment can be specified to override the + default for cases where larger or smaller alignments are needed. }]; let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value, + OptionalAttr:$alignment, OptionalAttr:$ordinal ); let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "StringRef":$name, "ElementsAttr":$value, - CArg<"ArrayRef", "{}">:$attrs)>, + OpBuilder<(ins + "StringRef":$name, + "ElementsAttr":$value, + CArg<"ArrayRef", "{}">:$attrs + )>, ]; } @@ -810,12 +818,14 @@ def VM_RodataInlineOp : VM_PureOp<"rodata.inline", [ ]> { let summary = [{inlined constant rodata}]; let description = [{ - vm.rodata that can be embedded inline in functions. + vm.rodata that can be embedded inline in functions. See vm.rodata for more + information. }]; let arguments = (ins OptionalAttr:$name, - ElementsAttr:$value + ElementsAttr:$value, + OptionalAttr:$alignment ); let results = (outs diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp index 13261cca816b..a23e9d001335 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/BytecodeModuleTarget.cpp @@ -275,6 +275,11 @@ static iree_vm_FunctionSignatureDef_ref_t makeInternalFunctionSignatureDef( static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, IREE::VM::ModuleOp moduleOp, FlatbufferBuilder &fbb) { + // Start the buffer so that we can begin recording data prior to the root + // table (which we do at the very end). This does not change the layout of the + // file and is only used to prime the flatcc builder. + iree_vm_BytecodeModuleDef_start_as_root(fbb); + SymbolTable symbolTable(moduleOp); if (!moduleOp.ordinal_counts().hasValue()) { return moduleOp.emitError() << "ordinal_counts attribute not found. The " @@ -315,9 +320,20 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, // layout planning by preserving the order in the IR is useful. SmallVector rodataContentRefs; rodataContentRefs.reserve(rodataOps.size()); + + // All constants are defaulted to 16-byte aligned as that is the maximum + // (reasonable) alignment of all data types on all platforms. This can be + // overridden by creators of the rodata with the `alignment` attribute. + static constexpr int kDefaultRodataAlignment = 16; + for (auto rodataOp : llvm::reverse(rodataOps)) { + size_t alignment = + rodataOp.alignment() + ? static_cast(rodataOp.alignment().getValue()) + : 0; + if (alignment == 0) alignment = kDefaultRodataAlignment; auto rodataRef = - serializeConstant(rodataOp.getLoc(), rodataOp.value(), fbb); + serializeConstant(rodataOp.getLoc(), rodataOp.value(), alignment, fbb); if (!rodataRef) { return rodataOp.emitOpError() << "failed to encode"; } @@ -461,7 +477,6 @@ static LogicalResult buildFlatBufferModule(BytecodeTargetOptions targetOptions, auto moduleNameRef = fbb.createString( moduleOp.sym_name().empty() ? "module" : moduleOp.sym_name()); - iree_vm_BytecodeModuleDef_start_as_root(fbb); iree_vm_BytecodeModuleDef_name_add(fbb, moduleNameRef); iree_vm_BytecodeModuleDef_types_add(fbb, typesRef); iree_vm_BytecodeModuleDef_imported_functions_add(fbb, importFuncsRef); diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp index b33908bdfe6f..91a5fb26170a 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp +++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.cpp @@ -26,11 +26,11 @@ namespace VM { // TODO(benvanik): switch to LLVM's BinaryStreamWriter to handle endianness. static flatbuffers_uint8_vec_ref_t serializeConstantI8Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { // vm.rodata and other very large constants end up as this; since i8 is i8 // everywhere (endianness doesn't matter when you have one byte :) we can // directly access the data and memcpy. - flatbuffers_uint8_vec_start(fbb); + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(int8_t)); if (attr.isSplat()) { @@ -47,8 +47,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI8Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI16Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int16_t)); uint16_t *nativePtr = reinterpret_cast(bytePtr); @@ -59,8 +59,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI16Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI32Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int32_t)); uint32_t *nativePtr = reinterpret_cast(bytePtr); @@ -71,8 +71,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI32Array( } static flatbuffers_uint8_vec_ref_t serializeConstantI64Array( - DenseIntElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseIntElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(int64_t)); uint64_t *nativePtr = reinterpret_cast(bytePtr); @@ -83,8 +83,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantI64Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF32Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(float)); float *nativePtr = reinterpret_cast(bytePtr); @@ -95,8 +95,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF32Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF64Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend(fbb, attr.getNumElements() * sizeof(double)); double *nativePtr = reinterpret_cast(bytePtr); @@ -107,8 +107,8 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF64Array( } static flatbuffers_uint8_vec_ref_t serializeConstantF16Array( - DenseFPElementsAttr attr, FlatbufferBuilder &fbb) { - flatbuffers_uint8_vec_start(fbb); + DenseFPElementsAttr attr, size_t alignment, FlatbufferBuilder &fbb) { + flatcc_builder_start_vector(fbb, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); uint8_t *bytePtr = flatbuffers_uint8_vec_extend( fbb, attr.getNumElements() * sizeof(uint16_t)); uint16_t *nativePtr = reinterpret_cast(bytePtr); @@ -121,17 +121,18 @@ static flatbuffers_uint8_vec_ref_t serializeConstantF16Array( flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, ElementsAttr elementsAttr, + size_t alignment, FlatbufferBuilder &fbb) { if (auto attr = elementsAttr.dyn_cast()) { switch (attr.getType().getElementTypeBitWidth()) { case 8: - return serializeConstantI8Array(attr, fbb); + return serializeConstantI8Array(attr, alignment, fbb); case 16: - return serializeConstantI16Array(attr, fbb); + return serializeConstantI16Array(attr, alignment, fbb); case 32: - return serializeConstantI32Array(attr, fbb); + return serializeConstantI32Array(attr, alignment, fbb); case 64: - return serializeConstantI64Array(attr, fbb); + return serializeConstantI64Array(attr, alignment, fbb); default: emitError(loc) << "unhandled element bitwidth " << attr.getType().getElementTypeBitWidth(); @@ -140,11 +141,11 @@ flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, } else if (auto attr = elementsAttr.dyn_cast()) { switch (attr.getType().getElementTypeBitWidth()) { case 16: - return serializeConstantF16Array(attr, fbb); + return serializeConstantF16Array(attr, alignment, fbb); case 32: - return serializeConstantF32Array(attr, fbb); + return serializeConstantF32Array(attr, alignment, fbb); case 64: - return serializeConstantF64Array(attr, fbb); + return serializeConstantF64Array(attr, alignment, fbb); default: emitError(loc) << "unhandled element bitwidth " << attr.getType().getElementTypeBitWidth(); diff --git a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h index 56471a633fff..94dca8186ae7 100644 --- a/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h +++ b/iree/compiler/Dialect/VM/Target/Bytecode/ConstantEncoder.h @@ -28,6 +28,7 @@ namespace VM { // Serializes a constant attribute to the FlatBuffer as a binary blob. flatbuffers_uint8_vec_ref_t serializeConstant(Location loc, ElementsAttr elementsAttr, + size_t alignment, FlatbufferBuilder &fbb); } // namespace VM diff --git a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp index 240a507f5004..7aa741dd3bd0 100644 --- a/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp +++ b/iree/compiler/Dialect/VM/Transforms/HoistInlinedRodata.cpp @@ -60,6 +60,9 @@ class HoistInlinedRodataPass auto rodataOp = OpBuilder(moduleOp.getContext()) .create(inlineOp.getLoc(), name, inlineOp.value()); + if (inlineOp.alignmentAttr()) { + rodataOp.alignmentAttr(inlineOp.alignmentAttr()); + } moduleSymbolTable.insert(rodataOp, moduleBuilder.getInsertionPoint()); rodataOp.setPrivate(); replaceInlineOpWithRodataRef(inlineOp, rodataOp); diff --git a/iree/compiler/Utils/FlatbufferUtils.cpp b/iree/compiler/Utils/FlatbufferUtils.cpp index e98167b75d40..90d03be0f1d7 100644 --- a/iree/compiler/Utils/FlatbufferUtils.cpp +++ b/iree/compiler/Utils/FlatbufferUtils.cpp @@ -44,8 +44,8 @@ FlatbufferBuilder::FlatbufferBuilder() { flatcc_builder_init(&builder); } FlatbufferBuilder::~FlatbufferBuilder() { flatcc_builder_clear(&builder); } flatbuffers_uint8_vec_ref_t FlatbufferBuilder::streamUint8Vec( - std::function fn) { - flatbuffers_uint8_vec_start(*this); + std::function fn, size_t alignment) { + flatcc_builder_start_vector(*this, 1, alignment, FLATBUFFERS_COUNT_MAX(1)); raw_flatbuffer_uint8_vec_ostream stream(*this); if (!fn(stream)) { return 0; diff --git a/iree/compiler/Utils/FlatbufferUtils.h b/iree/compiler/Utils/FlatbufferUtils.h index 524cf63f7725..783bffb910ed 100644 --- a/iree/compiler/Utils/FlatbufferUtils.h +++ b/iree/compiler/Utils/FlatbufferUtils.h @@ -108,7 +108,7 @@ class FlatbufferBuilder { // my_type_uint8_vec_field_add(builder, ref); // use vec reference // ... flatbuffers_uint8_vec_ref_t streamUint8Vec( - std::function fn); + std::function fn, size_t alignment = 16); // Captures the current contents of the flatbuffer builder and returns them // as a shaped `vector` dense attr. The builder is left unmodified. From 94e547f48cce33b266863eb13eb29149a2b5954f Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Fri, 16 Apr 2021 15:28:21 -0700 Subject: [PATCH 8/9] Adding a note about where executable imports would be specified. (#5498) --- iree/hal/local/task_command_buffer.c | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/iree/hal/local/task_command_buffer.c b/iree/hal/local/task_command_buffer.c index 15d0b1fd7fad..d6312f5f64c3 100644 --- a/iree/hal/local/task_command_buffer.c +++ b/iree/hal/local/task_command_buffer.c @@ -763,6 +763,12 @@ static iree_status_t iree_hal_task_command_buffer_build_dispatch( workgroup_size, workgroup_count, &cmd->task); iree_hal_executable_dispatch_state_v0_t* state = &cmd->state; + + // When we support imports we can populate those here based on what the + // executable declared (as each executable may import a unique set of + // functions). + state->imports = NULL; + memcpy(&state->workgroup_size, workgroup_size, sizeof(iree_hal_vec3_t)); memcpy(&state->workgroup_count, workgroup_count, sizeof(iree_hal_vec3_t)); From afc0fa24601437cf154f5cbf4845ce315b645407 Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Fri, 16 Apr 2021 17:30:13 -0700 Subject: [PATCH 9/9] Remove inliner from linalg on tensors pass (#5489) - This is currently a NOP --- iree/compiler/Conversion/LinalgToLLVM/Passes.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp index 99732bda062f..a1eadcd9ccc4 100644 --- a/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp +++ b/iree/compiler/Conversion/LinalgToLLVM/Passes.cpp @@ -63,7 +63,6 @@ void buildLLVMTransformPassPipeline(OpPassManager &passManager, if (options.usingLinalgOnTensors) { passManager.addPass(createMaterializeCPULaunchConfigurationPass()); OpPassManager &nestedModulePM = passManager.nest(); - nestedModulePM.addPass(createInlinerPass()); // TODO(ataei): We want to enable when tensor -> vector pass is fully // supported which requires first moving vector-tiling before this step. if (options.useLinalgOnTensorsToVectors) {