Skip to content

Commit

Permalink
Adding support for Acosh (#111)
Browse files Browse the repository at this point in the history
* Adding support for Acosh

* Acosh defined for X >= 1

* Acosh not implemented in ONNXRunTime as Float64, using Float32

* Typo in first commit, Acos vs Acosh
  • Loading branch information
dstarkenburg authored Dec 10, 2024
1 parent 4abb65e commit 4a134f7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acos}, args::VarVec, attrs::A
return push_call!(tape, _acos, args[1])
end

function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acosh}, args::VarVec, attrs::AttrDict)
return push_call!(tape, _acosh, args[1])
end

function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
args = [tape.c.name2var[name] for name in nd.input]
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))
Expand Down
1 change: 1 addition & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ _sin(x) = sin.(x)
_cos(x) = cos.(x)
_abs(x) = abs.(x)
_acos(x) = acos.(x)
_acosh(x) = acosh.(x)
mul(xs...) = .*(xs...)
relu(x) = NNlib.relu.(x)
leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a)
Expand Down
5 changes: 5 additions & 0 deletions src/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acos)}, op::Umlaut.
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acosh)}, op::Umlaut.Call)
nd = NodeProto("Acosh", op)
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
nd = NodeProto(
input=[onnx_name(v) for v in reverse(op.args)],
Expand Down
8 changes: 8 additions & 0 deletions test/saveload.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
ort_test(ONNX._acos, A)
end

@testset "Acosh" begin
# ONNXRunTime has no implementation for Acosh(x::Float64), using Float32
A = rand(Float32, 3, 4)
# Acosh defined for A >= 1
A = A .+ 1
ort_test(ONNX._acosh, A)
end

@testset "Gemm" begin
A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3))
ort_test(ONNX.onnx_gemm, A, B')
Expand Down

0 comments on commit 4a134f7

Please sign in to comment.