Skip to content

Commit

Permalink
Support code generation for onnx.Erf (#2160)
Browse files Browse the repository at this point in the history
* Support code generation for onnx.Erf.

Signed-off-by: Yasushi Negishi <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
negiyas and AlexandreEichenberger authored May 2, 2023
1 parent cf0fb99 commit dea431d
Show file tree
Hide file tree
Showing 8 changed files with 222 additions and 25 deletions.
79 changes: 73 additions & 6 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,6 @@ double analyzeSimdFor<ONNXPowOp>(Type t, int64_t &von, int64_t &son) {
return simdAnalysis({GenericOps::PowGop}, {1}, t, von, son);
}

template <>
struct ScalarOp<ONNXErfOp> {
using FOp = KrnlErfOp;
using IOp = NotSuportedScalarOp;
};

template <>
struct ScalarOp<ONNXIsInfOp> {
using FOp = KrnlIsInfOp;
Expand Down Expand Up @@ -771,6 +765,79 @@ Value emitScalarOpFor<ONNXSignOp>(ConversionPatternRewriter &rewriter,
return create.math.select(zeroPredicate, zero, plusSelect);
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXErfOp
//===----------------------------------------------------------------------===//
template <>
struct ScalarOp<ONNXErfOp> {
using FOp = CustomScalarOp;
using IOp = NotSuportedScalarOp;
};

template <>
double analyzeSimdFor<ONNXErfOp>(Type t, int64_t &von, int64_t &son) {
return simdAnalysis(
{GenericOps::ArithmeticGop, GenericOps::CompareGop, GenericOps::DivGop,
GenericOps::FmaGop, GenericOps::MulGop, GenericOps::SelectGop,
GenericOps::AbsGop, GenericOps::ExpGop},
{2, 1, 1, 6, 2, 1, 1, 1}, t, von, son);
}

template <>
Value emitScalarOpFor<ONNXErfOp>(ConversionPatternRewriter &rewriter,
Location loc, Operation *op, Type elementType,
ArrayRef<Value> scalarOperands) {
// Use numpy algorithm for rint as follows, according to
// https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/.
// ```
// def erf(x):
// a1 = 0.254829592
// a2 = -0.284496736
// a3 = 1.421413741
// a4 = -1.453152027
// a5 = 1.061405429
// p = 0.3275911
//
// t = 1.0 / (1.0 + p * abs(x))
// y = 1.0 -
// (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
// exp(-abs(x)*abs(x))
// return -y if x < 0.0 else y
// }
// ```
CheckIfCustomScalarOpIsSupported<ONNXErfOp>(elementType);
Value operand = scalarOperands[0];
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
Value zero = create.math.constant(elementType, 0);
Value one = create.math.constant(elementType, 1);
Value minusone = create.math.constant(elementType, -1);
Value a1 = create.math.constant(elementType, 0.254829592);
Value a2 = create.math.constant(elementType, -0.284496736);
Value a3 = create.math.constant(elementType, 1.421413741);
Value a4 = create.math.constant(elementType, -1.453152027);
Value a5 = create.math.constant(elementType, 1.061405429);
Value p = create.math.constant(elementType, 0.3275911);
Value absx = create.math.abs(operand);
Value t = create.math.div(one, create.math.fma(p, absx, one));
// y = 1.0 -
// (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
// exp(-absx*absx)
// minusy = (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t *
// exp(-absx*absx)
// + (-1.0)
Value minusy = create.math.fma(
create.math.mul(
create.math.fma(
create.math.fma(
create.math.fma(create.math.fma(a5, t, a4), t, a3), t, a2),
t, a1),
t),
create.math.exp(create.math.mul(create.math.sub(zero, absx), absx)),
minusone);
Value sign = create.math.gt(operand, zero);
return create.math.select(sign, create.math.sub(zero, minusy), minusy);
}

//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 9 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ Value MathBuilder::floorDiv(Value lhs, Value rhs) const {
llvm_unreachable("expected int");
}

// return (lhs * rhs) + acc
Value MathBuilder::fma(Value lhs, Value rhs, Value acc) const {
assert((lhs.getType() == rhs.getType()) && (rhs.getType() == acc.getType()) &&
"expected same type");
if (isFloatWithVector(lhs.getType()) && !isa<FloatType>(lhs.getType()))
return b().create<vector::FMAOp>(loc(), lhs, rhs, acc);
return add(mul(lhs, rhs), acc);
}

Value MathBuilder::exp(Value val) const {
if (isFloatWithVector(val.getType()))
return b().create<math::ExpOp>(loc(), val);
Expand Down
5 changes: 3 additions & 2 deletions src/Dialect/Mlir/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ struct MathBuilder final : DialectBuilder {
mlir::Value exp2(mlir::Value val) const; // Float only.
mlir::Value floor(mlir::Value val) const; // Float only.
mlir::Value floorDiv(mlir::Value lhs, mlir::Value rhs) const; // Int only.
mlir::Value log(mlir::Value val) const; // Float only.
mlir::Value log2(mlir::Value val) const; // Float only.
mlir::Value fma(mlir::Value lhs, mlir::Value rhs, mlir::Value acc) const;
mlir::Value log(mlir::Value val) const; // Float only.
mlir::Value log2(mlir::Value val) const; // Float only.
mlir::Value mul(mlir::Value lhs, mlir::Value rhs) const;
mlir::Value neg(mlir::Value val) const;
mlir::Value ori(mlir::Value lhs, mlir::Value rhs) const; // Int only.
Expand Down
52 changes: 52 additions & 0 deletions test/mlir/onnx/onnx_lowering_with_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4632,6 +4632,58 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor<

// -----

func.func @test_erf(%arg0: tensor<?x10xf32>) -> (tensor<*xf32>) attributes {} {
%0 = "onnx.Erf"(%arg0): (tensor<?x10xf32>) -> (tensor<*xf32>)
return %0 : tensor<*xf32>
// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @test_erf
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
// CHECK-DAG: [[CST_0_dot_327591091_:%.+]] = arith.constant 0.327591091 : f32
// CHECK-DAG: [[CST_1_dot_06140542_:%.+]] = arith.constant 1.06140542 : f32
// CHECK-DAG: [[CST_minus_1_dot_45315206_:%.+]] = arith.constant -1.45315206 : f32
// CHECK-DAG: [[CST_1_dot_42141378_:%.+]] = arith.constant 1.42141378 : f32
// CHECK-DAG: [[CST_minus_0_dot_284496725_:%.+]] = arith.constant -0.284496725 : f32
// CHECK-DAG: [[CST_0_dot_254829586_:%.+]] = arith.constant 0.254829586 : f32
// CHECK-DAG: [[CST_minus_1_dot_000000_:%.+]] = arith.constant -1.000000e+00 : f32
// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref<?x10xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[VAR_dim_8_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-DAG: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[MAP_0_]]([[VAR_dim_8_]]), [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10){
// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<?x10xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = math.absf [[LOAD_PARAM_0_MEM_]] : f32
// CHECK-DAG: [[VAR_4_:%.+]] = arith.mulf [[VAR_3_]], [[CST_0_dot_327591091_]] : f32
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addf [[VAR_4_]], [[CST_1_dot_000000_]] : f32
// CHECK-DAG: [[VAR_6_:%.+]] = arith.divf [[CST_1_dot_000000_]], [[VAR_5_]] : f32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.mulf [[VAR_6_]], [[CST_1_dot_06140542_]] : f32
// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_7_]], [[CST_minus_1_dot_45315206_]] : f32
// CHECK-DAG: [[VAR_9_:%.+]] = arith.mulf [[VAR_8_]], [[VAR_6_]] : f32
// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_42141378_]] : f32
// CHECK-DAG: [[VAR_11_:%.+]] = arith.mulf [[VAR_10_]], [[VAR_6_]] : f32
// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_11_]], [[CST_minus_0_dot_284496725_]] : f32
// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_12_]], [[VAR_6_]] : f32
// CHECK-DAG: [[VAR_14_:%.+]] = arith.addf [[VAR_13_]], [[CST_0_dot_254829586_]] : f32
// CHECK-DAG: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[VAR_6_]] : f32
// CHECK-DAG: [[VAR_16_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_3_]] : f32
// CHECK-DAG: [[VAR_17_:%.+]] = arith.mulf [[VAR_16_]], [[VAR_3_]] : f32
// CHECK-DAG: [[VAR_18_:%.+]] = math.exp [[VAR_17_]] : f32
// CHECK-DAG: [[VAR_19_:%.+]] = arith.mulf [[VAR_15_]], [[VAR_18_]] : f32
// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[VAR_19_]], [[CST_minus_1_dot_000000_]] : f32
// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf ogt, [[LOAD_PARAM_0_MEM_]], [[CST_0_dot_000000_]] : f32
// CHECK-DAG: [[VAR_22_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_20_]] : f32
// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_20_]] : f32
// CHECK: krnl.store [[VAR_23_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<?x10xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<?x10xf32>
}

// -----

func.func @add_partial_splat(%arg0: tensor<2x3x4x5xf32>, %arg1: tensor<3x1x1xf32>) -> tensor<*xf32> attributes {input_names = ["x", "y"], output_names = ["sum"]} {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<2x3x4x5xf32>, tensor<3x1x1xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
Expand Down
61 changes: 61 additions & 0 deletions test/mlir/onnx/onnx_lowering_with_canonicalize_O3.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5273,3 +5273,64 @@ func.func @test_matmulinteger_per_row_a(%arg0: tensor<16x32xui8>, %arg1: tensor<
// CHECK: return [[RES_8_]] : memref<16x64xi32>
// CHECK: }
}

// -----

// Test onnx.Erf lowering from onnx to kerneL
func.func @test_erf(%arg0: tensor<?x10xf32>) -> (tensor<*xf32>) {
%0 = "onnx.Erf"(%arg0): (tensor<?x10xf32>) -> (tensor<*xf32>)
return %0 : tensor<*xf32>
// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 40 + 64)>
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0] -> (s0 * 10)>
// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s1)>
// CHECK-LABEL: func.func @test_erf
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x10xf32>) -> memref<?x10xf32> {
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.327591091> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<1.06140542> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<-1.45315206> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_2_:%.+]] = arith.constant dense<1.42141378> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_3_:%.+]] = arith.constant dense<-0.284496725> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_4_:%.+]] = arith.constant dense<0.254829586> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_5_:%.+]] = arith.constant dense<-1.000000e+00> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_6_:%.+]] = arith.constant dense<1.000000e+00> : vector<16xf32>
// CHECK-DAG: [[VAR_cst_7_:%.+]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-DAG: [[VAR_0_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_0_]]) {{.*}}: memref<?xi8>
// CHECK-DAG: [[VAR_view_:%.+]] = memref.view [[RES_]]{{.}}[[CST_0_]]{{.}}{{.}}[[VAR_dim_]]{{.}} : memref<?xi8> to memref<?x10xf32>
// CHECK-DAG: [[VAR_dim_8_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?x10xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_8_]]{{.}}
// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK-DAG: affine.store [[VAR_1_]], [[RES_1_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<?x10xf32>, memref<1xindex>) -> memref<?xf32>
// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK-DAG: affine.store [[VAR_2_]], [[RES_2_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_11_:%.+]] = memref.reshape [[VAR_view_]]([[RES_2_]]) : (memref<?x10xf32>, memref<1xindex>) -> memref<?xf32>
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop)
// CHECK-DAG: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_2_]]{{.}}){
// CHECK-DAG: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_4_]]{{.}} : memref<?xf32>, vector<16xf32>
// CHECK-DAG: [[VAR_6_:%.+]] = math.absf [[LOAD_VAR_reshape_MEM_]] : vector<16xf32>
// CHECK-DAG: [[VAR_7_:%.+]] = vector.fma [[VAR_cst_]], [[VAR_6_]], [[VAR_cst_]]_6 : vector<16xf32>
// CHECK-DAG: [[VAR_8_:%.+]] = arith.divf [[VAR_cst_6_]], [[VAR_7_]] : vector<16xf32>
// CHECK-DAG: [[VAR_9_:%.+]] = vector.fma [[VAR_cst_0_]], [[VAR_8_]], [[VAR_cst_1_]] : vector<16xf32>
// CHECK-DAG: [[VAR_10_:%.+]] = vector.fma [[VAR_9_]], [[VAR_8_]], [[VAR_cst_2_]] : vector<16xf32>
// CHECK-DAG: [[VAR_11_:%.+]] = vector.fma [[VAR_10_]], [[VAR_8_]], [[VAR_cst_3_]] : vector<16xf32>
// CHECK-DAG: [[VAR_12_:%.+]] = vector.fma [[VAR_11_]], [[VAR_8_]], [[VAR_cst_4_]] : vector<16xf32>
// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_12_]], [[VAR_8_]] : vector<16xf32>
// CHECK-DAG: [[VAR_14_:%.+]] = arith.subf [[VAR_cst_7_]], [[VAR_6_]] : vector<16xf32>
// CHECK-DAG: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[VAR_6_]] : vector<16xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = math.exp [[VAR_15_]] : vector<16xf32>
// CHECK-DAG: [[VAR_17_:%.+]] = vector.fma [[VAR_13_]], [[VAR_16_]], [[VAR_cst_5_]] : vector<16xf32>
// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[LOAD_VAR_reshape_MEM_]], [[VAR_cst_7_]] : vector<16xf32>
// CHECK-DAG: [[VAR_19_:%.+]] = arith.subf [[VAR_cst_7_]], [[VAR_17_]] : vector<16xf32>
// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_17_]] : vector<16xi1>, vector<16xf32>
// CHECK: vector.store [[VAR_20_]], [[VAR_reshape_11_]]{{.}}[[VAR_4_]]{{.}} : memref<?xf32>, vector<16xf32>
// CHECK: }
// CHECK: return [[VAR_view_]] : memref<?x10xf32>
}
16 changes: 0 additions & 16 deletions test/mlir/onnx/onnx_math_functions_lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,21 +1,5 @@
// RUN: onnx-mlir-opt -O3 --shape-inference --convert-onnx-to-krnl %s -split-input-file | FileCheck %s

