From b1af3cccc93a3f8887b8de1931495df6c00d7826 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 23 Dec 2024 17:46:37 -0600 Subject: [PATCH 1/3] [VM] Add support for UI64 to F32 casts Signed-off-by: zjgarvey --- .../Dialect/VM/Conversion/ArithToVM/Patterns.cpp | 4 +++- .../VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 1 + .../src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp | 10 ++++++++++ .../src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td | 1 + compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td | 8 ++++++++ runtime/src/iree/vm/bytecode/dispatch.c | 5 +++++ .../src/iree/vm/bytecode/utils/generated/op_table.h | 2 +- runtime/src/iree/vm/bytecode/verifier.c | 4 ++++ runtime/src/iree/vm/ops.h | 3 +++ 9 files changed, 36 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp index 6d902e549611..27197d605956 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp @@ -598,7 +598,9 @@ struct UIToFPOpConversion : public OpConversionPattern { } if (srcType.isUnsignedInteger(64) || srcType.isSignlessInteger(64)) { if (dstType.isF32()) { - return rewriter.notifyMatchFailure(srcOp, "unsupported type"); + rewriter.replaceOpWithNewOp(srcOp, resultType, + input); + return success(); } rewriter.replaceOpWithNewOp(srcOp, resultType, diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 845c3e3ace4b..abb6886a590e 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -4522,6 +4522,7 @@ void populateVMToEmitCPatterns(ConversionTarget &conversionTarget, ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64"); ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32"); ADD_GENERIC_PATTERN(IREE::VM::CastSI64F32Op, "vm_cast_si64f32"); + ADD_GENERIC_PATTERN(IREE::VM::CastUI64F32Op, "vm_cast_ui64f32"); ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32"); ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32"); ADD_GENERIC_PATTERN(IREE::VM::CmpEQF32OOp, "vm_cmp_eq_f32o"); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index f227292e3e4a..19f004aba6b0 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -1693,6 +1693,16 @@ OpFoldResult CastSI64F32Op::fold(FoldAdaptor operands) { }); } +OpFoldResult CastUI64F32Op::fold(FoldAdaptor operands) { + return constFoldCastOp( + Float32Type::get(getContext()), operands.getOperand(), + [&](const APInt &a) { + APFloat b = APFloat(0.0f); + b.convertFromAPInt(a, /*IsSigned=*/false, APFloat::rmNearestTiesToAway); + return b; + }); +} + OpFoldResult CastUI32F32Op::fold(FoldAdaptor operands) { return constFoldCastOp( Float32Type::get(getContext()), operands.getOperand(), diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td index 1ce37ae032f0..e79ca33ab616 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td @@ -46,6 +46,7 @@ def VM_OPC_MaxF32 : VM_OPC<0x38, "MaxF32">; def VM_OPC_CastSI32F32 : VM_OPC<0x14, "CastSI32F32">; def VM_OPC_CastSI64F32 : VM_OPC<0x3C, "CastSI64F32">; +def VM_OPC_CastUI64F32 : VM_OPC<0x3D, "CastUI64F32">; def VM_OPC_CastUI32F32 : VM_OPC<0x15, "CastUI32F32">; def VM_OPC_CastF32SI32 : VM_OPC<0x16, "CastF32SI32">; def VM_OPC_CastF32SI64 : VM_OPC<0x3A, "CastF32SI64">; diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index f7e59449211b..feb088e35d02 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -3174,6 +3174,14 @@ def VM_CastSI64F32Op : let hasFolder = 1; } +def VM_CastUI64F32Op : + VM_ConversionOp { + let summary = [{cast from an unsigned integer to a float-point value}]; + let hasFolder = 1; +} + + def VM_CastUI64F64Op : VM_ConversionOp { diff --git a/runtime/src/iree/vm/bytecode/dispatch.c b/runtime/src/iree/vm/bytecode/dispatch.c index ba48f3228477..f7f4d3d4bcf7 100644 --- a/runtime/src/iree/vm/bytecode/dispatch.c +++ b/runtime/src/iree/vm/bytecode/dispatch.c @@ -2051,6 +2051,11 @@ static iree_status_t iree_vm_bytecode_dispatch( float* result = VM_DecResultRegF32("result"); *result = vm_cast_si64f32(operand); }); + DISPATCH_OP(EXT_F32, CastUI64F32, { + int64_t operand = (int64_t)VM_DecOperandRegI64("operand"); + float* result = VM_DecResultRegF32("result"); + *result = vm_cast_ui64f32(operand); + }); DISPATCH_OP(EXT_F32, CastUI32F32, { int32_t operand = (int32_t)VM_DecOperandRegI32("operand"); float* result = VM_DecResultRegF32("result"); diff --git a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h index 2c760731a023..e9823f88caf3 100644 --- a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h +++ b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h @@ -585,7 +585,7 @@ typedef enum { IREE_VM_OP_EXT_F32_CastF32SI64 = 0x3A, IREE_VM_OP_EXT_F32_CastF32UI64 = 0x3B, IREE_VM_OP_EXT_F32_CastSI64F32 = 0x3C, - IREE_VM_OP_EXT_F32_RSV_0x3D, + IREE_VM_OP_EXT_F32_CastUI64F32 = 0x3D, IREE_VM_OP_EXT_F32_RSV_0x3E, IREE_VM_OP_EXT_F32_RSV_0x3F, IREE_VM_OP_EXT_F32_RSV_0x40, diff --git a/runtime/src/iree/vm/bytecode/verifier.c b/runtime/src/iree/vm/bytecode/verifier.c index 5c726db74e8f..c0ef2c34fb02 100644 --- a/runtime/src/iree/vm/bytecode/verifier.c +++ b/runtime/src/iree/vm/bytecode/verifier.c @@ -1827,6 +1827,10 @@ static iree_status_t iree_vm_bytecode_function_verify_bytecode_op( VM_VerifyOperandRegI64(operand); VM_VerifyResultRegF32(result); }); + VERIFY_OP(EXT_F32, CastUI64F32, { + VM_VerifyOperandRegI64(operand); + VM_VerifyResultRegF32(result); + }); VERIFY_OP(EXT_F32, CastUI32F32, { VM_VerifyOperandRegI32(operand); VM_VerifyResultRegF32(result); diff --git a/runtime/src/iree/vm/ops.h b/runtime/src/iree/vm/ops.h index 68c939e62350..ae280629e11f 100644 --- a/runtime/src/iree/vm/ops.h +++ b/runtime/src/iree/vm/ops.h @@ -600,6 +600,9 @@ static inline float vm_erf_f32(float operand) { return erff(operand); } static inline float vm_cast_si32f32(int32_t operand) { return (float)operand; } static inline float vm_cast_si64f32(int64_t operand) { return (float)operand; } +static inline float vm_cast_ui64f32(int64_t operand) { + return (float)(uint64_t)operand; +} static inline float vm_cast_ui32f32(int32_t operand) { return (float)(uint32_t)operand; } From 646dd98965b94066df271f1c3efdb976577ed125 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 23 Dec 2024 17:52:42 -0600 Subject: [PATCH 2/3] Add a test for arith to vm conversion Signed-off-by: zjgarvey --- .../VM/Conversion/ArithToVM/test/conversion_ops.mlir | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir index 5b8da0ba9f03..019b1b623bcd 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir @@ -312,6 +312,18 @@ module @uitofp_i32_f32 { // ----- +// CHECK-LABEL: @uitofp_i64_f32 +module @uitofp_i64_f32 { + // CHECK: vm.func private @fn(%[[ARG0:.+]]: i64) + func.func @fn(%arg0: i64) -> f32 { + // CHECK: vm.cast.ui64.f32 %[[ARG0]] : i64 -> f32 + %0 = arith.uitofp %arg0 : i64 to f32 + return %0 : f32 + } +} + +// ----- + // CHECK-LABEL: @fptosi_fp32_i8 module @fptosi_fp32_i8 { // CHECK: vm.func private @fn(%[[ARG0:.+]]: f32) From 6e2d2e85c0ed45c853075b4ccf37d7dc60a78d6c Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 6 Jan 2025 12:51:34 -0600 Subject: [PATCH 3/3] change input signatures for uitofp casts to uint*_t Signed-off-by: zjgarvey --- runtime/src/iree/vm/ops.h | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/runtime/src/iree/vm/ops.h b/runtime/src/iree/vm/ops.h index ae280629e11f..3dc5ee0e7851 100644 --- a/runtime/src/iree/vm/ops.h +++ b/runtime/src/iree/vm/ops.h @@ -600,12 +600,8 @@ static inline float vm_erf_f32(float operand) { return erff(operand); } static inline float vm_cast_si32f32(int32_t operand) { return (float)operand; } static inline float vm_cast_si64f32(int64_t operand) { return (float)operand; } -static inline float vm_cast_ui64f32(int64_t operand) { - return (float)(uint64_t)operand; -} -static inline float vm_cast_ui32f32(int32_t operand) { - return (float)(uint32_t)operand; -} +static inline float vm_cast_ui64f32(uint64_t operand) { return (float)operand; } +static inline float vm_cast_ui32f32(uint32_t operand) { return (float)operand; } static inline int32_t vm_cast_f32si32(float operand) { return (int32_t)lroundf(operand); }