diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index d81b1f54c6..aa5fab3cb1 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -256,12 +256,6 @@ double analyzeSimdFor(Type t, int64_t &von, int64_t &son) { return simdAnalysis({GenericOps::PowGop}, {1}, t, von, son); } -template <> -struct ScalarOp { - using FOp = KrnlErfOp; - using IOp = NotSuportedScalarOp; -}; - template <> struct ScalarOp { using FOp = KrnlIsInfOp; @@ -771,6 +765,79 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, return create.math.select(zeroPredicate, zero, plusSelect); } +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXErfOp +//===----------------------------------------------------------------------===// +template <> +struct ScalarOp { + using FOp = CustomScalarOp; + using IOp = NotSuportedScalarOp; +}; + +template <> +double analyzeSimdFor(Type t, int64_t &von, int64_t &son) { + return simdAnalysis( + {GenericOps::ArithmeticGop, GenericOps::CompareGop, GenericOps::DivGop, + GenericOps::FmaGop, GenericOps::MulGop, GenericOps::SelectGop, + GenericOps::AbsGop, GenericOps::ExpGop}, + {2, 1, 1, 6, 2, 1, 1, 1}, t, von, son); +} + +template <> +Value emitScalarOpFor(ConversionPatternRewriter &rewriter, + Location loc, Operation *op, Type elementType, + ArrayRef scalarOperands) { + // Use numpy algorithm for rint as follows, according to + // https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/. + // ``` + // def erf(x): + // a1 = 0.254829592 + // a2 = -0.284496736 + // a3 = 1.421413741 + // a4 = -1.453152027 + // a5 = 1.061405429 + // p = 0.3275911 + // + // t = 1.0 / (1.0 + p * abs(x)) + // y = 1.0 - + // (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + // exp(-abs(x)*abs(x)) + // return -y if x < 0.0 else y + // } + // ``` + CheckIfCustomScalarOpIsSupported(elementType); + Value operand = scalarOperands[0]; + MultiDialectBuilder create(rewriter, loc); + Value zero = create.math.constant(elementType, 0); + Value one = create.math.constant(elementType, 1); + Value minusone = create.math.constant(elementType, -1); + Value a1 = create.math.constant(elementType, 0.254829592); + Value a2 = create.math.constant(elementType, -0.284496736); + Value a3 = create.math.constant(elementType, 1.421413741); + Value a4 = create.math.constant(elementType, -1.453152027); + Value a5 = create.math.constant(elementType, 1.061405429); + Value p = create.math.constant(elementType, 0.3275911); + Value absx = create.math.abs(operand); + Value t = create.math.div(one, create.math.fma(p, absx, one)); + // y = 1.0 - + // (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + // exp(-absx*absx) + // minusy = (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + // exp(-absx*absx) + // + (-1.0) + Value minusy = create.math.fma( + create.math.mul( + create.math.fma( + create.math.fma( + create.math.fma(create.math.fma(a5, t, a4), t, a3), t, a2), + t, a1), + t), + create.math.exp(create.math.mul(create.math.sub(zero, absx), absx)), + minusone); + Value sign = create.math.gt(operand, zero); + return create.math.select(sign, create.math.sub(zero, minusy), minusy); +} + //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 613e2bec5c..9eee0ab9cb 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -205,6 +205,15 @@ Value MathBuilder::floorDiv(Value lhs, Value rhs) const { llvm_unreachable("expected int"); } +// return (lhs * rhs) + acc +Value MathBuilder::fma(Value lhs, Value rhs, Value acc) const { + assert((lhs.getType() == rhs.getType()) && (rhs.getType() == acc.getType()) && + "expected same type"); + if (isFloatWithVector(lhs.getType()) && !isa(lhs.getType())) + return b().create(loc(), lhs, rhs, acc); + return add(mul(lhs, rhs), acc); +} + Value MathBuilder::exp(Value val) const { if (isFloatWithVector(val.getType())) return b().create(loc(), val); diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index 6d83a83109..66825aa2b1 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -114,8 +114,9 @@ struct MathBuilder final : DialectBuilder { mlir::Value exp2(mlir::Value val) const; // Float only. mlir::Value floor(mlir::Value val) const; // Float only. mlir::Value floorDiv(mlir::Value lhs, mlir::Value rhs) const; // Int only. - mlir::Value log(mlir::Value val) const; // Float only. - mlir::Value log2(mlir::Value val) const; // Float only. + mlir::Value fma(mlir::Value lhs, mlir::Value rhs, mlir::Value acc) const; + mlir::Value log(mlir::Value val) const; // Float only. + mlir::Value log2(mlir::Value val) const; // Float only. mlir::Value mul(mlir::Value lhs, mlir::Value rhs) const; mlir::Value neg(mlir::Value val) const; mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const; // Int only. diff --git a/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir b/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir index 8e2de225d5..ffbfd8436f 100644 --- a/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir +++ b/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir @@ -4632,6 +4632,58 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor< // ----- +func.func @test_erf(%arg0: tensor) -> (tensor<*xf32>) attributes {} { + %0 = "onnx.Erf"(%arg0): (tensor) -> (tensor<*xf32>) + return %0 : tensor<*xf32> +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @test_erf +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[CST_0_dot_327591091_:%.+]] = arith.constant 0.327591091 : f32 +// CHECK-DAG: [[CST_1_dot_06140542_:%.+]] = arith.constant 1.06140542 : f32 +// CHECK-DAG: [[CST_minus_1_dot_45315206_:%.+]] = arith.constant -1.45315206 : f32 +// CHECK-DAG: [[CST_1_dot_42141378_:%.+]] = arith.constant 1.42141378 : f32 +// CHECK-DAG: [[CST_minus_0_dot_284496725_:%.+]] = arith.constant -0.284496725 : f32 +// CHECK-DAG: [[CST_0_dot_254829586_:%.+]] = arith.constant 0.254829586 : f32 +// CHECK-DAG: [[CST_minus_1_dot_000000_:%.+]] = arith.constant -1.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_8_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[MAP_0_]]([[VAR_dim_8_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10){ +// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref +// CHECK-DAG: [[VAR_3_:%.+]] = math.absf [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.mulf [[VAR_3_]], [[CST_0_dot_327591091_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.addf [[VAR_4_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.mulf [[VAR_6_]], [[CST_1_dot_06140542_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_7_]], [[CST_minus_1_dot_45315206_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.mulf [[VAR_8_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_42141378_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.mulf [[VAR_10_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_11_]], [[CST_minus_0_dot_284496725_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_12_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.addf [[VAR_13_]], [[CST_0_dot_254829586_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[VAR_6_]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_3_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.mulf [[VAR_16_]], [[VAR_3_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = math.exp [[VAR_17_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.mulf [[VAR_15_]], [[VAR_18_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[VAR_19_]], [[CST_minus_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf ogt, [[LOAD_PARAM_0_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_20_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_20_]] : f32 +// CHECK: krnl.store [[VAR_23_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref +// CHECK: } +// CHECK: return [[RES_]] : memref +} + +// ----- + func.func @add_partial_splat(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<3x1x1xf32>) -> tensor<*xf32> attributes {input_names = ["x", "y"], output_names = ["sum"]} { %0 = "onnx.Add"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<3x1x1xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> diff --git a/test/mlir/onnx/onnx_lowering_with_canonicalize_O3.mlir b/test/mlir/onnx/onnx_lowering_with_canonicalize_O3.mlir index 9ef2ce5471..ee06f00825 100644 --- a/test/mlir/onnx/onnx_lowering_with_canonicalize_O3.mlir +++ b/test/mlir/onnx/onnx_lowering_with_canonicalize_O3.mlir @@ -5273,3 +5273,64 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor< // CHECK: return [[RES_8_]] : memref<16x64xi32> // CHECK: } } + +// ----- + +// Test onnx.Erf lowering from onnx to kerneL +func.func @test_erf(%arg0: tensor) -> (tensor<*xf32>) { + %0 = "onnx.Erf"(%arg0): (tensor) -> (tensor<*xf32>) + return %0 : tensor<*xf32> +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 64)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s1)> +// CHECK-LABEL: func.func @test_erf +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.327591091> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<1.06140542> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<-1.45315206> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<1.42141378> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<-0.284496725> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0.254829586> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<-1.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32> +// CHECK-DAG: [[VAR_cst_7_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-DAG: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref +// CHECK-DAG: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}{{.}}[[VAR_dim_]]{{.}} : memref to memref +// CHECK-DAG: [[VAR_dim_8_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_8_]]{{.}} +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK-DAG: affine.store [[VAR_1_]], [[RES_1_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK-DAG: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_11_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK-DAG: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK-DAG: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_2_]]{{.}}){ +// CHECK-DAG: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<16xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<16xf32> +// CHECK-DAG: [[VAR_7_:%.+]] = vector.fma [[VAR_cst_]], [[VAR_6_]], [[VAR_cst_]]_6 : vector<16xf32> +// CHECK-DAG: [[VAR_8_:%.+]] = arith.divf [[VAR_cst_6_]], [[VAR_7_]] : vector<16xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.fma [[VAR_cst_0_]], [[VAR_8_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.fma [[VAR_9_]], [[VAR_8_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = vector.fma [[VAR_10_]], [[VAR_8_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.fma [[VAR_11_]], [[VAR_8_]], [[VAR_cst_4_]] : vector<16xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_12_]], [[VAR_8_]] : vector<16xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = arith.subf [[VAR_cst_7_]], [[VAR_6_]] : vector<16xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[VAR_6_]] : vector<16xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = math.exp [[VAR_15_]] : vector<16xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = vector.fma [[VAR_13_]], [[VAR_16_]], [[VAR_cst_5_]] : vector<16xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_7_]] : vector<16xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.subf [[VAR_cst_7_]], [[VAR_17_]] : vector<16xf32> +// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_17_]] : vector<16xi1>, vector<16xf32> +// CHECK: vector.store [[VAR_20_]], [[VAR_reshape_11_]]{{.}}[[VAR_4_]]{{.}} : memref, vector<16xf32> +// CHECK: } +// CHECK: return [[VAR_view_]] : memref +} diff --git a/test/mlir/onnx/onnx_math_functions_lowering.mlir b/test/mlir/onnx/onnx_math_functions_lowering.mlir index f0a9df57f4..51ed7711de 100644 --- a/test/mlir/onnx/onnx_math_functions_lowering.mlir +++ b/test/mlir/onnx/onnx_math_functions_lowering.mlir @@ -1,21 +1,5 @@ // RUN: onnx-mlir-opt -O3 --shape-inference --convert-onnx-to-krnl %s -split-input-file | FileCheck %s -/// onnx.Erf lowering to krnl.erf. -func.func @erf_function(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { - %0 = "onnx.Erf"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> - "func.return"(%0) : (tensor<10x10xf32>) -> () -} - -// CHECK-LABEL erf_function -// CHECK: [[ALLOC:%.+]] = memref.alloc() {{.*}}: memref<10x10xf32> -// CHECK: [[LOOP:%.+]]:2 = krnl.define_loops 2 -// CHECK: krnl.iterate -// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[LOOP]]#0, [[LOOP]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK: [[LOAD:%.+]] = {{.*}}load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xf32> -// CHECK: [[ERF:%.+]] = "krnl.erf"([[LOAD]]) : (f32) -> f32 -// CHECK: {{.*}}store [[ERF]], [[ALLOC]][[[IV]]#0, [[IV]]#1] : memref<10x10xf32> -// CHECK: return [[ALLOC]] : memref<10x10xf32> - /// onnx.Acos lowering to krnl.acos. func.func @acos_function(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { %0 = "onnx.Acos"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32> diff --git a/test/modellib/ElementwiseModel.cpp b/test/modellib/ElementwiseModel.cpp index 148bb0190c..b4a267dc95 100644 --- a/test/modellib/ElementwiseModel.cpp +++ b/test/modellib/ElementwiseModel.cpp @@ -34,6 +34,7 @@ namespace test { // ONNXAddOp // ONNXDivOp // ONNXHardSigmoidOp +// ONNXErFop const float alphaVal = 2.0; const float betaVal = 0.5; @@ -41,7 +42,10 @@ const float betaVal = 0.5; Elementwise2DLibBuilder::Elementwise2DLibBuilder(const std::string &modelName, const std::string &onnxOpName, const int I, const int J) : ModelLibBuilder(modelName), onnxOpName(onnxOpName), I(I), J(J), - inputNum((onnxOpName.compare("ONNXHardSigmoidOp") == 0) ? 1 : 2) {} + inputNum(((onnxOpName.compare("ONNXHardSigmoidOp") == 0) || + (onnxOpName.compare("ONNXErfOp") == 0)) + ? 1 + : 2) {} bool Elementwise2DLibBuilder::build() { llvm::SmallVector shape = {I, J}; @@ -70,6 +74,10 @@ bool Elementwise2DLibBuilder::build() { auto op = builder.create(loc, yType, aVal, alpha, beta); results.emplace_back(op.getResult()); + } else if (onnxOpName.compare("ONNXErfOp") == 0) { + // Erf. + auto op = builder.create(loc, yType, aVal); + results.emplace_back(op.getResult()); } else llvm_unreachable("unsupported unary elementwise op"); @@ -166,6 +174,11 @@ bool Elementwise2DLibBuilder::verifyOutputs() { val = (val < 1.0) ? val : 1.0; return val; }; + else if (onnxOpName.compare("ONNXErfOp") == 0) + fct = [](float a) -> float { + float val = erf(a); + return val; + }; else llvm_unreachable("unsupported binary elementwise op"); for (int64_t i = 0; i < I; ++i) { diff --git a/test/numerical/TestElementwise.cpp b/test/numerical/TestElementwise.cpp index b7d52efc3f..abb7f2b3f4 100644 --- a/test/numerical/TestElementwise.cpp +++ b/test/numerical/TestElementwise.cpp @@ -93,5 +93,15 @@ int main(int argc, char *argv[]) { if (!success) return 1; + printf("RapidCheck test Erf case generation.\n"); + success = rc::check("Gemm implementation correctness", [&]() { + const int maxRange = 128; + const int I = *rc::gen::inRange(1, maxRange); + const int J = *rc::gen::inRange(1, maxRange); + RC_ASSERT(isOMElementwiseTheSameAsNaiveImplFor("ONNXErfOp", I, J)); + }); + if (!success) + return 1; + return 0; }