From 80beae81800229102da60667f324e62a9ea199ba Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Tue, 1 Feb 2022 21:05:51 -0500 Subject: [PATCH] Category mapper: codegen (#1130) Codegen for CategoryMapper Signed-off-by: Ettore Tiotto --- src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp | 413 +++++++++++------- .../ONNXToKrnl/ML/CategoryMapper.cpp | 4 +- src/Dialect/ONNX/ONNXOps.cpp | 3 - src/Runtime/OMIndexLookup.inc | 8 +- src/Runtime/PyExecutionSession.cpp | 63 ++- test/mlir/krnl/krnl_category_mapper.mlir | 111 +++-- .../onnx/onnx_lowering_category_mapper.mlir | 8 +- 7 files changed, 376 insertions(+), 234 deletions(-) diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp index 7c9fa3499a50..0add8b209f2f 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp @@ -53,47 +53,60 @@ using namespace mlir; namespace { -static onnx::TensorProto::DataType llvmTypeToOnnxType(mlir::Type elemType) { - if (elemType.isa()) - return onnx::TensorProto::FLOAT; - if (elemType.isUnsignedInteger(8)) - return onnx::TensorProto::UINT8; - if (elemType.isSignedInteger(8)) - return onnx::TensorProto::INT8; - if (elemType.isUnsignedInteger(16)) - return onnx::TensorProto::UINT16; - if (elemType.isSignedInteger(16)) - return onnx::TensorProto::INT16; - if (elemType.isSignedInteger(32)) - return onnx::TensorProto::INT32; - if (elemType.isSignedInteger(64)) - return onnx::TensorProto::INT64; - if (elemType.isa()) - return onnx::TensorProto::STRING; - if (elemType.isa()) - return onnx::TensorProto::FLOAT16; - if (elemType.isa()) - return onnx::TensorProto::DOUBLE; - if (elemType.isUnsignedInteger(32)) - return onnx::TensorProto::UINT32; - if (elemType.isUnsignedInteger(64)) - return onnx::TensorProto::INT64; - // LLVM Dialect does not have signed/unsigned int, only signless int - if (auto llvmIntType = elemType.dyn_cast()) { - if (llvmIntType.getWidth() == 1) - return onnx::TensorProto::BOOL; - if (llvmIntType.getWidth() == 8) - return onnx::TensorProto::INT8; - if (llvmIntType.getWidth() == 16) - return onnx::TensorProto::INT16; - if (llvmIntType.getWidth() == 32) - return onnx::TensorProto::INT32; - if (llvmIntType.getWidth() == 64) - return onnx::TensorProto::INT64; +// Convert an MLIR type to the correspoding ONNX type. +static onnx::TensorProto::DataType mlirTypeToOnnxType(mlir::Type elemType) { + onnx::TensorProto::DataType onnxType = onnx::TensorProto::UNDEFINED; + + TypeSwitch(elemType) + .Case( + [&](mlir::BFloat16Type) { onnxType = onnx::TensorProto::BFLOAT16; }) + .Case([&](mlir::ComplexType type) { + if (type.getElementType().isa()) + onnxType = onnx::TensorProto::COMPLEX64; + else if (type.getElementType().isa()) + onnxType = onnx::TensorProto::COMPLEX128; + }) + .Case( + [&](mlir::Float16Type) { onnxType = onnx::TensorProto::FLOAT16; }) + .Case( + [&](mlir::Float32Type) { onnxType = onnx::TensorProto::FLOAT; }) + .Case( + [&](mlir::Float64Type) { onnxType = onnx::TensorProto::DOUBLE; }) + .Case([&](mlir::IntegerType type) { + switch (type.getWidth()) { + case 1: + // only a signless type can be a bool. + onnxType = (type.isSigned() || type.isUnsigned()) + ? onnx::TensorProto::UNDEFINED + : onnx::TensorProto::BOOL; + break; + case 8: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8 + : onnx::TensorProto::INT8; + break; + case 16: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT16 + : onnx::TensorProto::INT16; + break; + case 32: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT32 + : onnx::TensorProto::INT32; + break; + case 64: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT64 + : onnx::TensorProto::INT64; + break; + } + }) + .Case( + [&](mlir::StringType) { onnxType = onnx::TensorProto::STRING; }); + + if (onnxType == onnx::TensorProto::UNDEFINED) { + elemType.dump(); + llvm_unreachable("MLIR type cannot be converted to ONNX type"); } - // Complex types don't seem to exist in LLVM Dialect. - elemType.dump(); - llvm_unreachable("Unexpected LLVM type, cannot be converted to ONNX type."); + + return onnxType; } static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, @@ -418,92 +431,145 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); - Location loc = op->getLoc(); - auto krnlGlobalOp = llvm::dyn_cast(op); - IntegerAttr alignmentAttr = krnlGlobalOp.alignmentAttr(); - - ModuleOp module = op->getParentOfType(); - StringRef name = krnlGlobalOp.name(); - - // Compute total number of elements. - auto shape = (krnlGlobalOp.shape()).dyn_cast(); - int64_t numElements = 1; - for (unsigned int i = 0; i < shape.size(); ++i) - numElements *= ArrayAttrIntVal(shape, i); - - // Create the global at the entry of the module. - LLVM::GlobalOp global; - auto type = op->getResult(0).getType(); - auto memRefTy = type.cast(); // The element type of the array. - auto constantElementType = + const auto type = op->getResult(0).getType(); + const auto memRefTy = type.cast(); + const auto constantElementType = typeConverter->convertType(memRefTy.getElementType()); auto globalType = constantElementType; - // The llvm type of the global (example: [2 x [8 x float]]) - if (shape.empty()) { + // The llvm type of the global (example: [2 x [8 x float]]). + const auto shape = (krnlGlobalOp.shape()).dyn_cast(); + if (shape.empty()) globalType = LLVM::LLVMArrayType::get(globalType.cast(), 1); - } else { + else { for (int i = shape.size() - 1; i >= 0; i--) globalType = LLVM::LLVMArrayType::get( globalType.cast(), ArrayAttrIntVal(shape, i)); } - auto llvmGlobalType = globalType.cast(); - if (!krnlGlobalOp.value().hasValue()) - llvm_unreachable("Krnl Global must always have a value"); + // Create the global at the entry of the module. + assert(krnlGlobalOp.value().hasValue() && + "Krnl Global must always have a value"); + auto value = krnlGlobalOp.value().getValue(); + LLVM::GlobalOp global; + TypeSwitch(value) + .Case([&](OpaqueElementsAttr attr) { + global = lowerOpaqueConstant(krnlGlobalOp, globalType, rewriter); + }) + .Case([&](DenseElementsAttr attr) { + global = lowerDenseConstant(krnlGlobalOp, globalType, rewriter); + }) + .Default([&](Attribute attr) { + llvm_unreachable("Unsupported attribute type"); + }); + + // Set the global alignment based on the alignment attribute if it exists, + // otherwise use the module datalayout info. + setAlignment(global, krnlGlobalOp.alignmentAttr(), + krnlGlobalOp->getParentOfType(), rewriter); + + // Prepare data to be inserted into a MemRefDescriptor (a struct). + Value globalOpAddr = + rewriter.create(krnlGlobalOp.getLoc(), global); + MemRefDescriptor memRefDescr = createMemRefDescriptor( + globalOpAddr, memRefTy, krnlGlobalOp.getLoc(), rewriter); + + rewriter.replaceOp(op, {memRefDescr}); + + return success(); + } + +private: + static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { + return (a.getValue()[i]).cast().getInt(); + } + + // LLVM::GlobalOp does not support OpaqueElementsAttr. + // Both StringAttr and OpaqueElementsAttr use StringRef for internal data + // array. Thus, it looks safe to use StringAtrr instead of + // OpaqueElementsAttr. + LLVM::GlobalOp lowerOpaqueConstant(KrnlGlobalOp &krnlGlobalOp, + Type globalType, ConversionPatternRewriter &rewriter) const { + assert(krnlGlobalOp.value().hasValue() && + "Expecting KrnlGlobalOp with a valid value"); + assert(krnlGlobalOp.value().getValue().isa() && + "Expecting a global with an opaque elements attribute"); + + MLIRContext *context = krnlGlobalOp.getContext(); + Location loc = krnlGlobalOp.getLoc(); + ModuleOp module = krnlGlobalOp->getParentOfType(); + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + StringRef data = + krnlGlobalOp.value().getValue().cast().getValue(); + // Check data size. + int64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); + assert(((int64_t)data.size() == sizeInBytes) && "Data size mismatch."); + + StringAttr llvmStringAttr = StringAttr::get(context, data); + auto llvmArrayI8Ty = + LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes); + LLVM::GlobalOp global = rewriter.create(loc, llvmArrayI8Ty, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + llvmStringAttr); - int64_t sizeInBytes = numElements * getMemRefEltSizeInBytes(memRefTy); - { - OpBuilder::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); + LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); + return global; + } + + LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType, + ConversionPatternRewriter &rewriter) const { + assert(krnlGlobalOp.value().hasValue() && + "Expecting KrnlGlobalOp with a valid value"); + assert(krnlGlobalOp.value().getValue().isa() && + "Expecting a global with an dense elements attribute"); + + MLIRContext *context = krnlGlobalOp.getContext(); + Location loc = krnlGlobalOp.getLoc(); + ModuleOp module = krnlGlobalOp->getParentOfType(); + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + DenseElementsAttr denseAttr = + krnlGlobalOp.value().getValue().cast(); + + int64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); + LLVM::GlobalOp global; + if ((!denseAttr.isSplat()) && (sizeInBytes > 1024)) { + ArrayRef rawData = denseAttr.getRawData(); + assert(((int64_t)rawData.size() == sizeInBytes) && "Data size mismatch."); + StringRef data(rawData.data(), rawData.size()); + StringAttr llvmStringAttr = StringAttr::get(context, data); auto llvmArrayI8Ty = LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes); - if (krnlGlobalOp.value().getValue().isa()) { - // LLVM::GlobalOp does not support OpaqueElementsAttr. - // Both StringAttr and OpaqueElementsAttr use StringRef for internal - // data array. Thus, it looks safe to use StringAtrr instead of - // OpaqueElementsAttr. - StringRef data = krnlGlobalOp.value() - .getValue() - .cast() - .getValue(); - // Check data size. - assert(((int64_t)data.size() == sizeInBytes) && "Data size mismatch."); - - StringAttr llvmStringAttr = StringAttr::get(context, data); - global = rewriter.create(loc, llvmArrayI8Ty, - /*isConstant=*/true, LLVM::Linkage::Internal, name, llvmStringAttr); - } else if (krnlGlobalOp.value().getValue().isa()) { - DenseElementsAttr denseAttr = - krnlGlobalOp.value().getValue().cast(); - if ((!denseAttr.isSplat()) && (sizeInBytes > 1024)) { - std::vector rawData = denseAttr.getRawData(); - // Check data size. - assert(((int64_t)rawData.size() == sizeInBytes) && - "Data size mismatch."); - - StringRef data = StringRef((char *)rawData.data(), rawData.size()); - StringAttr llvmStringAttr = StringAttr::get(context, data); - global = rewriter.create(loc, llvmArrayI8Ty, - /*isConstant=*/true, LLVM::Linkage::Internal, name, - llvmStringAttr); - } else { - global = rewriter.create(loc, llvmGlobalType, - /*isConstant=*/true, LLVM::Linkage::Internal, name, - krnlGlobalOp.value().getValue()); - } - } else - llvm_unreachable("Unsupported attribute type"); + global = rewriter.create(loc, llvmArrayI8Ty, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + llvmStringAttr); + } else { + if (denseAttr.getElementType().isa()) + global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter); + else + global = rewriter.create(loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + krnlGlobalOp.value().getValue()); } - // If the operation has a valid alignment attribute use it, otherwise - // attempt to set the alignment based on the module datalayout (if it - // exists). + // LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); + return global; + } + + // If the operation has a valid alignment attribute use it, otherwise + // attempt to set the alignment based on the module datalayout (if it + // exists). + void setAlignment(LLVM::GlobalOp &global, IntegerAttr alignmentAttr, + ModuleOp module, OpBuilder &builder) const { if (alignmentAttr && alignmentAttr.getValue().getSExtValue() != 0) global.setAlignmentAttr(alignmentAttr); else if (module->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) { @@ -513,23 +579,22 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { .getPreferredAlignment(global.getType(), getTypeConverter()->getDataLayout()); align = std::max(align, MinGlobalAlign); - global.setAlignmentAttr(rewriter.getI64IntegerAttr(align)); + global.setAlignmentAttr(builder.getI64IntegerAttr(align)); } else - global.setAlignmentAttr(rewriter.getI64IntegerAttr(MinGlobalAlign)); - - // Prepare data to be inserted into a MemRefDescriptor (a struct). - Value globalOpAddr = rewriter.create(loc, global); - MemRefDescriptor memRefDescr = - createMemRefDescriptor(globalOpAddr, memRefTy, loc, rewriter); + global.setAlignmentAttr(builder.getI64IntegerAttr(MinGlobalAlign)); + } - rewriter.replaceOp(op, {memRefDescr}); + int64_t computeSizeInBytes(KrnlGlobalOp &krnlGlobalOp) const { + // Compute total number of elements. + const auto shape = (krnlGlobalOp.shape()).dyn_cast(); + int64_t numElements = 1; + for (unsigned int i = 0; i < shape.size(); ++i) + numElements *= ArrayAttrIntVal(shape, i); - return success(); - } + const auto type = krnlGlobalOp.getResult().getType(); + const auto memRefTy = type.cast(); -private: - static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { - return (a.getValue()[i]).cast().getInt(); + return numElements * getMemRefEltSizeInBytes(memRefTy); } // Store the given address into a MemRefDescriptor (a struct). @@ -550,48 +615,64 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { // Generate a global string for each krnlGlobalOp string value, and store // the address of the global strings into an array. Return the array address. - Value lowerStringLiteral( - KrnlGlobalOp &krnlGlobalOp, OpBuilder &builder) const { + LLVM::GlobalOp lowerStringLiteral( + KrnlGlobalOp &krnlGlobalOp, Type globalType, OpBuilder &builder) const { assert(krnlGlobalOp.value().getValue().isa() && "Expecting a dense value"); Location loc = krnlGlobalOp.getLoc(); ModuleOp module = krnlGlobalOp->getParentOfType(); - DenseElementsAttr value = + DenseElementsAttr denseAttr = krnlGlobalOp.value().getValue().cast(); Type i8Type = IntegerType::get(builder.getContext(), 8); - Type i32Type = IntegerType::get(builder.getContext(), 32); Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); - // Generate LLVM GlobalOps for each string in the KrnlGlobalOp dense value. + int64_t numStrings = denseAttr.getValues().size(); + if (numStrings == 1) { + StringRef str = *denseAttr.getValues().begin(); + LLVM::GlobalOp global = + getOrCreateGlobalString(str, loc, builder, module); + + // return builder.create(loc, globalType, + // /*isConstant=*/true, LLVM::Linkage::Internal, + // krnlGlobalOp.name(), + // StringAttr::get(builder.getContext(), str)); + return global; + } + + // Generate LLVM GlobalOps for each string in the KrnlGlobalOp dense + // attribute. SmallVector globalOps; - for (StringRef str : value.getValues()) { + for (StringRef str : denseAttr.getValues()) { LLVM::GlobalOp globalOp = getOrCreateGlobalString(str, loc, builder, module); globalOps.push_back(globalOp); } - // Allocate memory for an array (to hold the address of the global strings). - auto cstNumElems = builder.create( - loc, i32Type, builder.getI32IntegerAttr(globalOps.size())); - Value alloca = builder.create(loc, - LLVM::LLVMPointerType::get(i8PtrType), cstNumElems, /*alignment=*/16); + // Generate an LLVM GlobalOps with an initializer region containing one + // block. + auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, globalOps.size()); + auto global = builder.create(loc, arrayType, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + Attribute()); + Region ®ion = global.getInitializerRegion(); + Block *block = builder.createBlock(®ion); + + // Initialize an array with the addresses of the global strings. + builder.setInsertionPoint(block, block->begin()); + Value array = builder.create(loc, arrayType); - // Store the address of the global strings into the array. int32_t index = 0; + Value lastValue = array; for (const LLVM::GlobalOp &globalOp : globalOps) { LLVM::GEPOp strAddr = getPtrToGlobalString(globalOp, loc, builder); - Type llvmIndexType = typeConverter->convertType(builder.getIndexType()); - Value cstIndex = builder.create( - loc, llvmIndexType, builder.getIndexAttr(index++)); - LLVM::GEPOp arrayElemAddr = builder.create(loc, - LLVM::LLVMPointerType::get(i8PtrType), alloca, - ArrayRef{cstIndex}); - builder.create(loc, strAddr, arrayElemAddr); + lastValue = builder.create(loc, arrayType, lastValue, + strAddr, builder.getArrayAttr({builder.getIndexAttr(index++)})); } - return alloca; + builder.create(loc, ArrayRef({lastValue})); + return global; } // Return the GlobalOp for the given string, creating one if not found. @@ -606,8 +687,9 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern { Type i8Type = IntegerType::get(builder.getContext(), 8); Type type = LLVM::LLVMArrayType::get(i8Type, str.size()); global = builder.create(loc, type, /*isConstant=*/true, - LLVM::Linkage::Internal, str, builder.getStringAttr(str), - /*alignment=*/0); + LLVM::Linkage::Internal, str, builder.getStringAttr(str)); + + setAlignment(global, nullptr, module, builder); } return global; } @@ -970,8 +1052,8 @@ class KrnlUnaryMathOpLowering : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto *context = op->getContext(); - auto loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); // get the LLVM type for the function args and result mlir::Type inType = op->getOperand(0).getType(); @@ -1142,13 +1224,13 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { SmallVector staticInputs; auto wrappedInput = entryPointEntryBlock.getArgument(0); - auto omTensorPtrArr = + Value omTensorPtrArr = callApi(rewriter, loc, apiRegistry, API::GET_OMT_ARRAY, {wrappedInput}); auto one = rewriter.create( loc, int64Ty, rewriter.getI64IntegerAttr(1)); // Create a memref type for the return argument of the iface call - auto memRefOutPtrTy = staticEntryPointTy.getParamType(0); + Type memRefOutPtrTy = staticEntryPointTy.getParamType(0); Value ptrToOutMemRef = rewriter.create(loc, memRefOutPtrTy, one, /*alignment=*/0); @@ -1244,7 +1326,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { auto outMemRefRank = getRankFromMemRefType(outMemRefTy); auto outMemRefRankVal = rewriter.create( loc, int64Ty, rewriter.getI64IntegerAttr(outMemRefRank)); - auto outOMTensor = callApi( + Value outOMTensor = callApi( rewriter, loc, apiRegistry, API::CREATE_OMTENSOR, {outMemRefRankVal}); // If output is a constant tensor, OMTensor does not own it. auto outOwning = constantOutputs[i] ? 0 : 1; @@ -1266,7 +1348,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { } // Create wrapped output. - auto wrappedOutput = callApi(rewriter, loc, apiRegistry, + Value wrappedOutput = callApi(rewriter, loc, apiRegistry, API::CREATE_OMTENSOR_LIST, {outOmtPtrsArr, numOutput, one}); // Return wrapped output. @@ -1365,7 +1447,7 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { Value memRef = rewriter.create(loc, memRefTy); // Set dataPtr and alignedDataPtr; - auto dataPtr = + Value dataPtr = callApi(rewriter, loc, apiRegistry, API::GET_DATA, {rtMemRef}); dataPtr = rewriter.create( loc, memRefTy.cast().getBody()[0], dataPtr); @@ -1382,9 +1464,9 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { // Get rank, sizes array ptr and strides array ptr. auto rank = getRankFromMemRefType(memRefTy.cast()); - auto sizesArrayPtr = + Value sizesArrayPtr = callApi(rewriter, loc, apiRegistry, API::GET_DATA_SHAPE, {rtMemRef}); - auto stridesArrayPtr = + Value stridesArrayPtr = callApi(rewriter, loc, apiRegistry, API::GET_DATA_STRIDES, {rtMemRef}); for (decltype(rank) i = 0; i < rank; i++) { @@ -1447,18 +1529,25 @@ class KrnlEntryPointOpLowering : public OpRewritePattern { callApi(rewriter, loc, apiRegistry, API::SET_DATA, {outOMTensor, owning, outMemRefAllocatedPtr, outMemRefAlignedPtr}); - auto elemTy = + Type elemTy = outMemRefTy.getBody()[0].cast().getElementType(); - auto onnxTy = llvmTypeToOnnxType(elemTy); + + if (auto structType = elemTy.dyn_cast_or_null()) { + elemTy = structType.getBody()[0] + .cast() + .getElementType(); + } + + onnx::TensorProto::DataType onnxTy = mlirTypeToOnnxType(elemTy); auto onnxTyVal = rewriter.create( loc, int64Ty, rewriter.getI64IntegerAttr(onnxTy)); callApi(rewriter, loc, apiRegistry, API::SET_DATA_TYPE, {outOMTensor, onnxTyVal}); - auto rank = getRankFromMemRefType(outMemRefTy); - auto sizesArrayPtr = + int64_t rank = getRankFromMemRefType(outMemRefTy); + Value sizesArrayPtr = callApi(rewriter, loc, apiRegistry, API::GET_DATA_SHAPE, {outOMTensor}); - auto stridesArrayPtr = callApi( + Value stridesArrayPtr = callApi( rewriter, loc, apiRegistry, API::GET_DATA_STRIDES, {outOMTensor}); for (decltype(rank) i = 0; i < rank; i++) { @@ -1682,7 +1771,7 @@ class KrnlFindIndexOpLowering : public ConversionPattern { // Get a symbol reference to the runtime function to use, creating one if // necessary. ModuleOp parentModule = findIndexOp->getParentOfType(); - auto FindIndexRef = getOrInsertFindIndex( + FlatSymbolRefAttr findIndexRef = getOrInsertFindIndex( rewriter, parentModule, findIndexOp.input().getType()); // Select the value to pass to as the first argument based on the operator @@ -1716,8 +1805,8 @@ class KrnlFindIndexOpLowering : public ConversionPattern { Value length = operandAdaptor.len(); // Generate the call to the runtime function. - Type retType = IntegerType::get(context, 32); - auto funcCall = rewriter.create(loc, FindIndexRef, retType, + Type retType = IntegerType::get(context, 64); + auto funcCall = rewriter.create(loc, findIndexRef, retType, ArrayRef({firstOperand, extractedGPtr, extractedVPtr, length})); rewriter.replaceOp(op, funcCall.getResults()[0]); @@ -1756,8 +1845,8 @@ class KrnlFindIndexOpLowering : public ConversionPattern { if (optFuncDecl.hasValue()) return optFuncDecl.getValue(); - // Create 'find_index_*' signature: `i32 ([i8*|i64], i32*, i32*, i32)` - Type fnType = LLVM::LLVMFunctionType::get(i32Type, + // Create 'find_index_*' signature: `i64 ([i8*|i64], i32*, i32*, i32)` + Type fnType = LLVM::LLVMFunctionType::get(i64Type, ArrayRef({firstArgType, i32PtrType, i32PtrType, i32Type}), false); // Insert the function declaration the module. diff --git a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp index e419d2ef741c..9a7cb89a9959 100644 --- a/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp +++ b/src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp @@ -175,8 +175,8 @@ struct ONNXCategoryMapperOpLowering : public ConversionPattern { MemRefType type = MemRefType::get( {static_cast(V.size())}, builder.getIntegerType(32)); - res.G = create.krnl.constant(type, "G", builder.getI32VectorAttr(G)); - res.V = create.krnl.constant(type, "V", builder.getI32VectorAttr(V)); + res.G = create.krnl.constant(type, "G", builder.getI32TensorAttr(G)); + res.V = create.krnl.constant(type, "V", builder.getI32TensorAttr(V)); res.len = create.math.constant(builder.getIntegerType(32), G.size()); return res; }; diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index bf2af02071ae..07077fcdc3d2 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -4716,9 +4716,6 @@ static LogicalResult verify(ONNXCategoryMapperOp op) { return op.emitError("'default_string' attribute is missing."); if (elementType.isa() && !op.default_int64Attr()) return op.emitError("'default_int64' attribute is missing."); - if (op.default_stringAttr() && op.default_int64Attr()) - return op.emitError("Only one of 'default_int64' or 'default_string' " - "attributes must be specified"); return success(); } diff --git a/src/Runtime/OMIndexLookup.inc b/src/Runtime/OMIndexLookup.inc index eba837aae11f..c4671e00b695 100644 --- a/src/Runtime/OMIndexLookup.inc +++ b/src/Runtime/OMIndexLookup.inc @@ -45,12 +45,12 @@ static inline uint32_t hash_int64(uint32_t hval, int64_t val) { #ifdef __cplusplus extern "C" #endif - uint32_t + uint64_t find_index_str( const char *str, int32_t G[], int32_t V[], int32_t dictSize) { assert(str && G && V && dictSize > 0); int32_t d = G[hash_string(0, str) % dictSize]; - int32_t index = (d < 0) ? V[-d - 1] : V[hash_string(d, str) % dictSize]; + int64_t index = (d < 0) ? V[-d - 1] : V[hash_string(d, str) % dictSize]; assert(index >= 0 && index < dictSize); return index; } @@ -62,11 +62,11 @@ extern "C" #ifdef __cplusplus extern "C" #endif - uint32_t + uint64_t find_index_i64(int64_t val, int32_t G[], int32_t V[], int32_t dictSize) { assert(G && V && dictSize > 0); int32_t d = G[hash_int64(0, val) % dictSize]; - int32_t index = (d < 0) ? V[-d - 1] : V[hash_int64(d, val) % dictSize]; + int64_t index = (d < 0) ? V[-d - 1] : V[hash_int64(d, val) % dictSize]; assert(index >= 0 && index < dictSize); return index; } diff --git a/src/Runtime/PyExecutionSession.cpp b/src/Runtime/PyExecutionSession.cpp index 3594eafb3dd8..318da18db4ae 100644 --- a/src/Runtime/PyExecutionSession.cpp +++ b/src/Runtime/PyExecutionSession.cpp @@ -62,6 +62,7 @@ std::vector PyExecutionSession::pyRun( dtype = ONNX_TYPE_INT32; else if (py::isinstance>(inputPyArray)) dtype = ONNX_TYPE_INT64; + // string type missing else if (py::isinstance>(inputPyArray)) dtype = ONNX_TYPE_BOOL; // Missing fp16 support. @@ -71,6 +72,11 @@ std::vector PyExecutionSession::pyRun( dtype = ONNX_TYPE_UINT32; else if (py::isinstance>(inputPyArray)) dtype = ONNX_TYPE_UINT64; + else if (py::isinstance>>(inputPyArray)) + dtype = ONNX_TYPE_COMPLEX64; + else if (py::isinstance>>(inputPyArray)) + dtype = ONNX_TYPE_COMPLEX128; + // Missing bfloat16 support else { std::cerr << "Numpy type not supported: " << inputPyArray.dtype() << ".\n"; @@ -97,38 +103,55 @@ std::vector PyExecutionSession::pyRun( // https://numpy.org/devdocs/user/basics.types.html py::dtype dtype; - if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::FLOAT) + switch (omTensorGetDataType(omt)) { + case (OM_DATA_TYPE)onnx::TensorProto::FLOAT: dtype = py::dtype("float32"); - else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::UINT8) + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT8: dtype = py::dtype("uint8"); - else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT8) + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT8: dtype = py::dtype("int8"); - else if (omTensorGetDataType(omt) == - (OM_DATA_TYPE)onnx::TensorProto::UINT16) + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT16: dtype = py::dtype("uint16"); - else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT16) + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT16: dtype = py::dtype("int16"); - else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT32) + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT32: dtype = py::dtype("int32"); - else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::INT64) + break; + case (OM_DATA_TYPE)onnx::TensorProto::INT64: dtype = py::dtype("int64"); - // TODO(tjingrant) wait for Tong's input for how to represent string. - else if (omTensorGetDataType(omt) == (OM_DATA_TYPE)onnx::TensorProto::BOOL) + break; + case (OM_DATA_TYPE)onnx::TensorProto::STRING: + dtype = py::dtype("str"); + break; + case (OM_DATA_TYPE)onnx::TensorProto::BOOL: dtype = py::dtype("bool_"); - else if (omTensorGetDataType(omt) == - (OM_DATA_TYPE)onnx::TensorProto::FLOAT16) + break; + case (OM_DATA_TYPE)onnx::TensorProto::FLOAT16: dtype = py::dtype("float32"); - else if (omTensorGetDataType(omt) == - (OM_DATA_TYPE)onnx::TensorProto::DOUBLE) + break; + case (OM_DATA_TYPE)onnx::TensorProto::DOUBLE: dtype = py::dtype("float64"); - else if (omTensorGetDataType(omt) == - (OM_DATA_TYPE)onnx::TensorProto::UINT32) + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT32: dtype = py::dtype("uint32"); - else if (omTensorGetDataType(omt) == - (OM_DATA_TYPE)onnx::TensorProto::UINT64) + break; + case (OM_DATA_TYPE)onnx::TensorProto::UINT64: dtype = py::dtype("uint64"); - else { - fprintf(stderr, "Unsupported ONNX type in OMTensor."); + break; + case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX64: + dtype = py::dtype("csingle"); + break; + case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128: + dtype = py::dtype("cdouble"); + break; + default: + std::cerr << "Unsupported ONNX type in OMTensor: " + << omTensorGetDataType(omt) << ".\n"; exit(1); } diff --git a/test/mlir/krnl/krnl_category_mapper.mlir b/test/mlir/krnl/krnl_category_mapper.mlir index 6c78a3c6e43d..b70f3318b3b1 100644 --- a/test/mlir/krnl/krnl_category_mapper.mlir +++ b/test/mlir/krnl/krnl_category_mapper.mlir @@ -4,47 +4,45 @@ // Test that 'krnl.find_index' can be called when the first argument is a string. func private @test_find_index_str(%str: !krnl.string) -> index { - %G = "krnl.global"() {name = "G", shape = [3], value = dense<[1,0,-3]> : vector<3xi32>} : () -> memref<3xi32> - %V = "krnl.global"() {name = "V", shape = [3], value = dense<[1,2,0]> : vector<3xi32>} : () -> memref<3xi32> + %G = "krnl.global"() {name = "G", shape = [3], value = dense<[1,0,-3]> : tensor<3xi32>} : () -> memref<3xi32> + %V = "krnl.global"() {name = "V", shape = [3], value = dense<[1,2,0]> : tensor<3xi32>} : () -> memref<3xi32> %c3 = arith.constant 3 : i32 %index = "krnl.find_index"(%str, %G, %V, %c3) : (!krnl.string, memref<3xi32>, memref<3xi32>, i32) -> index return %index : index } -// CHECK-DAG: llvm.func @find_index_str(!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i32 -// CHECK-DAG: llvm.mlir.global internal constant @V(dense<[1, 2, 0]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> -// CHECK-DAG: llvm.mlir.global internal constant @G(dense<[1, 0, -3]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> +// CHECK-DAG: llvm.func @find_index_str(!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i64 +// CHECK-DAG: llvm.mlir.global internal constant @V(dense<[1, 2, 0]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> +// CHECK-DAG: llvm.mlir.global internal constant @G(dense<[1, 0, -3]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> // CHECK-LABEL: @test_find_index_str(%arg0: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>) -> i64 // CHECK-DAG: [[LEN:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK-DAG: [[STR:%.+]] = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[G:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[V:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: [[INDEX:%.+]] = llvm.call @find_index_str([[STR]], [[G]], [[V]], [[LEN]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i32 -// CHECK: [[UNR_CONV:%.+]] = builtin.unrealized_conversion_cast [[INDEX]] : i32 to i64 -// CHECK: llvm.return [[UNR_CONV]] : i64 +// CHECK: [[INDEX:%.+]] = llvm.call @find_index_str([[STR]], [[G]], [[V]], [[LEN]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i64 +// CHECK: llvm.return [[INDEX]] : i64 // ----- // Test that 'krnl.find_index' can be called when the first argument is a int64_t. func private @test_find_index_int(%val: i64) -> index { - %G = "krnl.global"() {name = "G", shape = [3], value = dense<[1,0,-3]> : vector<3xi32>} : () -> memref<3xi32> - %V = "krnl.global"() {name = "V", shape = [3], value = dense<[1,2,0]> : vector<3xi32>} : () -> memref<3xi32> + %G = "krnl.global"() {name = "G", shape = [3], value = dense<[1,0,-3]> : tensor<3xi32>} : () -> memref<3xi32> + %V = "krnl.global"() {name = "V", shape = [3], value = dense<[1,2,0]> : tensor<3xi32>} : () -> memref<3xi32> %c3 = arith.constant 3 : i32 %index = "krnl.find_index"(%val, %G, %V, %c3) : (i64, memref<3xi32>, memref<3xi32>, i32) -> index return %index : index -// CHECK-DAG: llvm.func @find_index_i64(i64, !llvm.ptr, !llvm.ptr, i32) -> i32 -// CHECK-DAG: llvm.mlir.global internal constant @V(dense<[1, 2, 0]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> -// CHECK-DAG: llvm.mlir.global internal constant @G(dense<[1, 0, -3]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> +// CHECK-DAG: llvm.func @find_index_i64(i64, !llvm.ptr, !llvm.ptr, i32) -> i64 +// CHECK-DAG: llvm.mlir.global internal constant @V(dense<[1, 2, 0]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> +// CHECK-DAG: llvm.mlir.global internal constant @G(dense<[1, 0, -3]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> // CHECK-LABEL: llvm.func @test_find_index_int(%arg0: i64) -> i64 // CHECK-DAG: [[LEN:%.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK-DAG: [[G:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[V:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> -// CHECK: [[INDEX:%.+]] = llvm.call @find_index_i64(%arg0, [[G]], [[V]], [[LEN]]) : (i64, !llvm.ptr, !llvm.ptr, i32) -> i32 -// CHECK: [[UNR_CONV:%.+]] = builtin.unrealized_conversion_cast [[INDEX]] : i32 to i64 -// CHECK: llvm.return [[UNR_CONV]] : i64 +// CHECK: [[INDEX:%.+]] = llvm.call @find_index_i64(%arg0, [[G]], [[V]], [[LEN]]) : (i64, !llvm.ptr, !llvm.ptr, i32) -> i64 +// CHECK: llvm.return [[INDEX]] : i64 } // ----- @@ -55,10 +53,10 @@ func private @test_category_mapper_string_to_int64(%arg0: memref<2x2x!krnl.strin %c-1_i64 = arith.constant -1 : i64 %c3_i32 = arith.constant 3 : i32 %0 = memref.alloc() {alignment = 16 : i64} : memref<2x2xi64> - %1 = "krnl.global"() {name = "G0", shape = [3], value = dense<[1, 0, -3]> : vector<3xi32>} : () -> memref<3xi32> - %2 = "krnl.global"() {name = "V1", shape = [3], value = dense<[1, 2, 0]> : vector<3xi32>} : () -> memref<3xi32> - %3 = "krnl.global"() {name = "cats_int64s2", shape = [3], value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> memref<3xi64> - %4 = "krnl.global"() {name = "cats_strings3", shape = [3], value = dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> + %1 = "krnl.global"() {name = "G", shape = [3], value = dense<[1, 0, -3]> : tensor<3xi32>} : () -> memref<3xi32> + %2 = "krnl.global"() {name = "V", shape = [3], value = dense<[1, 2, 0]> : tensor<3xi32>} : () -> memref<3xi32> + %3 = "krnl.global"() {name = "cats_int64s", shape = [3], value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> memref<3xi64> + %4 = "krnl.global"() {name = "cats_strings", shape = [3], value = dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> %5:2 = krnl.define_loops 2 krnl.iterate(%5#0, %5#1) with (%5#0 -> %arg1 = 0 to 2, %5#1 -> %arg2 = 0 to 2) { %6:2 = krnl.get_induction_var_value(%5#0, %5#1) : (!krnl.loop, !krnl.loop) -> (index, index) @@ -79,11 +77,29 @@ func private @test_category_mapper_string_to_int64(%arg0: memref<2x2x!krnl.strin // CHECK-DAG: llvm.func @strncmp(!llvm.ptr, !llvm.ptr, i64) -> i32 // CHECK-DAG: llvm.func @strlen(!llvm.ptr) -> i64 - // CHECK-DAG: llvm.func @find_index_str(!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i32 - // CHECK-DAG: llvm.mlir.global internal constant {{.*}}(dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>) {alignment = 16 : i64} : !llvm.array<3 x struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> - // CHECK-DAG: llvm.mlir.global internal constant {{.*}}(dense<[1, 2, 3]> : tensor<3xi64>) {alignment = 16 : i64} : !llvm.array<3 x i64> - // CHECK-DAG: llvm.mlir.global internal constant @V{{.*}}(dense<[1, 2, 0]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> - // CHECK-DAG: llvm.mlir.global internal constant @G{{.*}}(dense<[1, 0, -3]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> + // CHECK-DAG: llvm.func @find_index_str(!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i64 + // CHECK-DAG: llvm.mlir.global internal constant @cat("cat") + // CHECK-DAG: llvm.mlir.global internal constant @dog("dog") + // CHECK-DAG: llvm.mlir.global internal constant @cow("cow") + // CHECK: llvm.mlir.global internal constant @cats_strings{{.*}}() {alignment = 16 : i64} : !llvm.array<3 x ptr> { + // CHECK: [[ARRAY:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr> + // CHECK: [[CAT_ADDR:%.+]] = llvm.mlir.addressof @cat : !llvm.ptr> + // CHECK: [[ZERO:%.+]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: [[CAT_GEP:%.+]] = llvm.getelementptr [[CAT_ADDR]]{{.*}}[[ZERO]], [[ZERO]]{{.*}} : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK: [[CAT_INS_VAL:%.+]] = llvm.insertvalue [[CAT_GEP]], [[ARRAY]][0 : index] : !llvm.array<3 x ptr> + // CHECK: [[DOG_ADDR:%.+]] = llvm.mlir.addressof @dog : !llvm.ptr> + // CHECK: [[ZERO:%.+]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: [[DOG_GEP:%.+]] = llvm.getelementptr [[DOG_ADDR]]{{.*}}[[ZERO]], [[ZERO]]{{.*}} : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK: [[DOG_INS_VAL:%.+]] = llvm.insertvalue [[DOG_GEP]], [[CAT_INS_VAL]][1 : index] : !llvm.array<3 x ptr> + // CHECK: [[COW_ADDR:%.+]] = llvm.mlir.addressof @cow : !llvm.ptr> + // CHECK: [[ZERO:%.+]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: [[COW_GEP:%.+]] = llvm.getelementptr [[COW_ADDR]]{{.*}}[[ZERO]], [[ZERO]]{{.*}} : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK: [[COW_INS_VAL:%.+]] = llvm.insertvalue [[COW_GEP]], [[DOG_INS_VAL]][2 : index] : !llvm.array<3 x ptr> + // CHECK: llvm.return [[COW_INS_VAL]] : !llvm.array<3 x ptr> + // CHECK: } + // CHECK-DAG: llvm.mlir.global internal constant @cats_int64s{{.*}}(dense<[1, 2, 3]> : tensor<3xi64>) {alignment = 16 : i64} : !llvm.array<3 x i64> + // CHECK-DAG: llvm.mlir.global internal constant @V{{.*}}(dense<[1, 2, 0]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> + // CHECK-DAG: llvm.mlir.global internal constant @G{{.*}}(dense<[1, 0, -3]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> // CHECK-LABEL: @test_category_mapper_string_to_int64(%arg0: !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, %arg1: !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32 @@ -94,7 +110,7 @@ func private @test_category_mapper_string_to_int64(%arg0: memref<2x2x!krnl.strin // CHECK-DAG: [[STR:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[G:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[V:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[INDEX:%.+]] = llvm.call @find_index_str([[STR]], [[G]], [[V]], [[LEN]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i32 + // CHECK: [[INDEX:%.+]] = llvm.call @find_index_str([[STR]], [[G]], [[V]], [[LEN]]) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32) -> i64 /// Determine whether the index is valid: // CHECK: [[STR1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> @@ -123,11 +139,11 @@ func private @test_category_mapper_string_to_int64(%arg0: memref<2x2x!krnl.strin func private @test_category_mapper_int64_to_string(%arg0: memref<2x2xi64>) -> memref<2x2x!krnl.string> { %c3_i32 = arith.constant 3 : i32 %0 = memref.alloc() {alignment = 16 : i64} : memref<2x2x!krnl.string> - %1 = "krnl.global"() {name = "G4", shape = [3], value = dense<[-1, 1, 0]> : vector<3xi32>} : () -> memref<3xi32> - %2 = "krnl.global"() {name = "V5", shape = [3], value = dense<[2, 1, 0]> : vector<3xi32>} : () -> memref<3xi32> - %3 = "krnl.global"() {name = "cats_int64s6", shape = [3], value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> memref<3xi64> - %4 = "krnl.global"() {name = "cats_strings7", shape = [3], value = dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> - %5 = "krnl.global"() {name = "default_string8", shape = [], value = dense<"none"> : tensor} : () -> memref + %1 = "krnl.global"() {name = "G", shape = [3], value = dense<[-1, 1, 0]> : tensor<3xi32>} : () -> memref<3xi32> + %2 = "krnl.global"() {name = "V", shape = [3], value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> memref<3xi32> + %3 = "krnl.global"() {name = "cats_int64s", shape = [3], value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> memref<3xi64> + %4 = "krnl.global"() {name = "cats_strings", shape = [3], value = dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> + %5 = "krnl.global"() {name = "default_string", shape = [], value = dense<"none"> : tensor} : () -> memref %6:2 = krnl.define_loops 2 krnl.iterate(%6#0, %6#1) with (%6#0 -> %arg1 = 0 to 2, %6#1 -> %arg2 = 0 to 2) { %7:2 = krnl.get_induction_var_value(%6#0, %6#1) : (!krnl.loop, !krnl.loop) -> (index, index) @@ -145,12 +161,30 @@ func private @test_category_mapper_int64_to_string(%arg0: memref<2x2xi64>) -> me } return %0 : memref<2x2x!krnl.string> - // CHECK-DAG: llvm.func @find_index_i64(i64, !llvm.ptr, !llvm.ptr, i32) -> i32 - // CHECK-DAG: llvm.mlir.global internal constant @default_string{{.*}}(dense<"none"> : tensor) {alignment = 16 : i64} : !llvm.array<1 x struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> - // CHECK-DAG: llvm.mlir.global internal constant @cats_strings{{.*}}(dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>) {alignment = 16 : i64} : !llvm.array<3 x struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + // CHECK-DAG: llvm.func @find_index_i64(i64, !llvm.ptr, !llvm.ptr, i32) -> i64 + // CHECK-DAG: llvm.mlir.global internal constant @none("none") + // CHECK-DAG: llvm.mlir.global internal constant @cat("cat") + // CHECK-DAG: llvm.mlir.global internal constant @dog("dog") + // CHECK-DAG: llvm.mlir.global internal constant @cow("cow") + // CHECK: llvm.mlir.global internal constant @cats_strings{{.*}}() {alignment = 16 : i64} : !llvm.array<3 x ptr> { + // CHECK: [[ARRAY:%.+]] = llvm.mlir.undef : !llvm.array<3 x ptr> + // CHECK: [[CAT_ADDR:%.+]] = llvm.mlir.addressof @cat : !llvm.ptr> + // CHECK: [[ZERO:%.+]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: [[CAT_GEP:%.+]] = llvm.getelementptr [[CAT_ADDR]]{{.*}}[[ZERO]], [[ZERO]]{{.*}} : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK: [[CAT_INS_VAL:%.+]] = llvm.insertvalue [[CAT_GEP]], [[ARRAY]][0 : index] : !llvm.array<3 x ptr> + // CHECK: [[DOG_ADDR:%.+]] = llvm.mlir.addressof @dog : !llvm.ptr> + // CHECK: [[ZERO:%.+]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: [[DOG_GEP:%.+]] = llvm.getelementptr [[DOG_ADDR]]{{.*}}[[ZERO]], [[ZERO]]{{.*}} : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK: [[DOG_INS_VAL:%.+]] = llvm.insertvalue [[DOG_GEP]], [[CAT_INS_VAL]][1 : index] : !llvm.array<3 x ptr> + // CHECK: [[COW_ADDR:%.+]] = llvm.mlir.addressof @cow : !llvm.ptr> + // CHECK: [[ZERO:%.+]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: [[COW_GEP:%.+]] = llvm.getelementptr [[COW_ADDR]]{{.*}}[[ZERO]], [[ZERO]]{{.*}} : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK: [[COW_INS_VAL:%.+]] = llvm.insertvalue [[COW_GEP]], [[DOG_INS_VAL]][2 : index] : !llvm.array<3 x ptr> + // CHECK: llvm.return [[COW_INS_VAL]] : !llvm.array<3 x ptr> + // CHECK: } // CHECK-DAG: llvm.mlir.global internal constant @cats_int64s{{.*}}(dense<[1, 2, 3]> : tensor<3xi64>) {alignment = 16 : i64} : !llvm.array<3 x i64> - // CHECK-DAG: llvm.mlir.global internal constant @V{{.*}}(dense<[2, 1, 0]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> - // CHECK-DAG: llvm.mlir.global internal constant @G{{.*}}(dense<[-1, 1, 0]> : vector<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> + // CHECK-DAG: llvm.mlir.global internal constant @V{{.*}}(dense<[2, 1, 0]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> + // CHECK-DAG: llvm.mlir.global internal constant @G{{.*}}(dense<[-1, 1, 0]> : tensor<3xi32>) {alignment = 16 : i64} : !llvm.array<3 x i32> // CHECK-LABEL: @test_category_mapper_int64_to_string(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) -> !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>>, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: [[LEN:%.+]] = llvm.mlir.constant(3 : i32) : i32 @@ -165,12 +199,11 @@ func private @test_category_mapper_int64_to_string(%arg0: memref<2x2xi64>) -> me // CHECK-DAG: [[INPUT:%.+]] = llvm.load {{.*}} : !llvm.ptr // CHECK-DAG: [[G:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-DAG: [[V:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK: [[INDEX:%.+]] = llvm.call @find_index_i64([[INPUT]], [[G]], [[V]], [[LEN]]) : (i64, !llvm.ptr, !llvm.ptr, i32) -> i32 - // CHECK-NEXT: [[UNR_CONV:%.+]] = builtin.unrealized_conversion_cast [[INDEX]] : i32 to i64 + // CHECK: [[INDEX:%.+]] = llvm.call @find_index_i64([[INPUT]], [[G]], [[V]], [[LEN]]) : (i64, !llvm.ptr, !llvm.ptr, i32) -> i64 /// Determine whether the index is valid: // CHECK: [[EV1:%.+]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> - // CHECK-DAG: [[GEP1:%.+]] = llvm.getelementptr [[EV1]]{{.*}}[[UNR_CONV]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: [[GEP1:%.+]] = llvm.getelementptr [[EV1]]{{.*}}[[INDEX]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-DAG: [[INDEX1:%.+]] = llvm.load [[GEP1]] : !llvm.ptr /// Store the index if valid, otherwise store the default value: diff --git a/test/mlir/onnx/onnx_lowering_category_mapper.mlir b/test/mlir/onnx/onnx_lowering_category_mapper.mlir index ab8ac3f195a9..c42b8c837b7f 100644 --- a/test/mlir/onnx/onnx_lowering_category_mapper.mlir +++ b/test/mlir/onnx/onnx_lowering_category_mapper.mlir @@ -10,8 +10,8 @@ func private @test_category_mapper_string_to_int64(%arg0 : tensor<2x2x!onnx.Stri // CHECK-LABEL: test_category_mapper_string_to_int64 // CHECK-DAG: [[LEN:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[ALLOCA:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<2x2xi64> - // CHECK-DAG: [[G:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 0, -3]> : vector<3xi32>} : () -> memref<3xi32> - // CHECK-DAG: [[V:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 2, 0]> : vector<3xi32>} : () -> memref<3xi32> + // CHECK-DAG: [[G:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 0, -3]> : tensor<3xi32>} : () -> memref<3xi32> + // CHECK-DAG: [[V:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 2, 0]> : tensor<3xi32>} : () -> memref<3xi32> // CHECK-DAG: [[CAT_INT64s:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> memref<3xi64> // CHECK-DAG: [[CAT_STRINGS:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> // CHECK-DAG: [[DEFAULT_INT64:%.+]] = arith.constant -1 : i64 @@ -44,8 +44,8 @@ func private @test_category_mapper_int64_to_string(%arg0 : tensor<2x2xi64>) -> t // CHECK-LABEL: test_category_mapper_int64_to_string // CHECK-DAG: [[LEN:%.+]] = arith.constant 3 : i32 // CHECK-DAG: [[ALLOCA:%.+]] = memref.alloc() {alignment = 16 : i64} : memref<2x2x!krnl.string> - // CHECK-DAG: [[G:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[-1, 1, 0]> : vector<3xi32>} : () -> memref<3xi32> - // CHECK-DAG: [[V:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[2, 1, 0]> : vector<3xi32>} : () -> memref<3xi32> + // CHECK-DAG: [[G:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[-1, 1, 0]> : tensor<3xi32>} : () -> memref<3xi32> + // CHECK-DAG: [[V:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> memref<3xi32> // CHECK-DAG: [[CAT_INT64s:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<[1, 2, 3]> : tensor<3xi64>} : () -> memref<3xi64> // CHECK-DAG: [[CAT_STRINGS:%.+]] = "krnl.global"() {name = {{.*}}, shape = [3], value = dense<["cat", "dog", "cow"]> : tensor<3x!krnl.string>} : () -> memref<3x!krnl.string> // CHECK-DAG: [[DEFAULT_STRING:%.+]] = "krnl.global"() {name = {{.*}}, shape = [], value = dense<"none"> : tensor} : () -> memref