diff --git a/.gitattributes b/.gitattributes index e02ed0b7..6cdd98b8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1 @@ -paper/* linguist-documentation +archives/paper/* linguist-documentation diff --git a/.github/workflows/TagBot.yml b/.github/workflows/TagBot.yml new file mode 100644 index 00000000..b5f35cc0 --- /dev/null +++ b/.github/workflows/TagBot.yml @@ -0,0 +1,14 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/Manifest.toml b/Manifest.toml index a9d3fe57..cf400cba 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,154 +1,16 @@ # This file is machine-generated - editing it directly is not advised -[[AbstractTrees]] -deps = ["Markdown", "Test"] -git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.2.1" - -[[Adapt]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "53d8fec4f662088c1202530e338a11a919407f3b" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "0.4.2" - -[[BSON]] -deps = ["Profile", "Test"] -git-tree-sha1 = "6453cef4f9cb8ded8e28e4d6d12e11e20eb692ea" -uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.2.3" - [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -[[BinDeps]] -deps = ["Compat", "Libdl", "SHA", "URIParser"] -git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9" -uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee" -version = "0.8.10" - -[[BinaryProvider]] -deps = ["Libdl", "SHA"] -git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.4" - -[[CSTParser]] -deps = ["Tokenize"] -git-tree-sha1 = "376a39f1862000442011390f1edf5e7f4dcc7142" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "0.6.0" - -[[CodecZlib]] -deps = ["BinaryProvider", "Libdl", "Test", "TranscodingStreams"] -git-tree-sha1 = "36bbf5374c661054d41410dc53ff752972583b9b" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.5.2" - -[[ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "10050a24b09e8e41b951e9976b109871ce98d965" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.8.0" - -[[Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport", "Test"] -git-tree-sha1 = "9f0a0210450acb91c730b730a994f8eef1d3d543" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.9.5" - -[[CommonSubexpressions]] -deps = ["Test"] -git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.2.0" - -[[Compat]] -deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] -git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "2.1.0" - -[[Crayons]] -deps = ["Test"] -git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.0.0" - -[[DataFlow]] -deps = ["Juno", "Lazy", "MacroTools"] -git-tree-sha1 = "e95561c5bf352d58eacf348e30e85f2d87d37321" -uuid = "a237f610-4214-5ca7-a9c6-385896804134" -version = "0.5.0" - -[[DataStructures]] -deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"] -git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.15.0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DelimitedFiles]] -deps = ["Mmap"] -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" - -[[DiffResults]] -deps = ["Compat", "StaticArrays"] -git-tree-sha1 = "34a4a1e8be7bc99bc9c611b895b5baf37a80584c" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "0.0.4" - -[[DiffRules]] -deps = ["Random", "Test"] -git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "0.0.10" - [[Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -[[FixedPointNumbers]] -git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.6.1" - -[[Flux]] -deps = ["AbstractTrees", "Adapt", "CodecZlib", "Colors", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"] -git-tree-sha1 = "ea0eedf3f8b3bd8cb4427ccc4d51735d997f426c" -repo-rev = "master" -repo-url = "https://github.com/FluxML/Flux.jl.git" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.8.3" - -[[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] -git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.3" - [[InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[Juno]] -deps = ["Base64", "Logging", "Media", "Profile", "Test"] -git-tree-sha1 = "4e4a8d43aa7ecec66cadaf311fbd1e5c9d7b9175" -uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d" -version = "0.7.0" - -[[Lazy]] -deps = ["Compat", "MacroTools", "Test"] -git-tree-sha1 = "aec38c7e7f255a678af22651c74100e3cd39ea20" -uuid = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0" -version = "0.13.2" - -[[LibGit2]] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -159,174 +21,34 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -[[MacroTools]] -deps = ["CSTParser", "Compat", "DataStructures", "Test"] -git-tree-sha1 = "daecd9e452f38297c686eba90dba2a6d5da52162" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.0" - [[Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -[[Media]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58" -uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27" -version = "0.5.0" - -[[Missings]] -deps = ["SparseArrays", "Test"] -git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.1" - -[[Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[NNlib]] -deps = ["Libdl", "LinearAlgebra", "Requires", "Statistics", "TimerOutputs"] -git-tree-sha1 = "0c667371391fc6bb31f7f12f96a56a17098b3de8" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.0" - -[[NaNMath]] -deps = ["Compat"] -git-tree-sha1 = "ce3b85e484a5d4c71dd5316215069311135fa9f2" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.2" - -[[OrderedCollections]] -deps = ["Random", "Serialization", "Test"] -git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.1.0" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[Profile]] -deps = ["Printf"] -uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" - [[ProtoBuf]] -deps = ["Random", "Serialization", "Sockets", "Test"] -git-tree-sha1 = "cc57b8d2d37f51f9bdfb440ba933c75764e6e171" +deps = ["Logging"] +git-tree-sha1 = "eb9460532c18a82d77f68bf90bb270f6f7aac3a9" uuid = "3349acd9-ac6a-5e09-bcdb-63829b23a429" -version = "0.7.0" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "0.9.0" [[Random]] deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -[[Reexport]] -deps = ["Pkg"] -git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "0.2.0" - -[[Requires]] -deps = ["Test"] -git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "0.5.2" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -[[SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -[[SortingAlgorithms]] -deps = ["DataStructures", "Random", "Test"] -git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "0.3.1" - [[SparseArrays]] deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[[SpecialFunctions]] -deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] -git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.7.2" - -[[StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.11.0" - [[Statistics]] deps = ["LinearAlgebra", "SparseArrays"] uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -[[StatsBase]] -deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"] -git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.30.0" - [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[TimerOutputs]] -deps = ["Crayons", "Printf", "Test", "Unicode"] -git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.0" - -[[Tokenize]] -git-tree-sha1 = "0de343efc07da00cd449d5b04e959ebaeeb3305d" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.4" - -[[Tracker]] -deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"] -git-tree-sha1 = "327342fec6e09f68ced0c2dc5731ed475e4b696b" -uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -version = "0.2.2" - -[[TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "a25d8e5a28c3b1b06d3859f30757d43106791919" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.4" - -[[URIParser]] -deps = ["Test", "Unicode"] -git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69" -uuid = "30578b45-9adc-5946-b283-645ec420af67" -version = "0.4.0" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[ZipFile]] -deps = ["BinaryProvider", "Libdl", "Printf"] -git-tree-sha1 = "580ce62b6c14244916cc28ad54f8a2e2886f843d" -uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" -version = "0.8.3" diff --git a/Project.toml b/Project.toml index 23769a55..4cc8b13a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,12 @@ name = "ONNX" uuid = "d0dd6a25-fac6-55c0-abf7-829e0c774d20" -version = "0.1.1" +version = "0.2.0" [deps] -BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -DataFlow = "a237f610-4214-5ca7-a9c6-385896804134" -Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -julia = "1" \ No newline at end of file +ProtoBuf = "= 0.9.0" +julia = "1" diff --git a/README.md b/README.md index f77ca468..7608cb0a 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,6 @@ # ONNX -[![Build Status](https://travis-ci.org/ayush1999/ONNX.jl.svg?branch=master)](https://travis-ci.org/ayush1999/ONNX.jl.svg?branch=master) -[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3994216.svg)](https://doi.org/10.5281/zenodo.3994216) +ONNX.jl is currently in the process of the total recostruction and doesn't have user friendly API yet. You may want to see: -ONNX.jl : Read [ONNX](https://onnx.ai/) graphs and load these models in Julia. ONNX.jl provides an instance of transfer learning into Julia, by reading pretrained models from ONNX format to [Flux.jl](https://github.com/FluxML/Flux.jl). This is done by generating the DataFlow graph from the model, and then reading it as Julia code. - -## Loading models - -You need to have the `model.onnx` ( or in some cases `model.pb` ) file, which will be read. Several pretrained ONNX model files can also be downloaded from [here](https://github.com/onnx/models). Now that we have the `model.onnx` file, we can read it into Flux as : - -``` -julia> using Flux, ONNX # Import the required packages. -julia> ONNX.load_model("model.onnx") # If you are in some other directory, specify the entire path. - # This creates two files: model.jl and weights.bson. -julia> weights = ONNX.load_weights("weights.bson") # Read the weights from the binary serialized file. -julia> model = include("model.jl") # Loads the model from the model.jl file. -``` - -And `model` is the corresponding model in Flux! - -This package is currently under development, don't tell us we didn't warn you! - -## Running the tests - -It's always better to run the tests before moving on to importing a model. The operator tests ensure that all ops are working. Use `]test ONNX` to run the tests. - -* Running these tests may take some time, as it may initially download the test files if you don't already have them.(You need to have git preinstalled in order to download the tests) - -In order to read more about these tests and run model specific tests, please go through the docs in the `test` directory. - -## Contributing and Help - -If you're looking to contribute to the development of this package, and don't know where to begin, [this blog post](https://medium.com/@ayush1999/onnx-jl-the-past-present-and-future-d3b497a0cd4c) can be a good -starting point. It lists the approach taken towards developing this package, the current obstacles, and the work to be done in the future. - -Since this package is currently under development, feel free to open an [issue](https://github.com/FluxML/ONNX.jl/issues) if you find any error/bug. - -For more discussion, you can get in touch with us on [Julia Slack](https://slackinvite.julialang.org/). We're pretty active on the #machine-learning channel. \ No newline at end of file + * [old version of this README](https://github.com/FluxML/ONNX.jl/blob/b7c3d0b48036947257e439c31e00430b0a94690a/README.md) + * [RFC for the new implementation](https://github.com/FluxML/ML-Coordination-Tracker/discussions/24) \ No newline at end of file diff --git a/paper/.latexmkrc b/archives/paper/.latexmkrc similarity index 100% rename from paper/.latexmkrc rename to archives/paper/.latexmkrc diff --git a/paper/bib.tex b/archives/paper/bib.tex similarity index 100% rename from paper/bib.tex rename to archives/paper/bib.tex diff --git a/paper/header.tex b/archives/paper/header.tex similarity index 100% rename from paper/header.tex rename to archives/paper/header.tex diff --git a/paper/jlcode.sty b/archives/paper/jlcode.sty similarity index 100% rename from paper/jlcode.sty rename to archives/paper/jlcode.sty diff --git a/paper/journal_dat.tex b/archives/paper/journal_dat.tex similarity index 100% rename from paper/journal_dat.tex rename to archives/paper/journal_dat.tex diff --git a/paper/juliacon.bst b/archives/paper/juliacon.bst similarity index 100% rename from paper/juliacon.bst rename to archives/paper/juliacon.bst diff --git a/paper/juliacon.cls b/archives/paper/juliacon.cls similarity index 100% rename from paper/juliacon.cls rename to archives/paper/juliacon.cls diff --git a/paper/juliagraphs.png b/archives/paper/juliagraphs.png similarity index 100% rename from paper/juliagraphs.png rename to archives/paper/juliagraphs.png diff --git a/paper/logojuliacon.pdf b/archives/paper/logojuliacon.pdf similarity index 100% rename from paper/logojuliacon.pdf rename to archives/paper/logojuliacon.pdf diff --git a/paper/onnx-3.png b/archives/paper/onnx-3.png similarity index 100% rename from paper/onnx-3.png rename to archives/paper/onnx-3.png diff --git a/paper/paper.fdb_latexmk b/archives/paper/paper.fdb_latexmk similarity index 100% rename from paper/paper.fdb_latexmk rename to archives/paper/paper.fdb_latexmk diff --git a/paper/paper.fls b/archives/paper/paper.fls similarity index 100% rename from paper/paper.fls rename to archives/paper/paper.fls diff --git a/paper/paper.tex b/archives/paper/paper.tex similarity index 100% rename from paper/paper.tex rename to archives/paper/paper.tex diff --git a/paper/paper.yml b/archives/paper/paper.yml similarity index 100% rename from paper/paper.yml rename to archives/paper/paper.yml diff --git a/paper/prep.rb b/archives/paper/prep.rb similarity index 100% rename from paper/prep.rb rename to archives/paper/prep.rb diff --git a/paper/ref.bib b/archives/paper/ref.bib similarity index 100% rename from paper/ref.bib rename to archives/paper/ref.bib diff --git a/src/ONNX.jl b/src/ONNX.jl index b251a8ab..53465686 100644 --- a/src/ONNX.jl +++ b/src/ONNX.jl @@ -1,12 +1,6 @@ module ONNX - -using ProtoBuf, MacroTools, DataFlow, Statistics - -include("onnx_pb.jl") -include("convert.jl") -include("new_types.jl") -include("graph/graph.jl") - -using Flux - -end # module + const _ProtoBuf_Top_ = @static isdefined(parentmodule(@__MODULE__), :_ProtoBuf_Top_) ? (parentmodule(@__MODULE__))._ProtoBuf_Top_ : parentmodule(@__MODULE__) + include("onnx_pb.jl") + include("read.jl") + include("write.jl") +end diff --git a/src/convert.jl b/src/convert.jl deleted file mode 100644 index 4ef2d0f3..00000000 --- a/src/convert.jl +++ /dev/null @@ -1,252 +0,0 @@ -using BSON - -rawproto(io::IO) = readproto(io, Proto.ModelProto()) -rawproto(path::String) = open(rawproto, path) - -""" -Helper function to check the layers present -in the ONNX model. -""" -function layers(filename::String) - f = filename |> open |> rawproto - lay = [] - for node in f.graph.node - push!(lay, node.op_type) - end - - return lay |> unique -end - -""" -Retrieve only the useful information from a AttributeProto -object into a Dict format. -""" -function convert_model(x::Proto.AttributeProto) - if (x._type != 0) - field = [:f, :i, :s, :t, :g, :floats, :ints, :strings, :tensors, :graphs][x._type] - return Symbol(x.name) => getfield(x, field) - end -end - -convert_array(as) = Dict(convert_model(a) for a in as) - -""" -Convert a ValueInfoProto to ValueInfo. -""" -function convert_model(model::Proto.ValueInfoProto) - a = Types.ValueInfo(model.name, model.doc_string) - return a -end - -""" -Convert an OperatorSetIdProto to Dict. -""" -function convert_model(model::ONNX.Proto.OperatorSetIdProto) - a = Dict{Symbol, Any}() - fields = [:domain, :version] - for ele in fields - a[ele] = getfield(model, ele) - end - return a -end - -""" -Convert a StringStringEntryProto to Dict. -""" -function convert_model(model::ONNX.Proto.StringStringEntryProto) - a = Dict{Symbol, Any}() - fields = [:key, :value] - for ele in fields - a[ele] = getfield(model, ele) - end - return a -end - -""" -Get the array from a TensorProto object. -""" -function get_array(x::Proto.TensorProto) - if (x.data_type == 1) - if !isempty(x.float_data) - x = reshape(reinterpret(Float32, x.float_data), reverse(x.dims)...) - else - x = reshape(reinterpret(Float32, x.raw_data), reverse(x.dims)...) - end - return x - end - if x.data_type == 7 - if !isempty(x.raw_data) - x = reshape(reinterpret(Int64, x.raw_data), reverse(x.dims)...) - else - x = reshape(reinterpret(Int64, x.int64_data), reverse(x.dims)...) - end - return x - end - if x.data_type == 9 - x = reshape(reinterpret(Int8, x.raw_data), reverse(x.dims)...) - return x - end - if x.data_type == 6 - x = reshape(reinterpret(Int32, x.raw_data), reverse(x.dims)...) - return x - end - if x.data_type == 11 - if !isempty(x.raw_data) - x = reshape(reinterpret(Float64, x.raw_data), reverse(x.dims)...) - else - x = Base.convert(Array{Float32, N} where N, reshape(x.double_data , reverse(x.dims)...)) - end - return x - end - if x.data_type == 10 - x = reshape(reinterpret(Float16, x.raw_data), reverse(x.dims)...) - return x - end -end - -""" -Convert a ModelProto object to a Model type. -""" -function convert(model::Proto.ModelProto) - # conversion for opset_import - arr1 = Array{Any, 1}() - for ele in model.opset_import - push!(arr1, convert_model(ele)) - end - - # conversion for stringstringentry proto - arr2 = Array{Any, 1}() - for ele in model.metadata_props - push!(arr2, convert_model(ele)) - end - - m = Types.Model(model.ir_version, - arr1, model.producer_name, - model.producer_version, - model.domain, model.model_version, - model.doc_string, convert(model.graph), - arr2) - return m -end - -""" -Convert a GraphProto object to Graph type. -""" -function convert(model::Proto.GraphProto) - # conversion for vector of nodeproto - arr1 = Array{Any, 1}() - for ele in model.node - push!(arr1, convert(ele)) - end - - # conversion for vector of tensorproto - arr2 = Dict{Any, Any}() - for ele in model.initializer - arr2[ele.name] = get_array(ele) - end - - #conversion for vector of valueinfoproto - arr3 = Array{Types.ValueInfo ,1}() - for ele in model.input - push!(arr3, convert_model(ele)) - end - - arr4 = Array{Types.ValueInfo ,1}() - for ele in model.output - push!(arr4, convert_model(ele)) - end - - arr5 = Array{Types.ValueInfo ,1}() - for ele in model.value_info - push!(arr5, convert_model(ele)) - end - - m = Types.Graph(arr1, - model.name, - arr2, model.doc_string, - arr3, arr4, arr5) - return m -end - -""" -Convert a Proto.NodeProto to Node type. -""" -function convert(model::Proto.NodeProto) - # Conversion of attribute - arr1 = convert_array(model.attribute) - - m = Types.Node(model.input, - model.output, - model.name, - model.op_type, - model.domain, - arr1, - model.doc_string) - return m -end - -function parent(path) - temp = split(path, "/") - res = "" - for element in temp - if (element != temp[end]) - res = res * element * "/" - end - end - return res -end - -""" -Serialize the weights to a binary format and stores in the -weights.bson file. -""" -function write_weights(model) - f = rawproto(model) - g = convert(f.graph) - temp = weights(g) # If weights are stored in Constant, we'll store them - weights_dict = Dict{Symbol, Any}() # in reverse order. - for ele in keys(temp) - weights_dict[Symbol(ele)] = temp[ele] - end - if '/' in model - cd(parent(model)) - end - bson("weights.bson", weights_dict) -end - -""" -Retrieve the dictionary form the binary file (String to Any). -format. -""" -function load_weights(name) - a = BSON.load(name) - weights = Dict{String, Any}() - for ele in keys(a) - weights[string(ele)] = a[ele] - end - return weights -end - -""" -Create the model.jl file and write the model to it. -""" -function write_julia_file(model_file) - f = readproto(open(model_file), ONNX.Proto.ModelProto()) - data = ONNX.code(convert(f).graph) - touch("model.jl") - str1="using Statistics \n" - str2="Mul(a,b,c) = b .* reshape(c, (1,1,size(c)[a],1)) \n" - str3 = "Add(axis, A ,B) = A .+ reshape(B, (1,1,size(B)[1],1)) \n" - open("model.jl","w") do file - write(file, str1*str2*str3*string(data)) - end -end - -""" -Create the two files from the model.pb file. -""" -function load_model(model) - write_weights(model) - write_julia_file(model) - return nothing -end \ No newline at end of file diff --git a/src/graph/graph.jl b/src/graph/graph.jl deleted file mode 100644 index 02d2d32b..00000000 --- a/src/graph/graph.jl +++ /dev/null @@ -1,118 +0,0 @@ -using DataFlow: Call, constant, inputnode, syntax - -const ops = Dict{Symbol,Any}() -include("ops.jl") - -# This is to fetch weights when they are stored in -# the constant tensor and not in intializer. -function get_weights(g::Types.Graph) - temp = Dict{Any, Any}() - for node in g.node - if node.op_type == "Constant" - temp[node.name] = get_array(node.attribute[:value]) - end - end - return temp -end - -vcall(a...) = vertex(Call(), constant.(a)...) - -# Placeholder for array values -weights(f::Types.Model) = weights(f.graph) - -""" -Checks location of weights and returns appropriate -values. -Note: Constant weight is deprecated now. -""" -function weights(g::Types.Graph) - count = 0 - for node in g.node - if (node.op_type == "Constant") - count = count + 1 - break - end - end - if (count > 0) - return get_weights(g) - end - return g.initializer -end - -function inputs(g::Types.Graph) - ws = weights(g) - i = 0 - Dict(x.name => haskey(ws, x.name) ? - constant(:(weights[$(x.name)])) : - inputnode(i += 1) - for x in g.input), i -end - -function _graph(g::Types.Graph) - vs, n = inputs(g) - for node in g.node - if node.op_type == "Constant" - vs[node.output[1]] = ops[Symbol(node.op_type)](node, map(n -> vs[n], node.input)...) - else - vs[node.output[1]] = ops[Symbol(node.op_type)](node.attribute, map(n -> vs[n], node.input)...) - end - end - return vs[g.output[1].name], n -end - -# Graph Cleanups - -ischainable(v) = DataFlow.iscall(v) && all(x -> DataFlow.isconstant(x), v[3:end]) -chaindepth(v) = ischainable(v) ? chaindepth(v[2]) + 1 : 0 - -function _tochain(v, ch) - ischainable(v) || return v - if length(v[:]) ≤ 2 - push!(ch, v[1]) - else - push!(ch, vertex(DataFlow.Lambda(1, vcall(v[1], inputnode(1), v[3:end]...)))) - end - return _tochain(v[2], ch) -end - -function tochain(v) - ch = [] - v = _tochain(v, ch) - vcall(vcall(:Chain, reverse(ch)...), v) -end - -function chainify(v) - MacroTools.prewalk(v) do v - chaindepth(v) > 3 ? tochain(v) : v - end -end - -# Interface -function graph(g::Types.Graph) - v, n = _graph(g) - v = chainify(v) - return vertex(DataFlow.Lambda(n, v)) |> DataFlow.λopen |> DataFlow.λclose -end - -""" -Write out the Julia code for the model -""" -code(g::Types.Graph) = graph(g) |> syntax |> - MacroTools.flatten |> MacroTools.gensym_ids - -# function breakcalls(ex) -# MacroTools.prewalk(ex) do ex -# iscall(ex) || return ex -# @capture(ex, f_(args__) | f_.(args__)) -# count(x -> iscall(x), args) ≥ 2 || return ex -# vars = [] -# args = map(args) do x -# iscall(x) || return x -# var = gensym() -# push!(vars, :($var = $x)) -# return var -# end -# ex = isexpr(ex, :call) ? :($f($(args...))) : :($f.($(args...))) -# :($(vars...); $ex) -# end -# end diff --git a/src/graph/ops.jl b/src/graph/ops.jl deleted file mode 100644 index f29fa691..00000000 --- a/src/graph/ops.jl +++ /dev/null @@ -1,508 +0,0 @@ -# This file contains the implementation of various operators. -# Tests for them is at test/runtests.jl. - -using Base -using Statistics -# TODO: we need kwarg support for many of these - -# Generic -get_tuple(x) = (x...,) -get_tuple() = nothing -convert_type(x) = Base.convert(Array{Float32, 1}, x) - -ops[:Concat] = function (params, ip...) - return vcall(:cat, ip..., Symbol("dims = 3")) -end - -ops[:Gemm] = function (params, A, B, C) - if !haskey(params, :transA) - params[:transA] = 0 - end - if !haskey(params, :transB) - params[:transB] = 0 - end - if !haskey(params, :alpha) - params[:alpha] = 1 - end - if !haskey(params, :beta) - params[:beta] = 1 - end - if !haskey(params, :broadcast) - params[:broadcast] = 0 - end - if (params[:transA] != 1) - A = vcall(:permutedims, A, vcall(:reverse, vcall(:range, 1, vcall(:ndims, A)))) - end - if (params[:transB] != 1) - B = vcall(:permutedims, B, vcall(:reverse, vcall(:range, 1, vcall(:ndims, B)))) - end - ip1 = vcall(:*, params[:alpha], A, B) - ip2 = vcall(:*, params[:beta], C) - if params[:broadcast] == 0 - ip1 = vcall(:permutedims, ip1, vcall(:reverse, vcall(:range, 1, vcall(:ndims, ip1)))) - res = vcall(:broadcast, :+, ip1, ip2) - return res - end - res = vcall(:broadcast, :+, ip1, ip2) - return vcall(:permutedims, res, vcall(:reverse, vcall(:range, 1, vcall(:ndims, res)))) -end - -# Image - -ops[:Conv] = function (params, x, w, b...) - if !haskey(params, Symbol("pads")) - params[:pads] = [0,0,0,0] - end - if !haskey(params, Symbol("strides")) - params[:strides] = (1,1) - end - if !haskey(params, Symbol("dilations")) - params[:dilations] = (1,1) - end - if (haskey(params, Symbol("auto_pad"))) - if (String(params[:auto_pad]) == "SAME_UPPER" || String(params[:auto_pad] == "SAME_LOWER")) - temp = Base.convert(Array{Int64,1}, (params[:kernel_shape] .- 1)./2) # Only for strides = [1,1] - params[:pads] = vcat(temp, temp) # To Do: Add support for other stride values. - elseif String(params[:auto_pad]) == "VALID" - params[:pads] = [0,0,0,0] - end - end - #if haskey(params, :group) - # s = vcall(:Int, vcall(:/, vcall(:size, x, 3), params[:group])) - # x = vcall(:reshape, x, vcall(:size, x, 1), vcall(:size, x, 2), s, params[:group], vcall(:size, x, 4)) - # temp_x = vcall(:getindex, x, :,:,:,1,:) - # temp = vcall(vcall(:Conv, Float32[0], :relu, - # Symbol("stride=$((params[:strides]...,))"), Symbol("pad=$(pads(params[:pads]))"), - # Symbol("dilation=$((params[:dilations]...,))")), temp_x) - # if isempty(b) - # for i=2:params[:group] - # temp = vcall(:cat, 3, temp, vcall(vcall(:Conv, Float32[0], :relu, - # Symbol("stride=$((params[:strides]...,))"), Symbol("pad=$(pads(params[:pads]))"), - # Symbol("dilation=$((params[:dilations]...,))")), temp)) - # end - # - # else - # for i=2:params[:group] - # temp = vcall(:cat, 3, temp, vcall(vcall(:Conv, b[1], :relu, - # Symbol("stride=$((params[:strides]...,))"), Symbol("pad=$(pads(params[:pads]))"), - # Symbol("dilation=$((params[:dilations]...,))")), temp)) - # end - # end - # return temp - #end - if isempty(b) - return vcall(vcall(:CrossCor, w, Float32[0], :relu, Symbol("stride=$((params[:strides]...,))"), - Symbol("pad=$((params[:pads]...,))"), Symbol("dilation=$((params[:dilations]...,))")), x) - # temp change (Until type fix) - end - vcall(vcall(:CrossCor, w, b[1], Symbol("stride=$((params[:strides]...,))"), - Symbol("pad=$((params[:pads]...,))"), Symbol("dilation=$((params[:dilations]...,))")), x) -end - -ops[:MaxPool] = function (params, x) - if !(haskey(params, :strides)) - params[:strides] = [1,1] - end - if !(haskey(params, :pads)) - params[:pads] = [0,0,0,0] - end - strides = params[:strides] == params[:kernel_shape] ? [] : [params[:strides]] - if length(params[:kernel_shape]) == 1 - push!(params[:kernel_shape], 1) - n_size = vcall(:Tuple, vcall(:push!, vcall(:collect, vcall(:size, x)), 1)) - new_x = vcall(:reshape, x, n_size) - return vcall(:dropdims, vcall(vcall(:MaxPool, (params[:kernel_shape]...,), - Symbol("pad=$((params[:pads]...,))"),Symbol("stride=$((params[:strides]...,))")), new_x), - Symbol("dims=4")) - end - - vcall(vcall(:MaxPool, (params[:kernel_shape]...,), Symbol("pad=$((params[:pads]...,))"),Symbol("stride=$((params[:strides]...,))")), x) -end - -ops[:GlobalAveragePool] = function (params, x) - vcall(:mean, x, Symbol("dims = (1,2)")) -end - -ops[:GlobalMaxPool] = function (params, x) - vcall(:getindex, vcall(:findmax, x, Symbol("dims=(1,2)")), 1) -end - -ops[:AveragePool] = function (params, x) - length(params[:kernel_shape]) <= 2 || error("Only averagepool2d currently supported") - if !haskey(params, :strides) - params[:strides] = [1,1] - end - strides = params[:strides] == params[:kernel_shape] ? [] : [params[:strides]] - if !haskey(params, :pads) - params[:pads] = [0,0,0,0] - end - if length(params[:kernel_shape]) == 1 - push!(params[:kernel_shape], 1) - n_size = vcall(:Tuple, vcall(:push!, vcall(:collect, vcall(:size, x)), 1)) - new_x = vcall(:reshape, x, n_size) - return vcall(:dropdims, vcall(vcall(:MeanPool, (params[:kernel_shape]...,), Symbol("pad=$((params[:pads]...,))"), - Symbol("stride=$((params[:strides]...,))")), new_x), Symbol("dims=4")) - end - if params[:pads] == [0,0,0,0] - return vcall(vcall(:MeanPool, (params[:kernel_shape]...,), Symbol("pad=$((params[:pads]...,))"), - Symbol("stride=$((params[:strides]...,))")), x) - else - params[:strides_temp] = [1,1] - params[:kernel_shape_temp] = [1,1] - params[:pads_temp] = [0,0,0,0] - temp = vcall(vcall(:MeanPool, (params[:kernel_shape_temp]...,), Symbol("pad=$((params[:pads]...,))"), - Symbol("stride=$((params[:strides_temp]...,))")), x) - return vcall(vcall(:MeanPool, (params[:kernel_shape]...,), Symbol("pad=$((params[:pads_temp]...,))"), - Symbol("stride=$((params[:strides]...,))")), temp) - end -end - -ops[:BatchNormalization] = function (params, x, scale, b, mean, var) - if !haskey(params ,Symbol("momentum")) - params[:momentum] = 0.9 - end - if !haskey(params, Symbol("epsilon")) - params[:epsilon] = 1e-5 - end - t = typeof(params[:momentum]) - q = vcall(:broadcast, :+, params[:epsilon], var) - p = vcall(:broadcast, sqrt ,q) - r = vcall(:broadcast, Float32, p) - return vcall(vcall(:BatchNorm,identity, b, scale, vcall(:broadcast, :Float32, mean), r, t(params[:epsilon]), params[:momentum], false), x) -end - -function slice(a, s, e) - return a[s:e] -end - -ops[:LSTM] = function(params, ip...) - if length(ip) == 3 - len = params[:hidden_size] - arg1 = vcall(reshape, ip[2], (4*len,2)) - arg2 = vcall(reshape, ip[3], (4*len,3)) - ip_ = vcall(reshape, ip[1], vcall(slice ,vcall(:size, ip[1]), 1, 2)) - - a = vcall(LSTM, arg1, arg2, zeros(len*4), zeros(len), zeros(len)) - - return vcall(a, ip_) - elseif length(ip) == 4 - len = params[:hidden_size] - arg1 = vcall(reshape, ip[2], (4*len,3)) - arg2 = vcall(reshape, ip[3], (4*len,4)) - arg3 = ip[4][1:4*len] - b1 = vcall(:broadcast, Float32, vcall(reinterpret, Float32, vcall(zeros, 2))) - a = vcall(LSTM, arg1, arg2, arg3, b1, b1) - - ip_ = vcall(reshape, ip[1], vcall(slice ,vcall(:size, ip[1]), 1, 2)) - return vcall(a, ip_) - end -end - -# Regularise - -ops[:Dropout] = function (params, x) - return vcall(:identity, x) # Inference mode: Dropout just bypasses input. -end - -# Activation - -iscallp(f, v) = DataFlow.iscall(v) && f(v[1]) -islayer(v, name) = iscallp(l -> iscallp(x -> x == constant(name), l), v) - -ops[:Identity] = function(params, x) - vcall(:identity, x) -end - -ops[:Flatten] = function(params, x) - if !haskey(params, :axis) - params[:axis] = 1 - end - l = vcall(:length, x) - rev = vcall(:reverse, vcall(:size, x)) - if (params[:axis] == 0) - return vcall(:reshape, x, l, 1) - else - s = vcall(:prod, vcall(:getindex, rev, 1:params[:axis])) - return vcall(:reshape, x, vcall(:div, l, s), s) - end -end - -ops[:Relu] = function (params, x) - vcall(broadcast, :relu, x) - #end -end - -ops[:LeakyRelu] = function(params, x) - if !haskey(params, :alpha) - params[:alpha] = 0.01 - end - vcall(:broadcast, :leakyrelu, x, params[:alpha]) -end - -ops[:PRelu] = function(params, x, slope) - ip1 = vcall(:broadcast, :clamp, x, 0, Inf) - ip2 = vcall(:.*, vcall(:broadcast, :clamp, x, -Inf, 0), slope) - return vcall(:broadcast, Float32, vcall(:+, ip1, ip2)) -end - -ops[:ArgMax] = function(params, x) - return vcall(Flux.argmax, x) -end - -ops[:Abs] = function (params, x) - vcall(:broadcast, abs, x) -end - -ops[:Clip] = function (params, x) - if !haskey(params, :min) - params[:min] = vcall(:getindex, vcall(:findmin, x), 1) - end - if !haskey(params, :max) - params[:max] = vcall(:getindex, vcall(:findmax, x), 1) - end - vcall(:broadcast, clamp, x, params[:min], params[:max]) -end - -ops[:Equal] = function(params, x, y) - return vcall(:broadcast, :Int, vcall(:broadcast, :isequal, x, y)) -end - -ops[:Greater] = function(params, x, y) - return vcall(:broadcast, :Int, vcall(:broadcast, :isless, y, x)) -end - -ops[:Sigmoid] = function (params, x) - vcall(:sigmoid, x) -end - -ops[:Softmax] = function (params, x) - vcall(:softmax, vcall(:vec, x)) -end - -ops[:Floor] = function (params, x) - vcall(:broadcast, :floor, x) -end - -ops[:Exp] = function(params, x) - vcall(:broadcast, :exp, x) -end - -ops[:Log] = function(params, x) - vcall(:broadcast, :log, x) -end - -ops[:Neg] = function(params, x) - vcall(:*, -1, x) -end - -ops[:Sum] = function (params, x, y...) - if isempty(y) - return vcall(:.+, x, 0) - end - vcall(:+, x, y[1]) -end - -ops[:Cast] = function(params, x) - if (params[:to] == 1) - return vcall(:broadcast, :Float32, x) - elseif params[:to] == 10 - return vcall(:broadcast, :Float16, x) - elseif params[:to] == 11 - return vcall(:broadcast, :Float64, x) - end -end - -ops[:Constant] = function (params) - constant(Symbol("weights[\"$(params.name)\"]")) -end - -ops[:Ceil] = function (params ,x) - vcall(:broadcast, :ceil, x) -end - -ops[:Unsqueeze] = function(params, x) - l1 = length(params[:axes]) - l2 = vcall(:+, l1, vcall(:ndims, x)) - temp = x - for ele in params[:axes] - temp = vcall(Flux.unsqueeze, temp, vcall(:-, vcall(:+, vcall(:ndims, temp), 1), ele)) - end - return temp -end - -ops[:Reshape] = function(params, tensor1, shape...) - if haskey(params, :shape) - return vcall(:reshape, tensor1, vcall(:broadcast, Int64, vcall(:Tuple, params[:shape]))) - end - vcall(:reshape, tensor1, vcall(:broadcast, Int64, vcall(:Tuple, vcall(:reverse, shape[1])))) -end - -ops[:Transpose] = function(params ,tensor) - temp_tens = vcall(:permutedims, tensor, vcall(:reverse, vcall(:range, 1, vcall(:ndims, tensor)))) - order = vcall(:.+, params[:perm], 1) - l = vcall(:permutedims, temp_tens, order) - return vcall(:permutedims, l, vcall(:reverse, vcall(:range, 1, vcall(:ndims, l)))) -end - -ops[:LRN] = function(params, x) - if !haskey(params, :bias) - params[:bias] = 1 - end - if !haskey(params, :alpha) - params[:alpha] = 1e-4 - end - if !haskey(params, :beta) - params[:beta] = 0.75 - end - return vcall(vcall(:LRNorm, params[:bias], params[:size], params[:alpha], params[:beta]), x) - # currently, just bypassing the output - #return vcall(:.+, 0, x) -end - -#To-Do : add broadcast here (Urgent) -# Add axis condition here -ops[:Add] = function(params, A, B) - s1 = vcall(:size, A) - s2 = vcall(:size, B) - if (s1==s2) - return vcall(:Add, params[:axis], A, B) - else - return vcall(:.+, A, B) - end -end - -ops[:Sub] = function(params, A , B) - s1 = vcall(:size, A) - s2 = vcall(:size, B) - if (s1==s2) - return vcall(:-, A, B) - else - return vcall(:.-, A, B) - end -end - -ops[:Div] = function(params, A , B) - if (haskey(params, :broadcast) && params[:broadcast] == 1) - if !haskey(params, :axis) - return vcall(:./, A, B) - end - return vcall( :Div, params[:axis], A, B) # To-Do define Div function - else - return vcall(:./ , A, B) # In case of no broadcast, Perform normal div operation. - end -end - -ops[:Mul] = function (params, A, B) - if (haskey(params, :broadcast) && params[:broadcast] == 1) - if !haskey(params, :axis) - return vcall(:.*, A, B) - end - return vcall( :Mul, params[:axis], A, B) # To-Do define Mul function - else - return vcall(:.*, A, B) # In case of no broadcast, Perform normal Mul operation. - end -end - -ops[:Pow] = function (params, A, B) - if (haskey(params, :broadcast) && params[:broadcast] == 1) - if !haskey(params, :axis) - return vcall(:.^, A, B) - end - return vcall( :Pow, params[:axis], A, B) # To-Do define Pow function - else - return vcall(:.^, A, B) # In case of no broadcast, Perform normal Power operation. - end -end - -ops[:MatMul] = function(params, A, B) - #tempa = vcall(:permutedims, A, vcall(:reverse, vcall(:range, 1, vcall(:ndims, A)))) - #tempb = vcall(:permutedims, B, vcall(:reverse, vcall(:range, 1, vcall(:ndims, B)))) - vcall(:*, B, A) -end - -ops[:Shape] = function(params, A) - vcall(:size, A) -end - -ops[:size] = function(params, A) - vcall(:prod, vcall(:size, A)) -end - -ops[:Sqrt] = function(params, A) - vcall(:broadcast, :sqrt, A) -end - -ops[:Reciprocal] = function(params, A) - vcall(:./ , 1, A) -end - -ops[:Xor] = function (params, A, B) - ip1 = vcall(:broadcast, &, vcall(:Array, vcall(:broadcast, Bool, A)), vcall(:Array, - vcall(:broadcast, !, vcall(:broadcast, Bool, B)))) - ip2 = vcall(:broadcast, &, vcall(:Array, vcall(:broadcast, Bool, B)), vcall(:Array, - vcall(:broadcast, !, vcall(:broadcast, Bool, A)))) - return vcall(:broadcast, :Int, vcall(:broadcast, |, ip1, ip2)) -end - -ops[:And] = function(params, A, B) - if (haskey(params, :broadcast) && params[:broadcast] == 1) - if !haskey(params, :axis) - return vcall(:.*, vcall(:broadcast, :Bool, A), vcall(:broadcast, :Bool, B)) - end - return vcall( :And, params[:axis], A, B) # To-Do define And function - else - return vcall(:.*, vcall(:broadcast, :Bool, A), vcall(:broadcast, :Bool, B)) # In case of no broadcast, - #Perform normal And operation. - end -end - -ops[:Or] = function(params, A, B) - if (haskey(params, :broadcast) && params[:broadcast] == 1) - if !haskey(params, :axis) - return vcall(:.+, vcall(:broadcast, :Bool, A), vcall(:broadcast, :Bool, B)) - end - return vcall( :Or, params[:axis], A, B) # To-Do define Or function - else - return vcall(:.+, vcall(:broadcast, :Bool, A), vcall(:broadcast, :Bool, B)) # In case of no broadcast, - #Perform normal Or operation. - end -end - -ops[:Expand] = function(params, A, B) - shape_new = vcall(:reverse, B) - return vcall(:repeat , A, Symbol("inner=$(vcall(:reverse, B))")) -end -# Preprocessing - -ops[:ImageScaler] = function(params, A) - if !haskey(params, :scale) - params[:scale] = 1 - end - vcall(:.*, A, params[:scale]) -end - -#Trigonometric ops - -ops[:Cos] = function(params, A) - vcall(:broadcast, :cos, A) -end - -ops[:Sin] = function(params, A) - vcall(:broadcast, :sin, A) -end - -ops[:Tan] = function(params, A) - vcall(:broadcast, :tan, A) -end - -ops[:Acos] = function(params, A) - vcall(:broadcast, :acos, A) -end - -ops[:Asin] = function(params, A) - vcall(:broadcast, :asin, A) -end - -ops[:Atan] = function(params, A) - vcall(:broadcast, :atan, A) -end diff --git a/src/new_types.jl b/src/new_types.jl deleted file mode 100644 index 267d39d3..00000000 --- a/src/new_types.jl +++ /dev/null @@ -1,53 +0,0 @@ -#= - Now, we define the new data types. - Model => ModelProto - Graph => GraphProto - Node => NodeProto - Attribute => AttributeProto - The new types will consist of Julian attributes. - - The purpose of dealing with these newer type is to make - the process simpler and easier to debug. -=# - -module Types - -mutable struct ValueInfo - name::AbstractString - doc_string::AbstractString -end - -mutable struct Node - input::Vector{AbstractString} - output::Vector{AbstractString} - name::AbstractString - op_type::AbstractString # Done! - domain::AbstractString - attribute::Dict{Any, Any} # AttributeProto to Dict - doc_string::AbstractString -end - -mutable struct Graph - node::Array{Any, 1} - name::AbstractString - initializer::Dict{Any ,Any} #Storing the array data instead of the tensorproto vector. - doc_string::AbstractString #in Dict format. - input::Array{ValueInfo ,1} # ValueInfoProto -> ValueInfo - output::Array{ValueInfo, 1} # - value_info::Array{ValueInfo, 1} # Done! -end - -mutable struct Model - ir_version::Int64 - opset_import::Array{Any, 1} #OperatorSetIdProto to Dict - producer_name::AbstractString - producer_version::AbstractString # Done! - domain::AbstractString - model_version::Int64 - doc_string::AbstractString - graph::Graph - metadata_props::Array{Any, 1} #StringStringEntryProto to Dict -end - -export Model, Graph, Node, ValueInfo -end \ No newline at end of file diff --git a/src/onnx_pb.jl b/src/onnx_pb.jl index 1fe55590..7552de79 100644 --- a/src/onnx_pb.jl +++ b/src/onnx_pb.jl @@ -1,231 +1,856 @@ -# Generate protoBuf code, donot change directly. - -module Proto - # syntax: proto2 -using Compat using ProtoBuf import ProtoBuf.meta -import Base: hash, isequal, == -struct __enum_Version <: ProtoEnum - _START_VERSION::Int32 - IR_VERSION_2017_10_10::Int32 - IR_VERSION_2017_10_30::Int32 - IR_VERSION::Int32 - __enum_Version() = new(0,1,2,3) -end #struct __enum_Version -const Version = __enum_Version() +const Version = (;[ + Symbol("_START_VERSION") => Int32(0), + Symbol("IR_VERSION_2017_10_10") => Int32(1), + Symbol("IR_VERSION_2017_10_30") => Int32(2), + Symbol("IR_VERSION_2017_11_3") => Int32(3), + Symbol("IR_VERSION_2019_1_22") => Int32(4), + Symbol("IR_VERSION_2019_3_18") => Int32(5), + Symbol("IR_VERSION_2019_9_19") => Int32(6), + Symbol("IR_VERSION") => Int32(7), +]...) mutable struct StringStringEntryProto <: ProtoType - key::AbstractString - value::AbstractString - StringStringEntryProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct StringStringEntryProto -hash(v::StringStringEntryProto) = ProtoBuf.protohash(v) -isequal(v1::StringStringEntryProto, v2::StringStringEntryProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::StringStringEntryProto, v2::StringStringEntryProto) = ProtoBuf.protoeq(v1, v2) - -struct __enum_TensorProto_DataType <: ProtoEnum - UNDEFINED::Int32 - FLOAT::Int32 - UINT8::Int32 - INT8::Int32 - UINT16::Int32 - INT16::Int32 - INT32::Int32 - INT64::Int32 - STRING::Int32 - BOOL::Int32 - FLOAT16::Int32 - DOUBLE::Int32 - UINT32::Int32 - UINT64::Int32 - COMPLEX64::Int32 - COMPLEX128::Int32 - __enum_TensorProto_DataType() = new(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15) -end #struct __enum_TensorProto_DataType -const TensorProto_DataType = __enum_TensorProto_DataType() + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function StringStringEntryProto(; kwargs...) + obj = new(meta(StringStringEntryProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct StringStringEntryProto +const __meta_StringStringEntryProto = Ref{ProtoMeta}() +function meta(::Type{StringStringEntryProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_StringStringEntryProto) + __meta_StringStringEntryProto[] = target = ProtoMeta(StringStringEntryProto) + allflds = Pair{Symbol,Union{Type,String}}[:key => AbstractString, :value => AbstractString] + meta(target, StringStringEntryProto, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_StringStringEntryProto[] + end +end +function Base.getproperty(obj::StringStringEntryProto, name::Symbol) + if name === :key + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :value + return (obj.__protobuf_jl_internal_values[name])::AbstractString + else + getfield(obj, name) + end +end + +mutable struct TensorAnnotation <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TensorAnnotation(; kwargs...) + obj = new(meta(TensorAnnotation), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TensorAnnotation +const __meta_TensorAnnotation = Ref{ProtoMeta}() +function meta(::Type{TensorAnnotation}) + ProtoBuf.metalock() do + if !isassigned(__meta_TensorAnnotation) + __meta_TensorAnnotation[] = target = ProtoMeta(TensorAnnotation) + allflds = Pair{Symbol,Union{Type,String}}[:tensor_name => AbstractString, :quant_parameter_tensor_names => Base.Vector{StringStringEntryProto}] + meta(target, TensorAnnotation, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TensorAnnotation[] + end +end +function Base.getproperty(obj::TensorAnnotation, name::Symbol) + if name === :tensor_name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :quant_parameter_tensor_names + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{StringStringEntryProto} + else + getfield(obj, name) + end +end + +const TensorProto_DataType = (;[ + Symbol("UNDEFINED") => Int32(0), + Symbol("FLOAT") => Int32(1), + Symbol("UINT8") => Int32(2), + Symbol("INT8") => Int32(3), + Symbol("UINT16") => Int32(4), + Symbol("INT16") => Int32(5), + Symbol("INT32") => Int32(6), + Symbol("INT64") => Int32(7), + Symbol("STRING") => Int32(8), + Symbol("BOOL") => Int32(9), + Symbol("FLOAT16") => Int32(10), + Symbol("DOUBLE") => Int32(11), + Symbol("UINT32") => Int32(12), + Symbol("UINT64") => Int32(13), + Symbol("COMPLEX64") => Int32(14), + Symbol("COMPLEX128") => Int32(15), + Symbol("BFLOAT16") => Int32(16), +]...) + +const TensorProto_DataLocation = (;[ + Symbol("DEFAULT") => Int32(0), + Symbol("EXTERNAL") => Int32(1), +]...) mutable struct TensorProto_Segment <: ProtoType - _begin::Int64 - _end::Int64 - TensorProto_Segment(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct TensorProto_Segment -hash(v::TensorProto_Segment) = ProtoBuf.protohash(v) -isequal(v1::TensorProto_Segment, v2::TensorProto_Segment) = ProtoBuf.protoisequal(v1, v2) -==(v1::TensorProto_Segment, v2::TensorProto_Segment) = ProtoBuf.protoeq(v1, v2) + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TensorProto_Segment(; kwargs...) + obj = new(meta(TensorProto_Segment), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TensorProto_Segment +const __meta_TensorProto_Segment = Ref{ProtoMeta}() +function meta(::Type{TensorProto_Segment}) + ProtoBuf.metalock() do + if !isassigned(__meta_TensorProto_Segment) + __meta_TensorProto_Segment[] = target = ProtoMeta(TensorProto_Segment) + allflds = Pair{Symbol,Union{Type,String}}[:_begin => Int64, :_end => Int64] + meta(target, TensorProto_Segment, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TensorProto_Segment[] + end +end +function Base.getproperty(obj::TensorProto_Segment, name::Symbol) + if name === :_begin + return (obj.__protobuf_jl_internal_values[name])::Int64 + elseif name === :_end + return (obj.__protobuf_jl_internal_values[name])::Int64 + else + getfield(obj, name) + end +end mutable struct TensorProto <: ProtoType - dims::Vector{Int64} - data_type::Int32 - segment::TensorProto_Segment - float_data::Vector{Float32} - int32_data::Vector{Int32} - string_data::Vector{Array{UInt8,1}} - int64_data::Vector{Int64} - name::AbstractString - doc_string::AbstractString - raw_data::Array{UInt8,1} - double_data::Vector{Float64} - uint64_data::Vector{UInt64} - TensorProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct TensorProto -const __fnum_TensorProto = Int[1,2,3,4,5,6,7,8,12,9,10,11] -const __pack_TensorProto = Symbol[:float_data,:int32_data,:int64_data,:double_data,:uint64_data] -meta(t::Type{TensorProto}) = meta(t, ProtoBuf.DEF_REQ, __fnum_TensorProto, ProtoBuf.DEF_VAL, true, __pack_TensorProto, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) -hash(v::TensorProto) = ProtoBuf.protohash(v) -isequal(v1::TensorProto, v2::TensorProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::TensorProto, v2::TensorProto) = ProtoBuf.protoeq(v1, v2) + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TensorProto(; kwargs...) + obj = new(meta(TensorProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TensorProto +const __meta_TensorProto = Ref{ProtoMeta}() +function meta(::Type{TensorProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_TensorProto) + __meta_TensorProto[] = target = ProtoMeta(TensorProto) + fnum = Int[1,2,3,4,5,6,7,8,12,9,13,14,10,11] + pack = Symbol[:float_data,:int32_data,:int64_data,:double_data,:uint64_data] + allflds = Pair{Symbol,Union{Type,String}}[:dims => Base.Vector{Int64}, :data_type => Int32, :segment => TensorProto_Segment, :float_data => Base.Vector{Float32}, :int32_data => Base.Vector{Int32}, :string_data => Base.Vector{Array{UInt8,1}}, :int64_data => Base.Vector{Int64}, :name => AbstractString, :doc_string => AbstractString, :raw_data => Array{UInt8,1}, :external_data => Base.Vector{StringStringEntryProto}, :data_location => Int32, :double_data => Base.Vector{Float64}, :uint64_data => Base.Vector{UInt64}] + meta(target, TensorProto, allflds, ProtoBuf.DEF_REQ, fnum, ProtoBuf.DEF_VAL, pack, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TensorProto[] + end +end +function Base.getproperty(obj::TensorProto, name::Symbol) + if name === :dims + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Int64} + elseif name === :data_type + return (obj.__protobuf_jl_internal_values[name])::Int32 + elseif name === :segment + return (obj.__protobuf_jl_internal_values[name])::TensorProto_Segment + elseif name === :float_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Float32} + elseif name === :int32_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Int32} + elseif name === :string_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Array{UInt8,1}} + elseif name === :int64_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Int64} + elseif name === :name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :doc_string + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :raw_data + return (obj.__protobuf_jl_internal_values[name])::Array{UInt8,1} + elseif name === :external_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{StringStringEntryProto} + elseif name === :data_location + return (obj.__protobuf_jl_internal_values[name])::Int32 + elseif name === :double_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Float64} + elseif name === :uint64_data + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{UInt64} + else + getfield(obj, name) + end +end + +mutable struct SparseTensorProto <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function SparseTensorProto(; kwargs...) + obj = new(meta(SparseTensorProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct SparseTensorProto +const __meta_SparseTensorProto = Ref{ProtoMeta}() +function meta(::Type{SparseTensorProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_SparseTensorProto) + __meta_SparseTensorProto[] = target = ProtoMeta(SparseTensorProto) + allflds = Pair{Symbol,Union{Type,String}}[:values => TensorProto, :indices => TensorProto, :dims => Base.Vector{Int64}] + meta(target, SparseTensorProto, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_SparseTensorProto[] + end +end +function Base.getproperty(obj::SparseTensorProto, name::Symbol) + if name === :values + return (obj.__protobuf_jl_internal_values[name])::TensorProto + elseif name === :indices + return (obj.__protobuf_jl_internal_values[name])::TensorProto + elseif name === :dims + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Int64} + else + getfield(obj, name) + end +end mutable struct TensorShapeProto_Dimension <: ProtoType - dim_value::Int64 - dim_param::AbstractString - TensorShapeProto_Dimension(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct TensorShapeProto_Dimension -const __oneofs_TensorShapeProto_Dimension = Int[1,1] -const __oneof_names_TensorShapeProto_Dimension = [Symbol("value")] -meta(t::Type{TensorShapeProto_Dimension}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, __oneofs_TensorShapeProto_Dimension, __oneof_names_TensorShapeProto_Dimension, ProtoBuf.DEF_FIELD_TYPES) -hash(v::TensorShapeProto_Dimension) = ProtoBuf.protohash(v) -isequal(v1::TensorShapeProto_Dimension, v2::TensorShapeProto_Dimension) = ProtoBuf.protoisequal(v1, v2) -==(v1::TensorShapeProto_Dimension, v2::TensorShapeProto_Dimension) = ProtoBuf.protoeq(v1, v2) + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TensorShapeProto_Dimension(; kwargs...) + obj = new(meta(TensorShapeProto_Dimension), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TensorShapeProto_Dimension +const __meta_TensorShapeProto_Dimension = Ref{ProtoMeta}() +function meta(::Type{TensorShapeProto_Dimension}) + ProtoBuf.metalock() do + if !isassigned(__meta_TensorShapeProto_Dimension) + __meta_TensorShapeProto_Dimension[] = target = ProtoMeta(TensorShapeProto_Dimension) + allflds = Pair{Symbol,Union{Type,String}}[:dim_value => Int64, :dim_param => AbstractString, :denotation => AbstractString] + oneofs = Int[1,1,1] + oneof_names = Symbol[Symbol("value")] + meta(target, TensorShapeProto_Dimension, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, oneofs, oneof_names) + end + __meta_TensorShapeProto_Dimension[] + end +end +function Base.getproperty(obj::TensorShapeProto_Dimension, name::Symbol) + if name === :dim_value + return (obj.__protobuf_jl_internal_values[name])::Int64 + elseif name === :dim_param + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :denotation + return (obj.__protobuf_jl_internal_values[name])::AbstractString + else + getfield(obj, name) + end +end mutable struct TensorShapeProto <: ProtoType - dim::Vector{TensorShapeProto_Dimension} - TensorShapeProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct TensorShapeProto -hash(v::TensorShapeProto) = ProtoBuf.protohash(v) -isequal(v1::TensorShapeProto, v2::TensorShapeProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::TensorShapeProto, v2::TensorShapeProto) = ProtoBuf.protoeq(v1, v2) + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} -mutable struct TypeProto_Tensor <: ProtoType - elem_type::Int32 - shape::TensorShapeProto - TypeProto_Tensor(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct TypeProto_Tensor -hash(v::TypeProto_Tensor) = ProtoBuf.protohash(v) -isequal(v1::TypeProto_Tensor, v2::TypeProto_Tensor) = ProtoBuf.protoisequal(v1, v2) -==(v1::TypeProto_Tensor, v2::TypeProto_Tensor) = ProtoBuf.protoeq(v1, v2) + function TensorShapeProto(; kwargs...) + obj = new(meta(TensorShapeProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TensorShapeProto +const __meta_TensorShapeProto = Ref{ProtoMeta}() +function meta(::Type{TensorShapeProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_TensorShapeProto) + __meta_TensorShapeProto[] = target = ProtoMeta(TensorShapeProto) + allflds = Pair{Symbol,Union{Type,String}}[:dim => Base.Vector{TensorShapeProto_Dimension}] + meta(target, TensorShapeProto, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TensorShapeProto[] + end +end +function Base.getproperty(obj::TensorShapeProto, name::Symbol) + if name === :dim + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{TensorShapeProto_Dimension} + else + getfield(obj, name) + end +end -mutable struct TypeProto <: ProtoType - tensor_type::TypeProto_Tensor - TypeProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct TypeProto -const __oneofs_TypeProto = Int[1] -const __oneof_names_TypeProto = [Symbol("value")] -meta(t::Type{TypeProto}) = meta(t, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, __oneofs_TypeProto, __oneof_names_TypeProto, ProtoBuf.DEF_FIELD_TYPES) -hash(v::TypeProto) = ProtoBuf.protohash(v) -isequal(v1::TypeProto, v2::TypeProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::TypeProto, v2::TypeProto) = ProtoBuf.protoeq(v1, v2) +mutable struct OperatorSetIdProto <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} -mutable struct ValueInfoProto <: ProtoType - name::AbstractString - _type::TypeProto - doc_string::AbstractString - ValueInfoProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct ValueInfoProto -hash(v::ValueInfoProto) = ProtoBuf.protohash(v) -isequal(v1::ValueInfoProto, v2::ValueInfoProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::ValueInfoProto, v2::ValueInfoProto) = ProtoBuf.protoeq(v1, v2) + function OperatorSetIdProto(; kwargs...) + obj = new(meta(OperatorSetIdProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct OperatorSetIdProto +const __meta_OperatorSetIdProto = Ref{ProtoMeta}() +function meta(::Type{OperatorSetIdProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_OperatorSetIdProto) + __meta_OperatorSetIdProto[] = target = ProtoMeta(OperatorSetIdProto) + allflds = Pair{Symbol,Union{Type,String}}[:domain => AbstractString, :version => Int64] + meta(target, OperatorSetIdProto, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_OperatorSetIdProto[] + end +end +function Base.getproperty(obj::OperatorSetIdProto, name::Symbol) + if name === :domain + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :version + return (obj.__protobuf_jl_internal_values[name])::Int64 + else + getfield(obj, name) + end +end -mutable struct OperatorSetIdProto <: ProtoType - domain::AbstractString - version::Int64 - OperatorSetIdProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct OperatorSetIdProto -hash(v::OperatorSetIdProto) = ProtoBuf.protohash(v) -isequal(v1::OperatorSetIdProto, v2::OperatorSetIdProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::OperatorSetIdProto, v2::OperatorSetIdProto) = ProtoBuf.protoeq(v1, v2) - -struct __enum_AttributeProto_AttributeType <: ProtoEnum - UNDEFINED::Int32 - FLOAT::Int32 - INT::Int32 - STRING::Int32 - TENSOR::Int32 - GRAPH::Int32 - FLOATS::Int32 - INTS::Int32 - STRINGS::Int32 - TENSORS::Int32 - GRAPHS::Int32 - __enum_AttributeProto_AttributeType() = new(0,1,2,3,4,5,6,7,8,9,10) -end #struct __enum_AttributeProto_AttributeType -const AttributeProto_AttributeType = __enum_AttributeProto_AttributeType() +const AttributeProto_AttributeType = (;[ + Symbol("UNDEFINED") => Int32(0), + Symbol("FLOAT") => Int32(1), + Symbol("INT") => Int32(2), + Symbol("STRING") => Int32(3), + Symbol("TENSOR") => Int32(4), + Symbol("GRAPH") => Int32(5), + Symbol("SPARSE_TENSOR") => Int32(11), + Symbol("FLOATS") => Int32(6), + Symbol("INTS") => Int32(7), + Symbol("STRINGS") => Int32(8), + Symbol("TENSORS") => Int32(9), + Symbol("GRAPHS") => Int32(10), + Symbol("SPARSE_TENSORS") => Int32(12), +]...) mutable struct AttributeProto <: ProtoType - name::AbstractString - doc_string::AbstractString - _type::Int32 - f::Float32 - i::Int64 - s::Array{UInt8,1} - t::TensorProto - g::Any - floats::Vector{Float32} - ints::Vector{Int64} - strings::Vector{Array{UInt8,1}} - tensors::Vector{TensorProto} - graphs::Any - AttributeProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct AttributeProto -const __fnum_AttributeProto = Int[1,13,20,2,3,4,5,6,7,8,9,10,11] -const __ftype_AttributeProto = Dict(:g => "GraphProto", :graphs => "Vector{GraphProto}") -meta(t::Type{AttributeProto}) = meta(t, ProtoBuf.DEF_REQ, __fnum_AttributeProto, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, __ftype_AttributeProto) -hash(v::AttributeProto) = ProtoBuf.protohash(v) -isequal(v1::AttributeProto, v2::AttributeProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::AttributeProto, v2::AttributeProto) = ProtoBuf.protoeq(v1, v2) + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function AttributeProto(; kwargs...) + obj = new(meta(AttributeProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct AttributeProto (has cyclic type dependency) +const __meta_AttributeProto = Ref{ProtoMeta}() +function meta(::Type{AttributeProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_AttributeProto) + __meta_AttributeProto[] = target = ProtoMeta(AttributeProto) + fnum = Int[1,21,13,20,2,3,4,5,6,22,7,8,9,10,11,23] + allflds = Pair{Symbol,Union{Type,String}}[:name => AbstractString, :ref_attr_name => AbstractString, :doc_string => AbstractString, :_type => Int32, :f => Float32, :i => Int64, :s => Array{UInt8,1}, :t => TensorProto, :g => "GraphProto", :sparse_tensor => SparseTensorProto, :floats => Base.Vector{Float32}, :ints => Base.Vector{Int64}, :strings => Base.Vector{Array{UInt8,1}}, :tensors => Base.Vector{TensorProto}, :graphs => "Base.Vector{GraphProto}", :sparse_tensors => Base.Vector{SparseTensorProto}] + meta(target, AttributeProto, allflds, ProtoBuf.DEF_REQ, fnum, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_AttributeProto[] + end +end +function Base.getproperty(obj::AttributeProto, name::Symbol) + if name === :name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :ref_attr_name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :doc_string + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :_type + return (obj.__protobuf_jl_internal_values[name])::Int32 + elseif name === :f + return (obj.__protobuf_jl_internal_values[name])::Float32 + elseif name === :i + return (obj.__protobuf_jl_internal_values[name])::Int64 + elseif name === :s + return (obj.__protobuf_jl_internal_values[name])::Array{UInt8,1} + elseif name === :t + return (obj.__protobuf_jl_internal_values[name])::TensorProto + elseif name === :g + return (obj.__protobuf_jl_internal_values[name])::Any + elseif name === :sparse_tensor + return (obj.__protobuf_jl_internal_values[name])::SparseTensorProto + elseif name === :floats + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Float32} + elseif name === :ints + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Int64} + elseif name === :strings + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{Array{UInt8,1}} + elseif name === :tensors + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{TensorProto} + elseif name === :graphs + return (obj.__protobuf_jl_internal_values[name])::Any + elseif name === :sparse_tensors + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{SparseTensorProto} + else + getfield(obj, name) + end +end + +mutable struct ValueInfoProto <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function ValueInfoProto(; kwargs...) + obj = new(meta(ValueInfoProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct ValueInfoProto (has cyclic type dependency) +const __meta_ValueInfoProto = Ref{ProtoMeta}() +function meta(::Type{ValueInfoProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_ValueInfoProto) + __meta_ValueInfoProto[] = target = ProtoMeta(ValueInfoProto) + allflds = Pair{Symbol,Union{Type,String}}[:name => AbstractString, :_type => "TypeProto", :doc_string => AbstractString] + meta(target, ValueInfoProto, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_ValueInfoProto[] + end +end +function Base.getproperty(obj::ValueInfoProto, name::Symbol) + if name === :name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :_type + return (obj.__protobuf_jl_internal_values[name])::Any + elseif name === :doc_string + return (obj.__protobuf_jl_internal_values[name])::AbstractString + else + getfield(obj, name) + end +end mutable struct NodeProto <: ProtoType - input::Vector{AbstractString} - output::Vector{AbstractString} - name::AbstractString - op_type::AbstractString - domain::AbstractString - attribute::Vector{AttributeProto} - doc_string::AbstractString - NodeProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct NodeProto -const __fnum_NodeProto = Int[1,2,3,4,7,5,6] -meta(t::Type{NodeProto}) = meta(t, ProtoBuf.DEF_REQ, __fnum_NodeProto, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) -hash(v::NodeProto) = ProtoBuf.protohash(v) -isequal(v1::NodeProto, v2::NodeProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::NodeProto, v2::NodeProto) = ProtoBuf.protoeq(v1, v2) + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} -mutable struct GraphProto <: ProtoType - node::Vector{NodeProto} - name::AbstractString - initializer::Vector{TensorProto} - doc_string::AbstractString - input::Vector{ValueInfoProto} - output::Vector{ValueInfoProto} - value_info::Vector{ValueInfoProto} - GraphProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct GraphProto -const __fnum_GraphProto = Int[1,2,5,10,11,12,13] -meta(t::Type{GraphProto}) = meta(t, ProtoBuf.DEF_REQ, __fnum_GraphProto, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) -hash(v::GraphProto) = ProtoBuf.protohash(v) -isequal(v1::GraphProto, v2::GraphProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::GraphProto, v2::GraphProto) = ProtoBuf.protoeq(v1, v2) + function NodeProto(; kwargs...) + obj = new(meta(NodeProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct NodeProto (has cyclic type dependency) +const __meta_NodeProto = Ref{ProtoMeta}() +function meta(::Type{NodeProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_NodeProto) + __meta_NodeProto[] = target = ProtoMeta(NodeProto) + fnum = Int[1,2,3,4,7,5,6] + allflds = Pair{Symbol,Union{Type,String}}[:input => Base.Vector{AbstractString}, :output => Base.Vector{AbstractString}, :name => AbstractString, :op_type => AbstractString, :domain => AbstractString, :attribute => Base.Vector{AttributeProto}, :doc_string => AbstractString] + meta(target, NodeProto, allflds, ProtoBuf.DEF_REQ, fnum, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_NodeProto[] + end +end +function Base.getproperty(obj::NodeProto, name::Symbol) + if name === :input + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{AbstractString} + elseif name === :output + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{AbstractString} + elseif name === :name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :op_type + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :domain + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :attribute + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{AttributeProto} + elseif name === :doc_string + return (obj.__protobuf_jl_internal_values[name])::AbstractString + else + getfield(obj, name) + end +end + +mutable struct TrainingInfoProto <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TrainingInfoProto(; kwargs...) + obj = new(meta(TrainingInfoProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TrainingInfoProto (has cyclic type dependency) +const __meta_TrainingInfoProto = Ref{ProtoMeta}() +function meta(::Type{TrainingInfoProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_TrainingInfoProto) + __meta_TrainingInfoProto[] = target = ProtoMeta(TrainingInfoProto) + allflds = Pair{Symbol,Union{Type,String}}[:initialization => "GraphProto", :algorithm => "GraphProto", :initialization_binding => Base.Vector{StringStringEntryProto}, :update_binding => Base.Vector{StringStringEntryProto}] + meta(target, TrainingInfoProto, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TrainingInfoProto[] + end +end +function Base.getproperty(obj::TrainingInfoProto, name::Symbol) + if name === :initialization + return (obj.__protobuf_jl_internal_values[name])::Any + elseif name === :algorithm + return (obj.__protobuf_jl_internal_values[name])::Any + elseif name === :initialization_binding + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{StringStringEntryProto} + elseif name === :update_binding + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{StringStringEntryProto} + else + getfield(obj, name) + end +end mutable struct ModelProto <: ProtoType - ir_version::Int64 - opset_import::Vector{OperatorSetIdProto} - producer_name::AbstractString - producer_version::AbstractString - domain::AbstractString - model_version::Int64 - doc_string::AbstractString - graph::GraphProto - metadata_props::Vector{StringStringEntryProto} - ModelProto(; kwargs...) = (o=new(); fillunset(o); isempty(kwargs) || ProtoBuf._protobuild(o, kwargs); o) -end #mutable struct ModelProto -const __fnum_ModelProto = Int[1,8,2,3,4,5,6,7,14] -meta(t::Type{ModelProto}) = meta(t, ProtoBuf.DEF_REQ, __fnum_ModelProto, ProtoBuf.DEF_VAL, true, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES, ProtoBuf.DEF_FIELD_TYPES) -hash(v::ModelProto) = ProtoBuf.protohash(v) -isequal(v1::ModelProto, v2::ModelProto) = ProtoBuf.protoisequal(v1, v2) -==(v1::ModelProto, v2::ModelProto) = ProtoBuf.protoeq(v1, v2) - -export Version, AttributeProto_AttributeType, AttributeProto, ValueInfoProto, NodeProto, ModelProto, StringStringEntryProto, GraphProto, TensorProto_DataType, TensorProto_Segment, TensorProto, TensorShapeProto_Dimension, TensorShapeProto, TypeProto_Tensor, TypeProto, OperatorSetIdProto, AttributeProto_AttributeType, AttributeProto + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function ModelProto(; kwargs...) + obj = new(meta(ModelProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct ModelProto (has cyclic type dependency) +const __meta_ModelProto = Ref{ProtoMeta}() +function meta(::Type{ModelProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_ModelProto) + __meta_ModelProto[] = target = ProtoMeta(ModelProto) + fnum = Int[1,8,2,3,4,5,6,7,14,20] + allflds = Pair{Symbol,Union{Type,String}}[:ir_version => Int64, :opset_import => Base.Vector{OperatorSetIdProto}, :producer_name => AbstractString, :producer_version => AbstractString, :domain => AbstractString, :model_version => Int64, :doc_string => AbstractString, :graph => "GraphProto", :metadata_props => Base.Vector{StringStringEntryProto}, :training_info => Base.Vector{TrainingInfoProto}] + meta(target, ModelProto, allflds, ProtoBuf.DEF_REQ, fnum, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_ModelProto[] + end +end +function Base.getproperty(obj::ModelProto, name::Symbol) + if name === :ir_version + return (obj.__protobuf_jl_internal_values[name])::Int64 + elseif name === :opset_import + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{OperatorSetIdProto} + elseif name === :producer_name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :producer_version + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :domain + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :model_version + return (obj.__protobuf_jl_internal_values[name])::Int64 + elseif name === :doc_string + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :graph + return (obj.__protobuf_jl_internal_values[name])::Any + elseif name === :metadata_props + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{StringStringEntryProto} + elseif name === :training_info + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{TrainingInfoProto} + else + getfield(obj, name) + end +end + +mutable struct GraphProto <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function GraphProto(; kwargs...) + obj = new(meta(GraphProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct GraphProto (has cyclic type dependency) +const __meta_GraphProto = Ref{ProtoMeta}() +function meta(::Type{GraphProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_GraphProto) + __meta_GraphProto[] = target = ProtoMeta(GraphProto) + fnum = Int[1,2,5,15,10,11,12,13,14] + allflds = Pair{Symbol,Union{Type,String}}[:node => Base.Vector{NodeProto}, :name => AbstractString, :initializer => Base.Vector{TensorProto}, :sparse_initializer => Base.Vector{SparseTensorProto}, :doc_string => AbstractString, :input => Base.Vector{ValueInfoProto}, :output => Base.Vector{ValueInfoProto}, :value_info => Base.Vector{ValueInfoProto}, :quantization_annotation => Base.Vector{TensorAnnotation}] + meta(target, GraphProto, allflds, ProtoBuf.DEF_REQ, fnum, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_GraphProto[] + end +end +function Base.getproperty(obj::GraphProto, name::Symbol) + if name === :node + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{NodeProto} + elseif name === :name + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :initializer + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{TensorProto} + elseif name === :sparse_initializer + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{SparseTensorProto} + elseif name === :doc_string + return (obj.__protobuf_jl_internal_values[name])::AbstractString + elseif name === :input + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{ValueInfoProto} + elseif name === :output + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{ValueInfoProto} + elseif name === :value_info + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{ValueInfoProto} + elseif name === :quantization_annotation + return (obj.__protobuf_jl_internal_values[name])::Base.Vector{TensorAnnotation} + else + getfield(obj, name) + end +end + +mutable struct TypeProto_Tensor <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TypeProto_Tensor(; kwargs...) + obj = new(meta(TypeProto_Tensor), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end +const __meta_TypeProto_Tensor = Ref{ProtoMeta}() +function meta(::Type{TypeProto_Tensor}) + ProtoBuf.metalock() do + if !isassigned(__meta_TypeProto_Tensor) + __meta_TypeProto_Tensor[] = target = ProtoMeta(TypeProto_Tensor) + allflds = Pair{Symbol,Union{Type,String}}[:elem_type => Int32, :shape => TensorShapeProto] + meta(target, TypeProto_Tensor, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TypeProto_Tensor[] + end +end +function Base.getproperty(obj::TypeProto_Tensor, name::Symbol) + if name === :elem_type + return (obj.__protobuf_jl_internal_values[name])::Int32 + elseif name === :shape + return (obj.__protobuf_jl_internal_values[name])::TensorShapeProto + else + getfield(obj, name) + end +end + + +mutable struct TypeProto_Sequence <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TypeProto_Sequence(; kwargs...) + obj = new(meta(TypeProto_Sequence), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TypeProto_Sequence (has cyclic type dependency) +const __meta_TypeProto_Sequence = Ref{ProtoMeta}() +function meta(::Type{TypeProto_Sequence}) + ProtoBuf.metalock() do + if !isassigned(__meta_TypeProto_Sequence) + __meta_TypeProto_Sequence[] = target = ProtoMeta(TypeProto_Sequence) + allflds = Pair{Symbol,Union{Type,String}}[:elem_type => "TypeProto"] + meta(target, TypeProto_Sequence, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TypeProto_Sequence[] + end end +function Base.getproperty(obj::TypeProto_Sequence, name::Symbol) + if name === :elem_type + return (obj.__protobuf_jl_internal_values[name])::Any + else + getfield(obj, name) + end +end + +mutable struct TypeProto_Map <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TypeProto_Map(; kwargs...) + obj = new(meta(TypeProto_Map), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TypeProto_Map (has cyclic type dependency) +const __meta_TypeProto_Map = Ref{ProtoMeta}() +function meta(::Type{TypeProto_Map}) + ProtoBuf.metalock() do + if !isassigned(__meta_TypeProto_Map) + __meta_TypeProto_Map[] = target = ProtoMeta(TypeProto_Map) + allflds = Pair{Symbol,Union{Type,String}}[:key_type => Int32, :value_type => "TypeProto"] + meta(target, TypeProto_Map, allflds, ProtoBuf.DEF_REQ, ProtoBuf.DEF_FNUM, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, ProtoBuf.DEF_ONEOFS, ProtoBuf.DEF_ONEOF_NAMES) + end + __meta_TypeProto_Map[] + end +end +function Base.getproperty(obj::TypeProto_Map, name::Symbol) + if name === :key_type + return (obj.__protobuf_jl_internal_values[name])::Int32 + elseif name === :value_type + return (obj.__protobuf_jl_internal_values[name])::Any + else + getfield(obj, name) + end +end + +mutable struct TypeProto <: ProtoType + __protobuf_jl_internal_meta::ProtoMeta + __protobuf_jl_internal_values::Dict{Symbol,Any} + + function TypeProto(; kwargs...) + obj = new(meta(TypeProto), Dict{Symbol,Any}()) + values = obj.__protobuf_jl_internal_values + symdict = obj.__protobuf_jl_internal_meta.symdict + for nv in kwargs + fldname, fldval = nv + fldtype = symdict[fldname].jtyp + (fldname in keys(symdict)) || error(string(typeof(obj), " has no field with name ", fldname)) + values[fldname] = isa(fldval, fldtype) ? fldval : convert(fldtype, fldval) + end + obj + end +end # mutable struct TypeProto (has cyclic type dependency) +const __meta_TypeProto = Ref{ProtoMeta}() +function meta(::Type{TypeProto}) + ProtoBuf.metalock() do + if !isassigned(__meta_TypeProto) + __meta_TypeProto[] = target = ProtoMeta(TypeProto) + fnum = Int[1,4,5,6] + allflds = Pair{Symbol,Union{Type,String}}[:tensor_type => TypeProto_Tensor, :sequence_type => TypeProto_Sequence, :map_type => TypeProto_Map, :denotation => AbstractString] + oneofs = Int[1,1,1,1] + oneof_names = Symbol[Symbol("value")] + meta(target, TypeProto, allflds, ProtoBuf.DEF_REQ, fnum, ProtoBuf.DEF_VAL, ProtoBuf.DEF_PACK, ProtoBuf.DEF_WTYPES, oneofs, oneof_names) + end + __meta_TypeProto[] + end +end +function Base.getproperty(obj::TypeProto, name::Symbol) + if name === :tensor_type + return (obj.__protobuf_jl_internal_values[name])::TypeProto_Tensor + elseif name === :sequence_type + return (obj.__protobuf_jl_internal_values[name])::TypeProto_Sequence + elseif name === :map_type + return (obj.__protobuf_jl_internal_values[name])::TypeProto_Map + elseif name === :denotation + return (obj.__protobuf_jl_internal_values[name])::AbstractString + else + getfield(obj, name) + end +end + +export Version, AttributeProto_AttributeType, AttributeProto, ValueInfoProto, NodeProto, TrainingInfoProto, ModelProto, StringStringEntryProto, TensorAnnotation, GraphProto, TensorProto_DataType, TensorProto_DataLocation, TensorProto_Segment, TensorProto, SparseTensorProto, TensorShapeProto_Dimension, TensorShapeProto, TypeProto_Tensor, TypeProto_Sequence, TypeProto_Map, TypeProto, OperatorSetIdProto, AttributeProto_AttributeType, AttributeProto, ValueInfoProto, NodeProto, TrainingInfoProto, ModelProto, GraphProto, TypeProto_Sequence, TypeProto_Map, TypeProto diff --git a/src/read.jl b/src/read.jl new file mode 100644 index 00000000..8bb79923 --- /dev/null +++ b/src/read.jl @@ -0,0 +1,73 @@ + +""" + array(p::TensorProto) + +Return `p` as an `Array` of the correct type. Second argument can be used to change type of the returned array +""" +function array(p::TensorProto, wrap=Array) + # Copy pasted from jl + # Can probably be cleaned up a bit + # TODO: Add missing datatypes... + if p.data_type === TensorProto_DataType.INT64 + if hasproperty(p, :int64_data) && !isempty(p.int64_data) + return reshape(reinterpret(Int64, p.int64_data), reverse(p.dims)...) |> wrap + end + return reshape(reinterpret(Int64, p.raw_data), reverse(p.dims)...) |> wrap + end + + if p.data_type === TensorProto_DataType.INT32 + if hasproperty(p, :int32_data) && !isempty(p.int32_data) + return reshape(p.int32_data , reverse(p.dims)...) |> wrap + end + return reshape(reinterpret(Int32, p.raw_data), reverse(p.dims)...) |> wrap + end + + if p.data_type === TensorProto_DataType.INT8 + return reshape(reinterpret(Int8, p.raw_data), reverse(p.dims)...) |> wrap + end + + if p.data_type === TensorProto_DataType.DOUBLE + if hasproperty(p, :double_data) && !isempty(p.double_data) + return reshape(p.double_data , reverse(p.dims)...) |> wrap + end + return reshape(reinterpret(Float64, p.raw_data), reverse(p.dims)...) |> wrap + end + + if p.data_type === TensorProto_DataType.FLOAT + if hasproperty(p,:float_data) && !isempty(p.float_data) + return reshape(reinterpret(Float32, p.float_data), reverse(p.dims)...) |> wrap + end + return reshape(reinterpret(Float32, p.raw_data), reverse(p.dims)...) |> wrap + end + + if p.data_type === TensorProto_DataType.FLOAT16 + return reshape(reinterpret(Float16, p.raw_data), reverse(p.dims)...) |> wrap + end +end + +Base.size(vip::ValueInfoProto) = size(vip._type) +Base.size(tp::TypeProto) = size(tp.tensor_type) +Base.size(tp::TensorProto) = tp.dims +Base.size(tp_t::TypeProto_Tensor) = hasproperty(tp_t, :shape) ? size(tp_t.shape) : missing +Base.size(tsp::TensorShapeProto) = size.(Tuple(reverse(tsp.dim))) +Base.size(tsp_d::TensorShapeProto_Dimension) = hasproperty(tsp_d, :dim_value) ? tsp_d.dim_value : missing + +""" + attribute(p::AttributeProto) + +Return attribute in `p` as a name => value pair. +""" +function attribute(p::AttributeProto) + # Copy paste from ONNX.jl + if (p._type != 0) + field = [:f, :i, :s, :t, :g, :floats, :ints, :strings, :tensors, :graphs][p._type] + if field === :s + return Symbol(p.name) => String(getproperty(p, field)) + elseif field === :strings + return Symbol(p.name) => String.(getproperty(p, field)) + end + return Symbol(p.name) => getproperty(p, field) + end +end + +Base.Dict(pa::AbstractVector{AttributeProto}) = Dict(attribute(p) for p in pa) diff --git a/src/write.jl b/src/write.jl new file mode 100644 index 00000000..b2f5bde8 --- /dev/null +++ b/src/write.jl @@ -0,0 +1,117 @@ + +# TODO: User supplied elemtype?? +ValueInfoProto(name::String, inshape, elemtype=Float32) = +ValueInfoProto( + name=name, + _type=TypeProto( + tensor_type=TypeProto_Tensor(inshape, elemtype) + ) +) + +TypeProto_Tensor(inshape, elemtype) = TypeProto_Tensor( + elem_type=tp_tensor_elemtype(elemtype), + shape=TensorShapeProto(inshape) +) +TypeProto_Tensor(::Missing, elemtype) = TypeProto_Tensor( + elem_type=tp_tensor_elemtype(elemtype) +) + +TensorShapeProto(shape) = TensorShapeProto(dim=[tsp_d(s) for s in reverse(shape)]) +tsp_d(::Missing) = TensorShapeProto_Dimension() +tsp_d(n::Integer) = TensorShapeProto_Dimension(dim_value=n) +tsp_d(s::String) = TensorShapeProto_Dimension(dim_param=s) +tsp_d(s::Symbol) = tsp_d(string(s)) + +tp_tensor_elemtype(i::Integer) = i +tp_tensor_elemtype(::Missing) = TensorProto_DataType.UNDEFINED +tp_tensor_elemtype(::Type{Float32}) = TensorProto_DataType.FLOAT + +TensorProto(x::Number, name ="") = TensorProto([x], name) + +TensorProto(t::AbstractArray{Float64,N}, name ="") where N = TensorProto( + dims=collect(reverse(size(t))), + data_type=TensorProto_DataType.DOUBLE, + double_data = reshape(t,:), + name=name) + +TensorProto(t::AbstractArray{Float32,N}, name ="") where N = TensorProto( + dims=collect(reverse(size(t))), + data_type=TensorProto_DataType.FLOAT, + float_data = reshape(t,:), + name=name) + +TensorProto(t::AbstractArray{Float16,N}, name ="") where N = TensorProto(t, TensorProto_DataType.FLOAT16, name) + +TensorProto(t::AbstractArray{Int64,N}, name ="") where N = TensorProto( + dims=collect(reverse(size(t))), + data_type=TensorProto_DataType.INT64, + int64_data = reshape(t,:), + name=name) + +TensorProto(t::AbstractArray{Int32,N}, name ="") where N = TensorProto( + dims=collect(reverse(size(t))), + data_type=TensorProto_DataType.INT32, + int32_data = reshape(t,:), + name=name) + +TensorProto(t::AbstractArray{Int8,N}, name ="") where N = TensorProto(t, TensorProto_DataType.INT8, name) + +TensorProto(t::AbstractArray, data_type::Int32, name) = TensorProto( + dims=collect(reverse(size(t))), + data_type=data_type, + raw_data = reinterpret(UInt8, reshape(t,:)), + name=name) + + +AttributeProto(p::Pair) = AttributeProto(first(p), last(p)) +AttributeProto(name::Symbol, v) = AttributeProto(string(name), v) + +AttributeProto(name::String, i::Int64) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.INT, + i = i +) + +AttributeProto(name::String, f::Float32) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.FLOAT, + f = f +) + +AttributeProto(name::String, floats::AbstractVector{Float32}) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.FLOATS, + floats = floats +) + +AttributeProto(name::String, f::Float64) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.FLOAT, + f = Float32(f) +) + +AttributeProto(name::String, i::NTuple{N, Int64}) where N = AttributeProto(name, collect(i)) + +AttributeProto(name::String, i::AbstractVector{Int64}) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.INTS, + ints = i +) + +AttributeProto(name::String, str::AbstractString) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.STRING, + s = Vector{UInt8}(str) +) + +AttributeProto(name::String, strings::AbstractVector{<:AbstractString}) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.STRINGS, + strings = Vector{UInt8}.(strings) +) + +AttributeProto(name::String, tensor::TensorProto) = AttributeProto( + name=name, + _type = AttributeProto_AttributeType.TENSOR, + t = tensor +) diff --git a/test/README.md b/test/README.md deleted file mode 100644 index 2deeb5b7..00000000 --- a/test/README.md +++ /dev/null @@ -1,40 +0,0 @@ -## Model Tests: - -While node operator tests are useful in testing a specific operator, Model -tests are used to test the model as a whole. These models can be pretty large (several -hundered MBs at times), and hence individual models are downloaded as and when they need -to be tested. Model tests are used not only to verify the functioning of the -operators working as a single unit (every operator taking an input from anouther node -and feeding the output to another), but these models can also be used directly for any -task, without having to reinvent the wheel by building and training the model from scratch. - -## Running model tests - -You need to run the `modeltests.jl` script to run the model tests on a specific model. - -For example, to test the MNIST pretrained model, run: - -``` -julia modeltests.jl MNIST -``` - -This creates a new `models` directory, downloads and extracts the MNIST pre-trained model -and tests it on the test data provided. (Note: You need to have `wget` installed to -download the model, and `tar` installed to extract it.) - -Currently, four model tests are available. These include the MNIST, Squeezenet, VGG19 and -Emotion_Ferplus models. - -## Writing your own node tests - -As these models become more diverse, it is likely that you might come across operators that -aren't supported by ONNX.jl. In such as case, you might have to implement it yourself (feel -free to open an issue too). - -The `ops.jl` file (`src/graph/ops.jl`) contains the implementation of all operators at this -point. In order to test your implementation, you need to make sure that ONNX provides the -test data for the operator. The `main_test` is the main function used to test individual -operators. It takes in the name of the test file, the expected output and the inputs as its -arguments. Also, please do create a Pull Request if your tests pass, as it might be useful -for the community. - diff --git a/test/arithmetic_ops.jl b/test/arithmetic_ops.jl deleted file mode 100644 index 221d4967..00000000 --- a/test/arithmetic_ops.jl +++ /dev/null @@ -1,274 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -#test add: -main_test("$ONNX_TEST_PATH/test_add", - read_output("$ONNX_TEST_PATH/test_add"), - read_input("$ONNX_TEST_PATH/test_add")[1], - read_input("$ONNX_TEST_PATH/test_add")[2]) - -#test add bcast -main_test("$ONNX_TEST_PATH/test_add_bcast", - read_output("$ONNX_TEST_PATH/test_add_bcast"), - read_input("$ONNX_TEST_PATH/test_add_bcast")[1], - read_input("$ONNX_TEST_PATH/test_add_bcast")[2]) - -#test mul -main_test("$ONNX_TEST_PATH/test_mul", - read_output("$ONNX_TEST_PATH/test_mul"), - read_input("$ONNX_TEST_PATH/test_mul")[1], - read_input("$ONNX_TEST_PATH/test_mul")[2]) - -#test mul bcast -main_test("$ONNX_TEST_PATH/test_mul_bcast", - read_output("$ONNX_TEST_PATH/test_mul_bcast"), - read_input("$ONNX_TEST_PATH/test_mul_bcast")[1], - read_input("$ONNX_TEST_PATH/test_mul_bcast")[2]) - -#test sub -main_test("$ONNX_TEST_PATH/test_sub", - read_output("$ONNX_TEST_PATH/test_sub"), - read_input("$ONNX_TEST_PATH/test_sub")[1], - read_input("$ONNX_TEST_PATH/test_sub")[2]) - -#test sub bcast -main_test("$ONNX_TEST_PATH/test_sub_bcast", - read_output("$ONNX_TEST_PATH/test_sub_bcast"), - read_input("$ONNX_TEST_PATH/test_sub_bcast")[1], - read_input("$ONNX_TEST_PATH/test_sub_bcast")[2]) - -#test div -main_test("$ONNX_TEST_PATH/test_div", - read_output("$ONNX_TEST_PATH/test_div"), - read_input("$ONNX_TEST_PATH/test_div")[1], - read_input("$ONNX_TEST_PATH/test_div")[2]) - -#test div bcast -main_test("$ONNX_TEST_PATH/test_div_bcast", - read_output("$ONNX_TEST_PATH/test_div_bcast"), - read_input("$ONNX_TEST_PATH/test_div_bcast")[1], - read_input("$ONNX_TEST_PATH/test_div_bcast")[2]) - -#test matmul 2d -main_test("$ONNX_TEST_PATH/test_matmul_2d", - read_output("$ONNX_TEST_PATH/test_matmul_2d"), - read_input("$ONNX_TEST_PATH/test_matmul_2d")[1], - read_input("$ONNX_TEST_PATH/test_matmul_2d")[2]) -#test exp -main_test("$ONNX_TEST_PATH/test_exp", - read_output("$ONNX_TEST_PATH/test_exp"), - read_input("$ONNX_TEST_PATH/test_exp")[1]) - -#test reciprocal -main_test("$ONNX_TEST_PATH/test_reciprocal", - read_output("$ONNX_TEST_PATH/test_reciprocal"), - read_input("$ONNX_TEST_PATH/test_reciprocal")[1]) - -#test reciprocal example -main_test("$ONNX_TEST_PATH/test_reciprocal_example", - read_output("$ONNX_TEST_PATH/test_reciprocal_example"), - read_input("$ONNX_TEST_PATH/test_reciprocal_example")[1]) - -#test floor -main_test("$ONNX_TEST_PATH/test_floor", - read_output("$ONNX_TEST_PATH/test_floor"), - read_input("$ONNX_TEST_PATH/test_floor")[1]) - -#test ceil -main_test("$ONNX_TEST_PATH/test_ceil", - read_output("$ONNX_TEST_PATH/test_ceil"), - read_input("$ONNX_TEST_PATH/test_ceil")[1]) - -#test log -main_test("$ONNX_TEST_PATH/test_log", - read_output("$ONNX_TEST_PATH/test_log"), - read_input("$ONNX_TEST_PATH/test_log")[1]) - -#test pow -main_test("$ONNX_TEST_PATH/test_pow", - read_output("$ONNX_TEST_PATH/test_pow"), - read_input("$ONNX_TEST_PATH/test_pow")[1], - read_input("$ONNX_TEST_PATH/test_pow")[2]) - -#test pow bcast -main_test("$ONNX_TEST_PATH/test_pow_bcast_array", - read_output("$ONNX_TEST_PATH/test_pow_bcast_array"), - read_input("$ONNX_TEST_PATH/test_pow_bcast_array")[1], - read_input("$ONNX_TEST_PATH/test_pow_bcast_array")[2]) - -#test pow bcast scalar -main_test("$ONNX_TEST_PATH/test_pow_bcast_scalar", - read_output("$ONNX_TEST_PATH/test_pow_bcast_scalar"), - read_input("$ONNX_TEST_PATH/test_pow_bcast_scalar")[1], - read_input("$ONNX_TEST_PATH/test_pow_bcast_scalar")[2]) - -#test pow example -main_test("$ONNX_TEST_PATH/test_pow_example", - read_output("$ONNX_TEST_PATH/test_pow_example"), - read_input("$ONNX_TEST_PATH/test_pow_example")[1], - read_input("$ONNX_TEST_PATH/test_pow_example")[2]) - -#test relu -main_test("$ONNX_TEST_PATH/test_relu", - read_output("$ONNX_TEST_PATH/test_relu"), - read_input("$ONNX_TEST_PATH/test_relu")[1]) - -#Test sum one input -main_test("$ONNX_TEST_PATH/test_sum_one_input", - read_output("$ONNX_TEST_PATH/test_sum_one_input"), - read_input( "$ONNX_TEST_PATH/test_sum_one_input")[1]) - -#Test sum two inputs -main_test("$ONNX_TEST_PATH/test_sum_two_inputs", - read_output("$ONNX_TEST_PATH/test_sum_two_inputs"), - read_input("$ONNX_TEST_PATH/test_sum_two_inputs")[1], - read_input("$ONNX_TEST_PATH/test_sum_two_inputs")[2]) - -# Test Prelu example -main_test("$ONNX_TEST_PATH/test_prelu_example", read_output("$ONNX_TEST_PATH/test_prelu_example"), - read_input("$ONNX_TEST_PATH/test_prelu_example")[1], - read_input("$ONNX_TEST_PATH/test_prelu_example")[2]) - -# Test PRelu broadcast -main_test("$ONNX_TEST_PATH/test_prelu_broadcast", read_output("$ONNX_TEST_PATH/test_prelu_broadcast"), - read_input("$ONNX_TEST_PATH/test_prelu_broadcast")[1], - read_input("$ONNX_TEST_PATH/test_prelu_broadcast")[2]) - -## Trigonometric ops - -#Test sin -main_test("$ONNX_TEST_PATH/test_sin", - read_output("$ONNX_TEST_PATH/test_sin"), - read_input("$ONNX_TEST_PATH/test_sin")[1]) - -main_test("$ONNX_TEST_PATH/test_sin_example", - read_output("$ONNX_TEST_PATH/test_sin_example"), - read_input("$ONNX_TEST_PATH/test_sin_example")[1]) -#Test cos -main_test("$ONNX_TEST_PATH/test_cos", - read_output("$ONNX_TEST_PATH/test_cos"), - read_input("$ONNX_TEST_PATH/test_cos")[1]) - -main_test("$ONNX_TEST_PATH/test_cos_example", - read_output("$ONNX_TEST_PATH/test_cos_example"), - read_input("$ONNX_TEST_PATH/test_cos_example")[1]) - -#Test tan -main_test("$ONNX_TEST_PATH/test_tan", - read_output("$ONNX_TEST_PATH/test_tan"), - read_input("$ONNX_TEST_PATH/test_tan")[1]) - -main_test("$ONNX_TEST_PATH/test_tan_example", - read_output("$ONNX_TEST_PATH/test_tan_example"), - read_input("$ONNX_TEST_PATH/test_tan_example")[1]) - -#test asin -main_test("$ONNX_TEST_PATH/test_asin", - read_output("$ONNX_TEST_PATH/test_asin"), - read_input("$ONNX_TEST_PATH/test_asin")[1]) - -main_test("$ONNX_TEST_PATH/test_asin_example", - read_output("$ONNX_TEST_PATH/test_asin_example"), - read_input("$ONNX_TEST_PATH/test_asin_example")[1]) - -#test acos -main_test("$ONNX_TEST_PATH/test_acos", - read_output("$ONNX_TEST_PATH/test_acos"), - read_input("$ONNX_TEST_PATH/test_acos")[1]) - -main_test("$ONNX_TEST_PATH/test_acos_example", - read_output("$ONNX_TEST_PATH/test_acos_example"), - read_input("$ONNX_TEST_PATH/test_acos_example")[1]) - -#test atan -main_test("$ONNX_TEST_PATH/test_atan", - read_output("$ONNX_TEST_PATH/test_atan"), - read_input("$ONNX_TEST_PATH/test_atan")[1]) - -main_test("$ONNX_TEST_PATH/test_atan_example", - read_output("$ONNX_TEST_PATH/test_atan_example"), - read_input("$ONNX_TEST_PATH/test_atan_example")[1]) - -# Flatten axis 0 -main_test("$ONNX_TEST_PATH/test_flatten_axis0", - read_output("$ONNX_TEST_PATH/test_flatten_axis0"), - read_input("$ONNX_TEST_PATH/test_flatten_axis0")[1]) - -# Flatten axis 1 -main_test("$ONNX_TEST_PATH/test_flatten_axis1", - read_output("$ONNX_TEST_PATH/test_flatten_axis1"), - read_input("$ONNX_TEST_PATH/test_flatten_axis1")[1]) - -# Flatten axis 2 -main_test("$ONNX_TEST_PATH/test_flatten_axis2", - read_output("$ONNX_TEST_PATH/test_flatten_axis2"), - read_input("$ONNX_TEST_PATH/test_flatten_axis2")[1]) - -# Flatten axis 3 -main_test("$ONNX_TEST_PATH/test_flatten_axis3", - read_output("$ONNX_TEST_PATH/test_flatten_axis3"), - read_input("$ONNX_TEST_PATH/test_flatten_axis3")[1]) - -# Flatten default axis -main_test("$ONNX_TEST_PATH/test_flatten_default_axis", - read_output("$ONNX_TEST_PATH/test_flatten_default_axis"), - read_input("$ONNX_TEST_PATH/test_flatten_default_axis")[1]) - -# test gemm broadcast -main_test("$ONNX_TEST_PATH/test_gemm_broadcast", read_output("$ONNX_TEST_PATH/test_gemm_broadcast"), - read_input("$ONNX_TEST_PATH/test_gemm_broadcast")[1], - read_input("$ONNX_TEST_PATH/test_gemm_broadcast")[2], - read_input("$ONNX_TEST_PATH/test_gemm_broadcast")[3]) - -# test gemm nobroadcast -main_test("$ONNX_TEST_PATH/test_gemm_nobroadcast", read_output("$ONNX_TEST_PATH/test_gemm_nobroadcast"), - read_input("$ONNX_TEST_PATH/test_gemm_nobroadcast")[1], - read_input("$ONNX_TEST_PATH/test_gemm_nobroadcast")[2], - read_input("$ONNX_TEST_PATH/test_gemm_nobroadcast")[3]) - -# test unsqueeze -main_test("$ONNX_TEST_PATH/test_unsqueeze", read_output("$ONNX_TEST_PATH/test_unsqueeze"), - read_input("$ONNX_TEST_PATH/test_unsqueeze")[1]) - -# test abs -main_test("$ONNX_TEST_PATH/test_abs", read_output("$ONNX_TEST_PATH/test_abs"), - read_input("$ONNX_TEST_PATH/test_abs")[1]) - -# test clip -main_test("$ONNX_TEST_PATH/test_clip", read_output("$ONNX_TEST_PATH/test_clip"), - read_input("$ONNX_TEST_PATH/test_clip")[1]) - -# test clip default inbounds -main_test("$ONNX_TEST_PATH/test_clip_default_inbounds", - read_output("$ONNX_TEST_PATH/test_clip_default_inbounds"), - read_input("$ONNX_TEST_PATH/test_clip_default_inbounds")[1]) - -# test clip default max -main_test("$ONNX_TEST_PATH/test_clip_default_max", - read_output("$ONNX_TEST_PATH/test_clip_default_max"), - read_input("$ONNX_TEST_PATH/test_clip_default_max")[1]) - -# test clip default min -main_test("$ONNX_TEST_PATH/test_clip_default_min", - read_output("$ONNX_TEST_PATH/test_clip_default_min"), - read_input("$ONNX_TEST_PATH/test_clip_default_min")[1]) - -# test clip example -main_test("$ONNX_TEST_PATH/test_clip_example", - read_output("$ONNX_TEST_PATH/test_clip_example"), - read_input("$ONNX_TEST_PATH/test_clip_example")[1]) - -# test clip inbounds -main_test("$ONNX_TEST_PATH/test_clip_inbounds", - read_output("$ONNX_TEST_PATH/test_clip_inbounds"), - read_input("$ONNX_TEST_PATH/test_clip_inbounds")[1]) - -# test clip outbounds -main_test("$ONNX_TEST_PATH/test_clip_outbounds", - read_output("$ONNX_TEST_PATH/test_clip_outbounds"), - read_input("$ONNX_TEST_PATH/test_clip_outbounds")[1]) - -# test clip splitbounds -main_test("$ONNX_TEST_PATH/test_clip_splitbounds", - read_output("$ONNX_TEST_PATH/test_clip_splitbounds"), - read_input("$ONNX_TEST_PATH/test_clip_splitbounds")[1]) \ No newline at end of file diff --git a/test/constant.jl b/test/constant.jl deleted file mode 100644 index 2c9e49b7..00000000 --- a/test/constant.jl +++ /dev/null @@ -1,24 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -#Test Constant -main_test("$ONNX_TEST_PATH/test_constant", read_output("$ONNX_TEST_PATH/test_constant")) - - -#Test BatchNorm epsilon -main_test("$ONNX_TEST_PATH/test_batchnorm_epsilon", read_output("$ONNX_TEST_PATH/test_batchnorm_epsilon"), - read_input("$ONNX_TEST_PATH/test_batchnorm_epsilon")[1] , - read_input("$ONNX_TEST_PATH/test_batchnorm_epsilon")[2], - read_input("$ONNX_TEST_PATH/test_batchnorm_epsilon")[3], - read_input("$ONNX_TEST_PATH/test_batchnorm_epsilon")[4], - read_input("$ONNX_TEST_PATH/test_batchnorm_epsilon")[5]) -""" -# Test BatchNorm example -main_test("$ONNX_TEST_PATH/test_batchnorm_example", - read_output("$ONNX_TEST_PATH/test_batchnorm_example"), - read_input("$ONNX_TEST_PATH/test_batchnorm_example")[1] , - read_input("$ONNX_TEST_PATH/test_batchnorm_example")[2], - read_input("$ONNX_TEST_PATH/test_batchnorm_example")[3], - read_input("$ONNX_TEST_PATH/test_batchnorm_example")[4], - read_input("$ONNX_TEST_PATH/test_batchnorm_example")[5]) -""" \ No newline at end of file diff --git a/test/conv.jl b/test/conv.jl deleted file mode 100644 index eb522685..00000000 --- a/test/conv.jl +++ /dev/null @@ -1,40 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -# test basic conv with padding -ip = read_input("$ONNX_TEST_PATH/test_basic_conv_with_padding") -main_test("$ONNX_TEST_PATH/test_basic_conv_with_padding", - read_output("$ONNX_TEST_PATH/test_basic_conv_with_padding"), - read_input("$ONNX_TEST_PATH/test_basic_conv_with_padding")[1], - read_input("$ONNX_TEST_PATH/test_basic_conv_with_padding")[2]) - -# test basic conv without padding -ip = read_input("$ONNX_TEST_PATH/test_basic_conv_without_padding") -main_test("$ONNX_TEST_PATH/test_basic_conv_without_padding", - read_output("$ONNX_TEST_PATH/test_basic_conv_without_padding"), - read_input("$ONNX_TEST_PATH/test_basic_conv_without_padding")[1], - read_input("$ONNX_TEST_PATH/test_basic_conv_without_padding")[2]) - -# Conv with strides and no pads -ip = read_input("$ONNX_TEST_PATH/test_conv_with_strides_no_padding") -main_test("$ONNX_TEST_PATH/test_conv_with_strides_no_padding", - read_output("$ONNX_TEST_PATH/test_conv_with_strides_no_padding"), - read_input("$ONNX_TEST_PATH/test_conv_with_strides_no_padding")[1], - read_input("$ONNX_TEST_PATH/test_conv_with_strides_no_padding")[2]) - -# Conv with strides and pads -ip = read_input("$ONNX_TEST_PATH/test_conv_with_strides_padding") -main_test("$ONNX_TEST_PATH/test_conv_with_strides_padding", - read_output("$ONNX_TEST_PATH/test_conv_with_strides_padding"), - read_input("$ONNX_TEST_PATH/test_conv_with_strides_padding")[1], - read_input("$ONNX_TEST_PATH/test_conv_with_strides_padding")[2]) - -# Test Dropout default -main_test("$ONNX_TEST_PATH/test_dropout_default", - read_output("$ONNX_TEST_PATH/test_dropout_default"), - read_input("$ONNX_TEST_PATH/test_dropout_default")[1]) - -# Test Dropout random -main_test("$ONNX_TEST_PATH/test_dropout_random", - read_output("$ONNX_TEST_PATH/test_dropout_random"), - read_input("$ONNX_TEST_PATH/test_dropout_random")[1]) \ No newline at end of file diff --git a/test/conversions.jl b/test/conversions.jl deleted file mode 100644 index d878d24a..00000000 --- a/test/conversions.jl +++ /dev/null @@ -1,32 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -# Cast double to float -main_test("$ONNX_TEST_PATH/test_cast_DOUBLE_to_FLOAT", - read_output("$ONNX_TEST_PATH/test_cast_DOUBLE_to_FLOAT"), - read_input("$ONNX_TEST_PATH/test_cast_DOUBLE_to_FLOAT")[1]) - -# Cast double to float16 -main_test("$ONNX_TEST_PATH/test_cast_DOUBLE_to_FLOAT16", - read_output("$ONNX_TEST_PATH/test_cast_DOUBLE_to_FLOAT16"), - read_input("$ONNX_TEST_PATH/test_cast_DOUBLE_to_FLOAT16")[1]) - -# Cast Float16 to double -main_test("$ONNX_TEST_PATH/test_cast_FLOAT16_to_DOUBLE", - read_output("$ONNX_TEST_PATH/test_cast_FLOAT16_to_DOUBLE"), - read_input("$ONNX_TEST_PATH/test_cast_FLOAT16_to_DOUBLE")[1]) - -# Cast Float16 to Float -main_test("$ONNX_TEST_PATH/test_cast_FLOAT16_to_FLOAT", - read_output("$ONNX_TEST_PATH/test_cast_FLOAT16_to_FLOAT"), - read_input("$ONNX_TEST_PATH/test_cast_FLOAT16_to_FLOAT")[1]) - -# Cast Float to Double -main_test("$ONNX_TEST_PATH/test_cast_FLOAT_to_DOUBLE", - read_output("$ONNX_TEST_PATH/test_cast_FLOAT_to_DOUBLE"), - read_input("$ONNX_TEST_PATH/test_cast_FLOAT_to_DOUBLE")[1]) - -# Cast Float to Float16 -main_test("$ONNX_TEST_PATH/test_cast_FLOAT_to_FLOAT16", - read_output("$ONNX_TEST_PATH/test_cast_FLOAT_to_FLOAT16"), - read_input("$ONNX_TEST_PATH/test_cast_FLOAT_to_FLOAT16")[1]) \ No newline at end of file diff --git a/test/logical_ops.jl b/test/logical_ops.jl deleted file mode 100644 index 24a3719f..00000000 --- a/test/logical_ops.jl +++ /dev/null @@ -1,127 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -# test and 3v1d -main_test("$ONNX_TEST_PATH/test_and_bcast3v1d", - read_output("$ONNX_TEST_PATH/test_and_bcast3v1d"), - read_input("$ONNX_TEST_PATH/test_and_bcast3v1d")[1], - read_input("$ONNX_TEST_PATH/test_and_bcast3v1d")[2]) - -# test and 3v2d -main_test("$ONNX_TEST_PATH/test_and_bcast3v2d", - read_output("$ONNX_TEST_PATH/test_and_bcast3v2d"), - read_input("$ONNX_TEST_PATH/test_and_bcast3v2d")[1], - read_input("$ONNX_TEST_PATH/test_and_bcast3v2d")[2]) - -# test and 4v2d -main_test("$ONNX_TEST_PATH/test_and_bcast4v2d", - read_output("$ONNX_TEST_PATH/test_and_bcast4v2d"), - read_input("$ONNX_TEST_PATH/test_and_bcast4v2d")[1], - read_input("$ONNX_TEST_PATH/test_and_bcast4v2d")[2]) - -# test and 4v3d -main_test("$ONNX_TEST_PATH/test_and_bcast4v3d", - read_output("$ONNX_TEST_PATH/test_and_bcast4v3d"), - read_input("$ONNX_TEST_PATH/test_and_bcast4v3d")[1], - read_input("$ONNX_TEST_PATH/test_and_bcast4v3d")[2]) - -# test and 4v4d -main_test("$ONNX_TEST_PATH/test_and_bcast4v4d", - read_output("$ONNX_TEST_PATH/test_and_bcast4v4d"), - read_input("$ONNX_TEST_PATH/test_and_bcast4v4d")[1], - read_input("$ONNX_TEST_PATH/test_and_bcast4v4d")[2]) - -#Test and2d -main_test("$ONNX_TEST_PATH/test_and2d", - read_output("$ONNX_TEST_PATH/test_and2d"), - read_input("$ONNX_TEST_PATH/test_and2d")[1], - read_input("$ONNX_TEST_PATH/test_and2d")[2]) - -#Test and3d -main_test("$ONNX_TEST_PATH/test_and3d", - read_output("$ONNX_TEST_PATH/test_and3d"), - read_input("$ONNX_TEST_PATH/test_and3d")[1], - read_input("$ONNX_TEST_PATH/test_and3d")[2]) - -#Test and4d -main_test("$ONNX_TEST_PATH/test_and4d", - read_output("$ONNX_TEST_PATH/test_and4d"), - read_input("$ONNX_TEST_PATH/test_and4d")[1], - read_input("$ONNX_TEST_PATH/test_and4d")[2]) - -# Test identity -main_test("$ONNX_TEST_PATH/test_identity", - read_output("$ONNX_TEST_PATH/test_identity"), - read_input("$ONNX_TEST_PATH/test_identity")[1]) - -# Test equal -main_test("$ONNX_TEST_PATH/test_equal", - read_output("$ONNX_TEST_PATH/test_equal"), - read_input("$ONNX_TEST_PATH/test_equal")[1], - read_input("$ONNX_TEST_PATH/test_equal")[2]) - -# Test equal bcast -main_test("$ONNX_TEST_PATH/test_equal_bcast", - read_output("$ONNX_TEST_PATH/test_equal_bcast"), - read_input("$ONNX_TEST_PATH/test_equal_bcast")[1], - read_input("$ONNX_TEST_PATH/test_equal_bcast")[2]) - -# Test greater -main_test("$ONNX_TEST_PATH/test_greater", - read_output("$ONNX_TEST_PATH/test_greater"), - read_input("$ONNX_TEST_PATH/test_greater")[1], - read_input("$ONNX_TEST_PATH/test_greater")[2]) - -# Test greater bcast -main_test("$ONNX_TEST_PATH/test_greater_bcast", - read_output("$ONNX_TEST_PATH/test_greater_bcast"), - read_input("$ONNX_TEST_PATH/test_greater_bcast")[1], - read_input("$ONNX_TEST_PATH/test_greater_bcast")[2]) - -# test xor2d -main_test("$ONNX_TEST_PATH/test_xor2d", - read_output("$ONNX_TEST_PATH/test_xor2d"), - read_input("$ONNX_TEST_PATH/test_xor2d")[1], - read_input("$ONNX_TEST_PATH/test_xor2d")[2]) - -# test xor3d -main_test("$ONNX_TEST_PATH/test_xor3d", - read_output("$ONNX_TEST_PATH/test_xor3d"), - read_input("$ONNX_TEST_PATH/test_xor3d")[1], - read_input("$ONNX_TEST_PATH/test_xor3d")[2]) - -# test xor4d -main_test("$ONNX_TEST_PATH/test_xor4d", - read_output("$ONNX_TEST_PATH/test_xor4d"), - read_input("$ONNX_TEST_PATH/test_xor4d")[1], - read_input("$ONNX_TEST_PATH/test_xor4d")[2]) - -# test xor bcast 3v1d -main_test("$ONNX_TEST_PATH/test_xor_bcast3v1d", - read_output("$ONNX_TEST_PATH/test_xor_bcast3v1d"), - read_input("$ONNX_TEST_PATH/test_xor_bcast3v1d")[1], - read_input("$ONNX_TEST_PATH/test_xor_bcast3v1d")[2]) - -# test xor bcast 3v2d -main_test("$ONNX_TEST_PATH/test_xor_bcast3v2d", - read_output("$ONNX_TEST_PATH/test_xor_bcast3v2d"), - read_input("$ONNX_TEST_PATH/test_xor_bcast3v2d")[1], - read_input("$ONNX_TEST_PATH/test_xor_bcast3v2d")[2]) - -# test xor bcast 4v2d -main_test("$ONNX_TEST_PATH/test_xor_bcast4v2d", - read_output("$ONNX_TEST_PATH/test_xor_bcast4v2d"), - read_input("$ONNX_TEST_PATH/test_xor_bcast4v2d")[1], - read_input("$ONNX_TEST_PATH/test_xor_bcast4v2d")[2]) - -# test xor bcast 4v3d -main_test("$ONNX_TEST_PATH/test_xor_bcast4v3d", - read_output("$ONNX_TEST_PATH/test_xor_bcast4v3d"), - read_input("$ONNX_TEST_PATH/test_xor_bcast4v3d")[1], - read_input("$ONNX_TEST_PATH/test_xor_bcast4v3d")[2]) - -# test xor bcast 4v4d -main_test("$ONNX_TEST_PATH/test_xor_bcast4v4d", - read_output("$ONNX_TEST_PATH/test_xor_bcast4v4d"), - read_input("$ONNX_TEST_PATH/test_xor_bcast4v4d")[1], - read_input("$ONNX_TEST_PATH/test_xor_bcast4v4d")[2]) \ No newline at end of file diff --git a/test/lstm.jl b/test/lstm.jl deleted file mode 100644 index 2c4e0320..00000000 --- a/test/lstm.jl +++ /dev/null @@ -1,12 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -# Test LSTM with no bias -main_test("$ONNX_TEST_PATH/test_lstm_defaults", read_output("$ONNX_TEST_PATH/test_lstm_defaults"), - read_input("$ONNX_TEST_PATH/test_lstm_defaults")[1], read_input("$ONNX_TEST_PATH/test_lstm_defaults")[2], - read_input("$ONNX_TEST_PATH/test_lstm_defaults")[3]) - -# Test LSTM with bias -main_test("$ONNX_TEST_PATH/test_lstm_with_initial_bias", read_output("$ONNX_TEST_PATH/test_lstm_with_initial_bias"), - read_input("$ONNX_TEST_PATH/test_lstm_with_initial_bias")[1], read_input("$ONNX_TEST_PATH/test_lstm_with_initial_bias")[2], - read_input("$ONNX_TEST_PATH/test_lstm_with_initial_bias")[3], read_input("$ONNX_TEST_PATH/test_lstm_with_initial_bias")[4]) \ No newline at end of file diff --git a/test/modeltests.jl b/test/modeltests.jl deleted file mode 100644 index 696d755c..00000000 --- a/test/modeltests.jl +++ /dev/null @@ -1,71 +0,0 @@ -using ONNX, Flux, ProtoBuf -using Base.Test - -args = map(x->lowercase(string(x)), ARGS) - -name_to_link = Dict() -name_to_link["squeezenet"] = "https://s3.amazonaws.com/download.onnx/models/opset_8/squeezenet.tar.gz" -name_to_link["mnist"] = "https://www.cntk.ai/OnnxModels/mnist/opset_7/mnist.tar.gz" -name_to_link["emotion_ferplus"] = "https://www.cntk.ai/OnnxModels/emotion_ferplus/opset_7/emotion_ferplus.tar.gz" -name_to_link["vgg19"] = "https://s3.amazonaws.com/download.onnx/models/opset_8/vgg19.tar.gz" - -function read_ip(name) - ip = readproto(open(name), ONNX.Proto.TensorProto()) |> ONNX.get_array - if ndims(ip) ==2 - ip = reshape(pi, (size(ip)[1], size(ip)[2], 1, 1)) - elseif ndims(ip==3) - ip = reshape(ip, (size(ip)[1], size(ip)[2], size(ip)[3], 1)) - end - return ip -end - -function read_ip(name) - ip = readproto(open(name), ONNX.Proto.TensorProto()) |> ONNX.get_array - return ip -end - -function main(name) - if !("models" in readdir()) - mkdir("models") - cd("models") - else - cd("models") - end - if name in readdir() - println("Testing predownloaded model") - else - run(`wget $(name_to_link[name])`) - run(`tar -xvzf $name.tar.gz`) - end -end - - -main(args[1]) -cd(args[1]) -ONNX.load_model("model.onnx") -weights = ONNX.load_weights("weights.bson") -model = include(pwd()*"/model.jl") -num_test=2 -if args[1] == "squeezenet" - num_test = 11 -elseif args[1] == "vgg19" - num_test = 2 -end -@testset begin -if (args[1] == "squeezenet") || (args[1] == "vgg19") - for x=0:num_test - @test (findmax(model(read_ip("test_data_set_$x/input_0.pb")))[2] == - findmax(read_ip("test_data_set_$x/output_0.pb"))[2]) - end -elseif (args[1] == "mnist") - for x=0:num_test - @test (findmax(model(read_ip("test_data_set_$x/input_0.pb")))[2] == - findmax(read_ip("test_data_set_$x/output_0.pb"))[2]) - end -else - for x=0:num_test - @test (findmax(model(read_ip("test_data_set_$x/input_0.pb")))[2] == - findmax(read_ip("test_data_set_$x/output_0.pb"))[2]) - end -end -end \ No newline at end of file diff --git a/test/ops_tests.jl b/test/ops_tests.jl deleted file mode 100644 index 1429fca9..00000000 --- a/test/ops_tests.jl +++ /dev/null @@ -1,117 +0,0 @@ -using ONNX, Flux, ProtoBuf -using DataFlow: Call, vertex, syntax, constant -using Test -using Base:run -# test taken from : https://github.com/onnx/onnx/tree/master/onnx/backend/test/data -# clone onnx here if onnx dir does not exist - -if !("onnx" in readdir()) - # clone the package here - println("Downloading test data....") - Base.run(`git clone https://github.com/onnx/onnx.git`) -end - -ONNX_PATH = "./onnx" - -ONNX_TEST_PATH = "$ONNX_PATH/onnx/backend/test/data/node" - - -function read_input(folder_name) - ar = Array{Any, 1}() - for ele in readdir(folder_name*"/test_data_set_0") - push!(ar, Float32.(readproto(open(folder_name*"/test_data_set_0/"*ele), ONNX.Proto.TensorProto()) |> ONNX.get_array)) - end - return ar[1:end-1] -end - -function read_output(folder_name) - ar = Array{Any, 1}() - for ele in readdir(folder_name*"/test_data_set_0") - push!(ar, Float32.(readproto(open(folder_name*"/test_data_set_0/"*ele), ONNX.Proto.TensorProto()) |> ONNX.get_array)) - end - return ar[end] -end - -function read_model(folder_name) - for ele in readdir(folder_name) - if ele=="model.onnx" - return readproto(open(folder_name*"/model.onnx"), ONNX.Proto.ModelProto()) - end - end -end - -function get_optype(a::ONNX.Proto.ModelProto) - g = ONNX.convert(a.graph) - return g.node[1].op_type -end - -function get_dict(a::ONNX.Proto.ModelProto) - g = ONNX.convert(a.graph) - return g.node[1].attribute -end - -function main_test(filename,op_expected, ip...) - if Symbol(get_optype(read_model(filename))) == :Constant - @test get_dict(read_model(filename))[:value] |> ONNX.get_array == op_expected - - elseif Symbol(get_optype(read_model(filename))) == :Conv - - temp = ONNX.ops[Symbol(get_optype(read_model(filename)))](get_dict(read_model(filename)), - Symbol("ip[1]"), Symbol("ip[2]")) |> syntax - - touch("temp_conv.jl") - open("temp_conv.jl","w") do file - str1 = "flipkernel(x) = x[end:-1:1, end:-1:1, :, :] \n" # Remove when Flux directly supports it - write(file, str1*string(temp)) - end - model = include("temp_conv.jl") - rm("temp_conv.jl") - @test model == op_expected - elseif Symbol(get_optype(read_model(filename))) == :MaxPool - temp = ONNX.ops[Symbol(get_optype(read_model(filename)))](get_dict(read_model(filename)), - Symbol("ip[1]")) |> syntax - touch("temp_maxpool.jl") - open("temp_maxpool.jl","w") do file - write(file, string(temp)) - end - - model = include("temp_maxpool.jl") - rm("temp_maxpool.jl") - @test model == op_expected - elseif Symbol(get_optype(read_model(filename))) == :AveragePool - temp = ONNX.ops[Symbol(get_optype(read_model(filename)))](get_dict(read_model(filename)), - Symbol("ip[1]")) |> syntax - touch("temp_averagepool.jl") - open("temp_averagepool.jl","w") do file - write(file, string(temp)) - end - - model = include("temp_averagepool.jl") - rm("temp_averagepool.jl") - @test model ≈ op_expected atol=0.001 - elseif Symbol(get_optype(read_model(filename))) in [:GlobalAveragePool, :GlobalMaxPool] - temp = ONNX.ops[Symbol(get_optype(read_model(filename)))](get_dict(read_model(filename)), - Symbol("ip[1]")) |> syntax - touch("temp_averagepool.jl") - open("temp_averagepool.jl","w") do file - write(file, "using Statistics \n" * string(temp)) - end - model = include("temp_averagepool.jl") - rm("temp_averagepool.jl") - @test model ≈ op_expected atol=0.001 - elseif Symbol(get_optype(read_model(filename))) in [:Expand, :Concat] - temp = ONNX.ops[Symbol(get_optype(read_model(filename)))](get_dict(read_model(filename)), - Symbol("ip[1]"), Symbol("ip[2]")) |> syntax - touch("temp_expand.jl") - open("temp_expand.jl","w") do file - write(file, string(temp)) - end - - model = include("temp_expand.jl") - rm("temp_expand.jl") - @test model ≈ op_expected atol=0.001 - else - @test ONNX.ops[Symbol(get_optype(read_model(filename)))](get_dict(read_model(filename)), - ip...) |> syntax |> eval ≈ op_expected atol=0.001 - end -end diff --git a/test/pooling.jl b/test/pooling.jl deleted file mode 100644 index 29184341..00000000 --- a/test/pooling.jl +++ /dev/null @@ -1,94 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -# Maxpool 1D default: -ip = read_input("$ONNX_TEST_PATH/test_maxpool_1d_default") -main_test("$ONNX_TEST_PATH/test_maxpool_1d_default", - read_output("$ONNX_TEST_PATH/test_maxpool_1d_default"), - read_input("$ONNX_TEST_PATH/test_maxpool_1d_default")[1]) - -# Maxpool 2D default -ip = read_input("$ONNX_TEST_PATH/test_maxpool_2d_default") -main_test("$ONNX_TEST_PATH/test_maxpool_2d_default", - read_output("$ONNX_TEST_PATH/test_maxpool_2d_default"), - read_input("$ONNX_TEST_PATH/test_maxpool_2d_default")[1]) - -# Maxpool 2D precomputed pads: -ip = read_input("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_pads") -main_test("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_pads", - read_output("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_pads"), - read_input("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_pads")[1]) - -# Maxpool 2D precomputed strides: -ip = read_input("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_strides") -main_test("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_strides", - read_output("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_strides"), - read_input("$ONNX_TEST_PATH/test_maxpool_2d_precomputed_strides")[1]) - -# Maxpool 2D strides: -ip = read_input("$ONNX_TEST_PATH/test_maxpool_2d_strides") -main_test("$ONNX_TEST_PATH/test_maxpool_2d_strides", - read_output("$ONNX_TEST_PATH/test_maxpool_2d_strides"), - read_input("$ONNX_TEST_PATH/test_maxpool_2d_strides")[1]) - -# Averagepool 1D -ip = read_input("$ONNX_TEST_PATH/test_averagepool_1d_default") - main_test("$ONNX_TEST_PATH/test_averagepool_1d_default", - read_output("$ONNX_TEST_PATH/test_averagepool_1d_default"), - read_input("$ONNX_TEST_PATH/test_averagepool_1d_default")) - - -# AveragePool 2D Default -ip = read_input("$ONNX_TEST_PATH/test_averagepool_2d_default") -main_test("$ONNX_TEST_PATH/test_averagepool_2d_default", - read_output("$ONNX_TEST_PATH/test_averagepool_2d_default"), - read_input("$ONNX_TEST_PATH/test_averagepool_2d_default")[1]) - -# Averagepool 2d pads count include pad -ip = read_input("$ONNX_TEST_PATH/test_averagepool_2d_pads_count_include_pad") -main_test("$ONNX_TEST_PATH/test_averagepool_2d_pads_count_include_pad", - read_output("$ONNX_TEST_PATH/test_averagepool_2d_pads_count_include_pad"), - read_input("$ONNX_TEST_PATH/test_averagepool_2d_pads_count_include_pad")[1]); - -# AveragePool 2d precomputed pads count include pad -ip = read_input("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_pads_count_include_pad") -main_test("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_pads_count_include_pad", - read_output("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_pads_count_include_pad"), - read_input("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_pads_count_include_pad")[1]) - -# Averagepool precomputed strides -ip = read_input("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_strides") - main_test("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_strides", - read_output("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_strides"), - read_input("$ONNX_TEST_PATH/test_averagepool_2d_precomputed_strides")) - -# AveragePool 2D Strides -ip = read_input("$ONNX_TEST_PATH/test_averagepool_2d_strides") -main_test("$ONNX_TEST_PATH/test_averagepool_2d_strides", - read_output("$ONNX_TEST_PATH/test_averagepool_2d_strides"), - read_input("$ONNX_TEST_PATH/test_averagepool_2d_strides")[1]) - - -# Test globalaveragepool -ip = read_input("$ONNX_TEST_PATH/test_globalaveragepool") -main_test("$ONNX_TEST_PATH/test_globalaveragepool", - read_output("$ONNX_TEST_PATH/test_globalaveragepool"), - read_input("$ONNX_TEST_PATH/test_globalaveragepool")[1]) - -# Test globalaveragepool precomputed -ip = read_input("$ONNX_TEST_PATH/test_globalaveragepool_precomputed") -main_test("$ONNX_TEST_PATH/test_globalaveragepool_precomputed", - read_output("$ONNX_TEST_PATH/test_globalaveragepool_precomputed"), - read_input("$ONNX_TEST_PATH/test_globalaveragepool_precomputed")[1]) - -# Test globalmaxpool -ip = read_input("$ONNX_TEST_PATH/test_globalmaxpool") -main_test("$ONNX_TEST_PATH/test_globalmaxpool", - read_output("$ONNX_TEST_PATH/test_globalmaxpool"), - read_input("$ONNX_TEST_PATH/test_globalmaxpool")[1]) - -# Test globalmaxpool precomputed -ip = read_input("$ONNX_TEST_PATH/test_globalmaxpool_precomputed") -main_test("$ONNX_TEST_PATH/test_globalmaxpool_precomputed", - read_output("$ONNX_TEST_PATH/test_globalmaxpool_precomputed"), - read_input("$ONNX_TEST_PATH/test_globalmaxpool_precomputed")[1]) \ No newline at end of file diff --git a/test/readwrite.jl b/test/readwrite.jl new file mode 100644 index 00000000..29f2c4b0 --- /dev/null +++ b/test/readwrite.jl @@ -0,0 +1,72 @@ +@testset "Read and write" begin + import ONNX + + function serdeser(p::T) where T + iob = PipeBuffer(); + ONNX.writeproto(iob, p) + return ONNX.readproto(iob, T()) + end + + @testset "TensorProto" begin + import ONNX: TensorProto, array + + @testset "Tensor type $T size $s" for T in (Int8, Int32, Int64, Float16, Float32, Float64), s in ((1,), + (1, 2), + (1, 2, 3), + (1, 2, 3, 4), + (1, 2, 3, 4, 5)) + exp = reshape(collect(T, 1:prod(s)), s...) + @test TensorProto(exp) |> serdeser |> array == exp + end + end + + @testset "ValueInfo" begin + import ONNX: ValueInfoProto + + @testset "ValueInfo shape $s" for s in ((), (missing,), (1, 2), (3, 4, missing)) + + vip = ValueInfoProto("test", s) + + dvip = serdeser(vip) + + @test dvip.name == vip.name + + vsize = size(dvip) + @test length(vsize) == length(s) + if !isempty(s) + @test vsize[findall(!ismissing, s)] == Tuple(skipmissing(s)) + end + end + end + + @testset "Attribute" begin + import ONNX: AttributeProto, TensorProto, attribute, array + + @testset "Attribute type $(first(p))" for p in ( + :Int64 => 12, + :Float32 => 23f0, + :Float32s => Float32.(1:4), + :Int64s => [1, 2, 3, 4], + :String => "relu", + :Strings => split("abcdefg", "") + ) + @test AttributeProto(p) |> serdeser |> attribute == p + end + + @testset "Attribute type Float64" begin + # Float64 does not exist as attribute type in ONNX so above test will fail with rounging errors + @test AttributeProto(:ff => 1.23) |> serdeser |> attribute |> last == 1.23f0 + end + + @testset "Attribute type TensorProto" begin + # TensorProto has undef fields which mess up straigh comparison + arr = collect(1:4) + @test AttributeProto(:ff => TensorProto(arr)) |> serdeser |> attribute |> last |> array == arr + end + + @testset "Attribute Dict" begin + attrs = Dict(:int => 12, :str => "aaa", :floats => Float32.(2:5)) + @test pairs(attrs) |> collect .|> AttributeProto .|> serdeser |> Dict == attrs + end + end +end \ No newline at end of file diff --git a/test/reshape.jl b/test/reshape.jl deleted file mode 100644 index cfddafa9..00000000 --- a/test/reshape.jl +++ /dev/null @@ -1,96 +0,0 @@ -using ONNX, Flux, ProtoBuf -include("ops_tests.jl") - -#test reshape one dim: -main_test("$ONNX_TEST_PATH/test_reshape_one_dim", - read_output("$ONNX_TEST_PATH/test_reshape_extended_dims"), - read_input("$ONNX_TEST_PATH/test_reshape_extended_dims")[1], - read_input("$ONNX_TEST_PATH/test_reshape_extended_dims")[2]) - -#test reshape extended dim: -main_test("$ONNX_TEST_PATH/test_reshape_extended_dims", - read_output("$ONNX_TEST_PATH/test_reshape_extended_dims"), - read_input("$ONNX_TEST_PATH/test_reshape_extended_dims")[1], - read_input("$ONNX_TEST_PATH/test_reshape_extended_dims")[2]) - -#test reshape reordered dim: -main_test("$ONNX_TEST_PATH/test_reshape_reordered_dims", - read_output("$ONNX_TEST_PATH/test_reshape_reordered_dims"), - read_input("$ONNX_TEST_PATH/test_reshape_reordered_dims")[1], - read_input("$ONNX_TEST_PATH/test_reshape_reordered_dims")[2]) - -#test reshape reduced dim -main_test("$ONNX_TEST_PATH/test_reshape_reduced_dims", - read_output("$ONNX_TEST_PATH/test_reshape_reduced_dims"), - read_input("$ONNX_TEST_PATH/test_reshape_reduced_dims")[1], - read_input("$ONNX_TEST_PATH/test_reshape_reduced_dims")[2]) - -## Transpose test: - -main_test("$ONNX_TEST_PATH/test_transpose_all_permutations_0", - read_output("$ONNX_TEST_PATH/test_transpose_all_permutations_0"), - read_input("$ONNX_TEST_PATH/test_transpose_all_permutations_0")[1]) - -main_test("$ONNX_TEST_PATH/test_transpose_all_permutations_1", - read_output("$ONNX_TEST_PATH/test_transpose_all_permutations_1"), - read_input("$ONNX_TEST_PATH/test_transpose_all_permutations_1")[1]) - -main_test("$ONNX_TEST_PATH/test_transpose_all_permutations_2", - read_output("$ONNX_TEST_PATH/test_transpose_all_permutations_2"), - read_input("$ONNX_TEST_PATH/test_transpose_all_permutations_2")[1]) - -main_test("$ONNX_TEST_PATH/test_transpose_all_permutations_3", - read_output("$ONNX_TEST_PATH/test_transpose_all_permutations_3"), - read_input("$ONNX_TEST_PATH/test_transpose_all_permutations_3")[1]) - -main_test("$ONNX_TEST_PATH/test_transpose_all_permutations_4", - read_output("$ONNX_TEST_PATH/test_transpose_all_permutations_4"), - read_input("$ONNX_TEST_PATH/test_transpose_all_permutations_4")[1]) - -main_test("$ONNX_TEST_PATH/test_transpose_all_permutations_5", - read_output("$ONNX_TEST_PATH/test_transpose_all_permutations_5"), - read_input("$ONNX_TEST_PATH/test_transpose_all_permutations_5")[1]) - -## Test concat - -#Test concat 1d axis 0 -#ip = read_input("$ONNX_TEST_PATH/test_concat_1d_axis_0") -#main_test("$ONNX_TEST_PATH/test_concat_1d_axis_0", -# read_output("$ONNX_TEST_PATH/test_concat_1d_axis_0"), -# read_input("$ONNX_TEST_PATH/test_concat_1d_axis_0")[1], -# read_input("$ONNX_TEST_PATH/test_concat_1d_axis_0")[2]) - -#Test concat 2d axis 0 -#ip = read_input("$ONNX_TEST_PATH/test_concat_2d_axis_0") -#main_test("$ONNX_TEST_PATH/test_concat_2d_axis_0", -# read_output("$ONNX_TEST_PATH/test_concat_2d_axis_0"), -# read_input("$ONNX_TEST_PATH/test_concat_2d_axis_0")[1], -# read_input("$ONNX_TEST_PATH/test_concat_2d_axis_0")[2]) - -#Test concat 2d axis 1 -#ip = read_input("$ONNX_TEST_PATH/test_concat_2d_axis_1") -#main_test("$ONNX_TEST_PATH/test_concat_2d_axis_1", -# read_output("$ONNX_TEST_PATH/test_concat_2d_axis_1"), -# read_input("$ONNX_TEST_PATH/test_concat_2d_axis_1")[1], -# read_input("$ONNX_TEST_PATH/test_concat_2d_axis_1")[2]) - -#Test concat 3d axis 0 -#ip = read_input("$ONNX_TEST_PATH/test_concat_3d_axis_0") -#main_test("$ONNX_TEST_PATH/test_concat_3d_axis_0", -# read_output("$ONNX_TEST_PATH/test_concat_3d_axis_0"), -# read_input("$ONNX_TEST_PATH/test_concat_3d_axis_0")[1], -# read_input("$ONNX_TEST_PATH/test_concat_3d_axis_0")[2]) - -#Test concat 3d axis 1 -#ip = read_input("$ONNX_TEST_PATH/test_concat_3d_axis_1") -#main_test("$ONNX_TEST_PATH/test_concat_3d_axis_1", -# read_output("$ONNX_TEST_PATH/test_concat_3d_axis_1"), -# read_input("$ONNX_TEST_PATH/test_concat_3d_axis_1")[1], -# read_input("$ONNX_TEST_PATH/test_concat_3d_axis_1")[2]) - -#Test concat 3d axis 2 -#ip = read_input("$ONNX_TEST_PATH/test_concat_3d_axis_2") -#main_test("$ONNX_TEST_PATH/test_concat_3d_axis_2", -# read_output("$ONNX_TEST_PATH/test_concat_3d_axis_2"), -# read_input("$ONNX_TEST_PATH/test_concat_3d_axis_2")[1], -# read_input("$ONNX_TEST_PATH/test_concat_3d_axis_2")[2]) diff --git a/test/runtests.jl b/test/runtests.jl index b4c902ae..eb804ac5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,17 +1,4 @@ -using ONNX, Flux, ProtoBuf +using ONNX using Test -include("ops_tests.jl") - -@testset "ONNX" begin - -include("conversions.jl") -#include("constant.jl") -include("logical_ops.jl") -include("pooling.jl") -include("conv.jl") -include("reshape.jl") -include("arithmetic_ops.jl") -include("lstm.jl") - -end +include("readwrite.jl")