Skip to content

Commit

Permalink
Adding support for And (#113)
Browse files Browse the repository at this point in the history
* Adding support for And

* Update .gitignore

* Update Project.toml
  • Loading branch information
dstarkenburg authored Jan 22, 2025
1 parent 4a134f7 commit 7bbb35b
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.9'
- '1.10'
- '1'
os:
- ubuntu-latest
arch:
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ONNX"
uuid = "d0dd6a25-fac6-55c0-abf7-829e0c774d20"
version = "0.2.7"
version = "0.3.0"

[deps]
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
Expand All @@ -18,5 +18,5 @@ EnumX = "1"
NNlib = "0.8, 0.9"
ProtoBuf = "1.0"
StaticArrays = "1"
Umlaut = "0.4, 0.5, 0.6, 0.7"
julia = "1.6"
Umlaut = "0.7"
julia = "1.10"
4 changes: 4 additions & 0 deletions src/load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acosh}, args::VarVec, attrs::
return push_call!(tape, _acosh, args[1])
end

function load_node!(tape::Tape, ::OpConfig{:ONNX, :And}, args::VarVec, attrs::AttrDict)
return push_call!(tape, and, args...)
end

function load_node!(tape::Tape, nd::NodeProto, backend::Symbol)
args = [tape.c.name2var[name] for name in nd.input]
attrs = convert(Dict{Symbol, Any}, Dict(nd.attribute))
Expand Down
4 changes: 4 additions & 0 deletions src/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ function onnx_flatten(x; axis = 1)
return flatten(x; dim = dim)
end

function and(x, y)
return x .& y
end

add(xs...) = .+(xs...)
sub(xs...) = .-(xs...)
_sin(x) = sin.(x)
Expand Down
5 changes: 5 additions & 0 deletions src/save.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acosh)}, op::Umlaut
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(and)}, op::Umlaut.Call)
nd = NodeProto("And", op)
push!(g.node, nd)
end

function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call)
nd = NodeProto(
input=[onnx_name(v) for v in reverse(op.args)],
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Umlaut = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"

[compat]
Umlaut = "0.4"
Umlaut = "0.7"
10 changes: 10 additions & 0 deletions test/saveload.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name
ort_test(tape, args...)
end

@testset "And" begin
# Testing matricies of similar shape
args = rand(Bool, 3, 4), rand(Bool, 3, 4)
ort_test(ONNX.and, args...)

# Testing Numpy-style broadcasting
args = rand(Bool, 3, 3), rand(Bool, 1, 3)
ort_test(ONNX.and, args...)
end

@testset "Basic ops" begin
args = (rand(3, 4), rand(3, 4))
ort_test(ONNX.add, args...)
Expand Down

0 comments on commit 7bbb35b

Please sign in to comment.