/// onnx.Erf lowering to krnl.erf.
func.func @erf_function(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Erf"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
"func.return"(%0) : (tensor<10x10xf32>) -> ()
}

// CHECK-LABEL erf_function
// CHECK: [[ALLOC:%.+]] = memref.alloc() {{.*}}: memref<10x10xf32>
// CHECK: [[LOOP:%.+]]:2 = krnl.define_loops 2
// CHECK: krnl.iterate
// CHECK: [[IV:%.+]]:2 = krnl.get_induction_var_value([[LOOP]]#0, [[LOOP]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK: [[LOAD:%.+]] = {{.*}}load %arg0[[[IV]]#0, [[IV]]#1] : memref<10x10xf32>
// CHECK: [[ERF:%.+]] = "krnl.erf"([[LOAD]]) : (f32) -> f32
// CHECK: {{.*}}store [[ERF]], [[ALLOC]][[[IV]]#0, [[IV]]#1] : memref<10x10xf32>
// CHECK: return [[ALLOC]] : memref<10x10xf32>

/// onnx.Acos lowering to krnl.acos.
func.func @acos_function(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> {
%0 = "onnx.Acos"(%arg0) : (tensor<10x10xf32>) -> tensor<10x10xf32>
Expand Down
15 changes: 14 additions & 1 deletion test/modellib/ElementwiseModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ namespace test {
// ONNXAddOp
// ONNXDivOp
// ONNXHardSigmoidOp
// ONNXErFop

const float alphaVal = 2.0;
const float betaVal = 0.5;

Elementwise2DLibBuilder::Elementwise2DLibBuilder(const std::string &modelName,
const std::string &onnxOpName, const int I, const int J)
: ModelLibBuilder(modelName), onnxOpName(onnxOpName), I(I), J(J),
inputNum((onnxOpName.compare("ONNXHardSigmoidOp") == 0) ? 1 : 2) {}
inputNum(((onnxOpName.compare("ONNXHardSigmoidOp") == 0) ||
(onnxOpName.compare("ONNXErfOp") == 0))
? 1
: 2) {}

bool Elementwise2DLibBuilder::build() {
llvm::SmallVector<int64_t, 4> shape = {I, J};
Expand Down Expand Up @@ -70,6 +74,10 @@ bool Elementwise2DLibBuilder::build() {
auto op =
builder.create<ONNXHardSigmoidOp>(loc, yType, aVal, alpha, beta);
results.emplace_back(op.getResult());
} else if (onnxOpName.compare("ONNXErfOp") == 0) {
// Erf.
auto op = builder.create<ONNXErfOp>(loc, yType, aVal);
results.emplace_back(op.getResult());
} else
llvm_unreachable("unsupported unary elementwise op");

Expand Down Expand Up @@ -166,6 +174,11 @@ bool Elementwise2DLibBuilder::verifyOutputs() {
val = (val < 1.0) ? val : 1.0;
return val;
};
else if (onnxOpName.compare("ONNXErfOp") == 0)
fct = [](float a) -> float {
float val = erf(a);
return val;
};
else
llvm_unreachable("unsupported binary elementwise op");
for (int64_t i = 0; i < I; ++i) {
Expand Down
10 changes: 10 additions & 0 deletions test/numerical/TestElementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,15 @@ int main(int argc, char *argv[]) {
if (!success)
return 1;

printf("RapidCheck test Erf case generation.\n");
success = rc::check("Gemm implementation correctness", [&]() {
const int maxRange = 128;
const int I = *rc::gen::inRange(1, maxRange);
const int J = *rc::gen::inRange(1, maxRange);
RC_ASSERT(isOMElementwiseTheSameAsNaiveImplFor("ONNXErfOp", I, J));
});
if (!success)
return 1;

return 0;
}

0 comments on commit dea431d

Please sign in to comment.