From 3fbddf4606200e7b688f1c6bbba75ec10e5f30c6 Mon Sep 17 00:00:00 2001 From: Bram Date: Thu, 14 Dec 2023 14:39:01 +0000 Subject: [PATCH] v1.0.0 (#188) * Added the Abs operator * Issue [#96](https://github.com/AdvancedClimateSystems/gonnx/issues/96): Update import paths. * Setup CI/CD checks * Breaking pipeline? * Add linter and test jobs * Fix syntax * Split build jobs * Fix imports for Abs operator * add PRelu operator * fix typos in prelu comments * Fix new ONNX tests * Better ordering * Issue #151 : Update import path in readme * Keep using Go1.19 * Fix a lot of linter issues * Fix more lint * Got to cast operator and disable `captLocal` * WIP: more lint fixes * WIP: More lint fixes * More wip on lint errors * WIP on lint * More WIP * Fix all lints except errors * Fix all lints * Fixed part of tests * Fix validate input tests. * Worked on fixing tests * Fixed rest of tests * Resolve all MR comments * Added cos operator (#159) * Added cos operator * Replace s struct identifier * Missed characters * Fixed comment * Resolved MR comments * Fix tests * Fix lint * remove unused errors * Fix naming * Group declarations --------- Co-authored-by: Swopper050 * Added acosh operator (#163) * Added acosh operator * Merged develop --------- Co-authored-by: Swopper050 * Add Conv operator (#177) * Working on conv operator * Added all attributes for conv operator * WIP on conv operator * Shit's hard * WIP on conv operator * Set defaults for attributes * Finished computation of dilated kernel * Start of 1D conv implementation * Almost finished 1D convolution * Working 2D conv! * Working convolution operator * Add mnist model * Added tests + bugfixes in conv * Full coverage * Remove unnecessary if * Kept division by 2 * Fix last MR comment * Fix lint? --------- Co-authored-by: Swopper050 * Added acos operator (#162) * Added acos operator * Merge develop * Group declarations --------- Co-authored-by: Swopper050 * Added Sin operator (#157) * Added Sin operator * Added ONNX sin test coverage * Fix tests * Remove unused error * Remove unused error * Fix lint * Fix lint * Use float type --------- Co-authored-by: Swopper050 * Added asin operator (#161) * Added asin operator * Use FloatType --------- Co-authored-by: Swopper050 * Added Sinh operator (#158) * Added sinh operator * Updated comments * Use FloatType --------- Co-authored-by: Swopper050 * Added atan operator (#165) * Added atan operator * Use FloatType --------- Co-authored-by: Swopper050 * Added atanh operator (#166) * Added atanh operator * Use FloatType --------- Co-authored-by: Swopper050 * Added tan operator (#167) * Added tan operator * Use FloatType --------- Co-authored-by: Swopper050 * Added asinh operator (#168) * Added asinh operator * Remove unused types --------- Co-authored-by: Swopper050 * Added cosh operator (#160) * Added cosh operator * Group declarations * Correct Apply comment --------- Co-authored-by: Swopper050 * Added softmax operator (#171) * Added softmax operator * Merged develop * Resolved MR comments --------- Co-authored-by: Swopper050 * Added not operator (#170) * Added not operator * Resolved MR comments --------- Co-authored-by: Swopper050 * Added boolean operators (#180) * Added boolean operators * Fix lint * Fix final linst * Create a generic apply boolean operator * Refactored boolean operators to use binary op generic function * Refactored even more * Generalized binary operations --------- Co-authored-by: Swopper050 * Added comparison operators (#173) * Added equal, greater, greaterOrEqual operator * Merged develop * Resolved MR comments --------- Co-authored-by: Swopper050 * Add RNN operator (#181) * WIP on RNN * WIP on RNN * Working RNN version * Added tests for RNN * Resolved MR comments * Do not export rnn specific constants --------- Co-authored-by: Swopper050 * Add LSTM operator (#183) * WIP on RNN * WIP on RNN * Working RNN version * Added tests for RNN * Working version of LSTM operator * Reusable attrs and tests for LSTM * Refactored recurrent operators to share code * Resolved MR comments * Do not export rnn specific constants --------- Co-authored-by: Swopper050 * Add LinearRegressor operator (#184) * WIP on LinearRegressor * Added tests for linear regressor * Added test descriptions and docstring * Do not export constants --------- Co-authored-by: Swopper050 --------- Co-authored-by: Swopper050 Co-authored-by: wisse Co-authored-by: wipsel Co-authored-by: Yannick Dylla <17772145+ydylla@users.noreply.github.com> --- .github/workflows/go.yml | 75 ++ .gitignore | 2 + .golangci.yml | 15 +- README.md | 2 +- errors.go | 33 +- go.mod | 1 - go.sum | 6 +- model.go | 46 +- model_test.go | 21 +- onnx/graph_proto.go | 131 +++- ops/activation.go | 35 +- ops/binary_op.go | 161 ++++ ops/convert.go | 17 +- ops/convert_test.go | 6 +- ops/errors.go | 280 ++++++- ops/fixtures.go | 22 +- ops/multidir_broadcast.go | 33 +- ops/multidir_broadcast_test.go | 16 +- ops/operator.go | 7 +- ops/opset13/abs.go | 63 ++ ops/opset13/abs_test.go | 144 ++++ ops/opset13/acos.go | 75 ++ ops/opset13/acos_test.go | 99 +++ ops/opset13/acosh.go | 75 ++ ops/opset13/acosh_test.go | 99 +++ ops/opset13/add.go | 32 +- ops/opset13/add_test.go | 20 +- ops/opset13/and.go | 61 ++ ops/opset13/and_test.go | 104 +++ ops/opset13/asin.go | 75 ++ ops/opset13/asin_test.go | 99 +++ ops/opset13/asinh.go | 75 ++ ops/opset13/asinh_test.go | 99 +++ ops/opset13/atan.go | 75 ++ ops/opset13/atan_test.go | 99 +++ ops/opset13/atanh.go | 75 ++ ops/opset13/atanh_test.go | 99 +++ ops/opset13/cast.go | 24 +- ops/opset13/cast_test.go | 14 +- ops/opset13/concat.go | 20 +- ops/opset13/concat_test.go | 11 +- ops/opset13/constant.go | 19 +- ops/opset13/constant_of_shape.go | 65 +- ops/opset13/constant_of_shape_test.go | 37 +- ops/opset13/constant_test.go | 23 +- ops/opset13/conv.go | 590 +++++++++++++++ ops/opset13/conv_test.go | 692 ++++++++++++++++++ ops/opset13/cos.go | 75 ++ ops/opset13/cos_test.go | 99 +++ ops/opset13/cosh.go | 75 ++ ops/opset13/cosh_test.go | 99 +++ ops/opset13/div.go | 32 +- ops/opset13/div_test.go | 8 +- ops/opset13/equal.go | 61 ++ ops/opset13/equal_test.go | 133 ++++ ops/opset13/gather.go | 104 ++- ops/opset13/gather_test.go | 176 +++-- ops/opset13/gemm.go | 22 +- ops/opset13/gemm_test.go | 39 +- ops/opset13/greater.go | 61 ++ ops/opset13/greater_or_equal.go | 61 ++ ops/opset13/greater_or_equal_test.go | 133 ++++ ops/opset13/greater_test.go | 133 ++++ ops/opset13/gru.go | 214 +++--- ops/opset13/gru_test.go | 102 ++- ops/opset13/less.go | 61 ++ ops/opset13/less_or_equal.go | 61 ++ ops/opset13/less_or_equal_test.go | 133 ++++ ops/opset13/less_test.go | 133 ++++ ops/opset13/linear_regressor.go | 117 +++ ops/opset13/linear_regressor_test.go | 196 +++++ ops/opset13/lstm.go | 413 +++++++++++ ops/opset13/lstm_test.go | 386 ++++++++++ ops/opset13/matmul.go | 64 +- ops/opset13/matmul_test.go | 12 +- ops/opset13/mul.go | 32 +- ops/opset13/mul_test.go | 15 +- ops/opset13/not.go | 60 ++ ops/opset13/not_test.go | 93 +++ ops/opset13/opset13.go | 42 +- ops/opset13/opset13_test.go | 125 +++- ops/opset13/or.go | 61 ++ ops/opset13/or_test.go | 104 +++ ops/opset13/prelu.go | 134 ++++ ops/opset13/prelu_test.go | 108 +++ ops/opset13/relu.go | 20 +- ops/opset13/relu_test.go | 8 +- ops/opset13/reshape.go | 32 +- ops/opset13/reshape_test.go | 8 +- ops/opset13/rnn.go | 249 +++++++ ops/opset13/rnn_test.go | 334 +++++++++ ops/opset13/scaler.go | 25 +- ops/opset13/scaler_test.go | 30 +- ops/opset13/shape.go | 17 +- ops/opset13/shape_test.go | 8 +- ops/opset13/sigmoid.go | 7 +- ops/opset13/sigmoid_test.go | 26 +- ops/opset13/sin.go | 75 ++ ops/opset13/sin_test.go | 99 +++ ops/opset13/sinh.go | 75 ++ ops/opset13/sinh_test.go | 99 +++ ops/opset13/slice.go | 19 +- ops/opset13/slice_test.go | 9 +- ops/opset13/softmax.go | 84 +++ ops/opset13/softmax_test.go | 140 ++++ ops/opset13/squeeze.go | 41 +- ops/opset13/squeeze_test.go | 10 +- ops/opset13/sub.go | 32 +- ops/opset13/sub_test.go | 8 +- ops/opset13/tan.go | 75 ++ ops/opset13/tan_test.go | 99 +++ ops/opset13/tanh.go | 7 +- ops/opset13/tanh_test.go | 8 +- ops/opset13/transpose.go | 24 +- ops/opset13/transpose_test.go | 28 +- ops/opset13/unsqueeze.go | 32 +- ops/opset13/unsqueeze_test.go | 74 +- ops/opset13/xor.go | 61 ++ ops/opset13/xor_test.go | 102 +++ ops/recurrent_utils.go | 73 ++ ops/slicer.go | 4 +- ops/types.go | 18 + ops/unidir_broadcast.go | 20 +- ops/unidir_broadcast_test.go | 20 +- ops/utils.go | 51 +- ops/utils_test.go | 13 +- ops/validate_inputs.go | 27 +- ops/validate_inputs_test.go | 20 +- ops_test.go | 201 ++++- opset.go | 16 +- opset_test.go | 4 +- .../onnx_models/mnist-8-opset13.onnx | 3 + sample_models/requirements.txt | 6 +- 133 files changed, 9363 insertions(+), 927 deletions(-) create mode 100644 .github/workflows/go.yml create mode 100644 ops/binary_op.go create mode 100644 ops/opset13/abs.go create mode 100644 ops/opset13/abs_test.go create mode 100644 ops/opset13/acos.go create mode 100644 ops/opset13/acos_test.go create mode 100644 ops/opset13/acosh.go create mode 100644 ops/opset13/acosh_test.go create mode 100644 ops/opset13/and.go create mode 100644 ops/opset13/and_test.go create mode 100644 ops/opset13/asin.go create mode 100644 ops/opset13/asin_test.go create mode 100644 ops/opset13/asinh.go create mode 100644 ops/opset13/asinh_test.go create mode 100644 ops/opset13/atan.go create mode 100644 ops/opset13/atan_test.go create mode 100644 ops/opset13/atanh.go create mode 100644 ops/opset13/atanh_test.go create mode 100644 ops/opset13/conv.go create mode 100644 ops/opset13/conv_test.go create mode 100644 ops/opset13/cos.go create mode 100644 ops/opset13/cos_test.go create mode 100644 ops/opset13/cosh.go create mode 100644 ops/opset13/cosh_test.go create mode 100644 ops/opset13/equal.go create mode 100644 ops/opset13/equal_test.go create mode 100644 ops/opset13/greater.go create mode 100644 ops/opset13/greater_or_equal.go create mode 100644 ops/opset13/greater_or_equal_test.go create mode 100644 ops/opset13/greater_test.go create mode 100644 ops/opset13/less.go create mode 100644 ops/opset13/less_or_equal.go create mode 100644 ops/opset13/less_or_equal_test.go create mode 100644 ops/opset13/less_test.go create mode 100644 ops/opset13/linear_regressor.go create mode 100644 ops/opset13/linear_regressor_test.go create mode 100644 ops/opset13/lstm.go create mode 100644 ops/opset13/lstm_test.go create mode 100644 ops/opset13/not.go create mode 100644 ops/opset13/not_test.go create mode 100644 ops/opset13/or.go create mode 100644 ops/opset13/or_test.go create mode 100644 ops/opset13/prelu.go create mode 100644 ops/opset13/prelu_test.go create mode 100644 ops/opset13/rnn.go create mode 100644 ops/opset13/rnn_test.go create mode 100644 ops/opset13/sin.go create mode 100644 ops/opset13/sin_test.go create mode 100644 ops/opset13/sinh.go create mode 100644 ops/opset13/sinh_test.go create mode 100644 ops/opset13/softmax.go create mode 100644 ops/opset13/softmax_test.go create mode 100644 ops/opset13/tan.go create mode 100644 ops/opset13/tan_test.go create mode 100644 ops/opset13/xor.go create mode 100644 ops/opset13/xor_test.go create mode 100644 ops/recurrent_utils.go create mode 100644 ops/types.go create mode 100644 sample_models/onnx_models/mnist-8-opset13.onnx diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..d53567a --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,75 @@ +# This workflow will build a golang project +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go + +name: Go + +on: + push: + branches: [ "develop" ] + pull_request: + branches: [ "develop" ] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Install linter + run: make install_lint + + - name: Lint + run: make lint + + tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Install dependencies + run: make install + + - name: Install Gotestsum + run: make install_gotestsum + + - name: Setup ONNX test data + run: make test_data + + - name: Tests + run: make test + + build_amd64: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Build amd64 + run: make build_amd64 + + build_arm64: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 + + - name: Build arm64 + run: make build_arm64 diff --git a/.gitignore b/.gitignore index 37ffce7..fa81ed7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ test_data/ .coverage.out + +sample_models/.env diff --git a/.golangci.yml b/.golangci.yml index 6fb336b..e71a88c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -43,10 +43,17 @@ linters: linters-settings: gomnd: ignored-functions: - - 'strconv.ParseInt' - - 'strconv.ParseFloat' - - 'strconv.FormatInt' - - 'strconv.FormatFloat' + - "strconv.ParseInt" + - "strconv.ParseFloat" + - "strconv.FormatInt" + - "strconv.FormatFloat" + gocritic: + disabled-checks: + # In the world of AI tensor's are often denoted with a capital letter. + # We want to adopt the go style guide as much as possible but we also want + # to be able to easily show when a variable is a Tensor. So we chose to + # disable captLocal. Note that any other parameter should use a lower case letters. + - "captLocal" issues: max-issues-per-linter: 0 max-same-issues: 0 diff --git a/README.md b/README.md index 12e5e22..9d934b6 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ A simple example is shown below: package main import ( - "github.com/AdvancedClimateSystems/gonnx" + "github.com/advancedclimatesystems/gonnx" "gorgonia.org/tensor" ) diff --git a/errors.go b/errors.go index 82523e8..1f8f48c 100644 --- a/errors.go +++ b/errors.go @@ -1,7 +1,32 @@ package gonnx -// InvalidShapeError is used when the shape of an input tensor does not match the expectation. -const InvalidShapeError = "input shape does not match for %v: expected %v but got %v" +import ( + "errors" + "fmt" -// SetOutputTensorsError is used when the output of an operation could not be set. -const SetOutputTensorsError = "could not set output tensors, expected %v tensors but got %v" + "github.com/advancedclimatesystems/gonnx/onnx" +) + +var errModel = errors.New("gonnx model error") + +type InvalidShapeError struct { + expected onnx.Shape + actual []int +} + +func (i InvalidShapeError) Error() string { + return fmt.Sprintf("invalid shape error expected: %v actual %v", i.expected, i.actual) +} + +func ErrInvalidShape(expected onnx.Shape, actual []int) error { + return InvalidShapeError{ + expected: expected, + actual: actual, + } +} + +// ErrModel is used for when an error ocured during setup of running onnx models. +// The user can specify a formatted message using the standard formatting rules. +func ErrModel(format string, a ...any) error { + return fmt.Errorf("%w: %s", errModel, fmt.Sprintf(format, a...)) +} diff --git a/go.mod b/go.mod index 2502285..7295fcb 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.19 require ( github.com/stretchr/testify v1.8.1 - gitlab.advancedclimate.nl/smartbase/software/core/airgo v0.0.0-b1 google.golang.org/protobuf v1.28.1 gorgonia.org/tensor v0.9.24 ) diff --git a/go.sum b/go.sum index a1de286..88906c4 100644 --- a/go.sum +++ b/go.sum @@ -84,8 +84,6 @@ github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -gitlab.advancedclimate.nl/smartbase/software/core/airgo v0.0.0-b1 h1:RkBEoyN6PzTsB8i4ZHLrBpG2q3KxPjJS5kLoLT9Op8c= -gitlab.advancedclimate.nl/smartbase/software/core/airgo v0.0.0-b1/go.mod h1:EzFgx3V68Upc4HwJDakYaztVI9AZn0ABtegpQJT0Q+4= go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 h1:FyBZqvoA/jbNzuAWLQE2kG820zMAkcilx6BMjGbL/E4= go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -123,8 +121,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -164,8 +162,8 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc= gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= -gonum.org/v1/netlib v0.0.0-20201012070519-2390d26c3658 h1:/DNJ3wcvPHjTLVNG6rmSHK7uEwdBihyiJRJXB16wXoU= gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= diff --git a/model.go b/model.go index 96ca54b..9c9fcd1 100644 --- a/model.go +++ b/model.go @@ -2,16 +2,16 @@ package gonnx import ( "archive/zip" - "fmt" - "io/ioutil" + "io" + "os" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "google.golang.org/protobuf/proto" "gorgonia.org/tensor" ) -// Tensors is a map with tensors +// Tensors is a map with tensors. type Tensors map[string]tensor.Tensor // Model defines a model that can be used for inference. @@ -23,7 +23,7 @@ type Model struct { // NewModelFromFile creates a new model from a path to a file. func NewModelFromFile(path string) (*Model, error) { - bytesModel, err := ioutil.ReadFile(path) + bytesModel, err := os.ReadFile(path) if err != nil { return nil, err } @@ -38,7 +38,7 @@ func NewModelFromZipFile(file *zip.File) (*Model, error) { return nil, err } - bytesModel, err := ioutil.ReadAll(fc) + bytesModel, err := io.ReadAll(fc) if err != nil { return nil, err } @@ -66,6 +66,7 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) { opsetImports := mp.GetOpsetImport() var opsetID int64 + for i := 0; i < len(opsetImports); i++ { version := opsetImports[i].GetVersion() if version > opsetID { @@ -78,12 +79,11 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) { return nil, err } - model := &Model{ + return &Model{ mp: mp, parameters: params, GetOperator: GetOperator, - } - return model, nil + }, nil } // ModelProtoFromBytes creates an onnx.ModelProto based on a list of bytes. @@ -92,6 +92,7 @@ func ModelProtoFromBytes(bytesModel []byte) (*onnx.ModelProto, error) { if err := proto.Unmarshal(bytesModel, mp); err != nil { return nil, err } + return mp, nil } @@ -108,16 +109,13 @@ func (m *Model) InputShapes() onnx.Shapes { // InputDimSize returns the size of the input dimension given an input tensor. func (m *Model) InputDimSize(input string, i int) (int, error) { if !m.hasInput(input) { - return 0, fmt.Errorf("input %v does not exist", input) + return 0, ErrModel("input %v does not exist", input) } inputShape := m.mp.Graph.InputShapes()[input] if i >= len(inputShape) { - err := fmt.Errorf( - "input %v only has %d dimensions, but index %d was required", input, len(inputShape), i, - ) - return 0, err + return 0, ErrModel("input %v only has %d dimensions, but index %d was required", input, len(inputShape), i) } return int(inputShape[i].Size), nil @@ -159,7 +157,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) { return nil, err } - var tensors = make(Tensors) + tensors := make(Tensors) for inputName, inputTensor := range inputs { tensors[inputName] = inputTensor } @@ -179,7 +177,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) { } } - var outputTensors = make(Tensors) + outputTensors := make(Tensors) for _, outputName := range m.OutputNames() { outputTensors[outputName] = tensors[outputName] } @@ -189,7 +187,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) { // applyOp applies the operation to the graph. func (m *Model) applyOp(op ops.Operator, n *onnx.NodeProto, tensors Tensors) error { - if err := op.Init(n.GetAttribute()); err != nil { + if err := op.Init(n); err != nil { return err } @@ -222,13 +220,13 @@ func (m *Model) validateShapes(inputTensors Tensors) error { tensor, ok := inputTensors[name] if !ok { - return fmt.Errorf("tensor: %v not found", name) + return ErrModel("tensor: %v not found", name) } shapeReceived := tensor.Shape() if len(shapeReceived) != len(shapeExpected) { - return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived) + return ErrInvalidShape(shapeExpected, shapeReceived) } for i, dim := range shapeExpected { @@ -239,7 +237,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error { } if dim.Size != int64(shapeReceived[i]) { - return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived) + return ErrInvalidShape(shapeExpected, shapeReceived) } } } @@ -249,6 +247,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error { func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, error) { var inputTensors []tensor.Tensor + for _, tensorName := range names { // An empty name can happen in between optional inputs, like: // [, , nil, ] @@ -259,7 +258,7 @@ func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, e } else if tensor, ok := tensors[tensorName]; ok { inputTensors = append(inputTensors, tensor) } else { - return nil, fmt.Errorf("no tensor yet for name %v", tensorName) + return nil, ErrModel("no tensor yet for name %v", tensorName) } } @@ -269,9 +268,8 @@ func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, e func setOutputTensorsOfNode( names []string, outputTensors []tensor.Tensor, tensors Tensors, ) error { - if len(names) != len(outputTensors) { - return fmt.Errorf(SetOutputTensorsError, len(names), len(outputTensors)) + return ErrModel("could not set output tensor") } for i, tensor := range outputTensors { diff --git a/model_test.go b/model_test.go index d76d6f3..ed1c4b9 100644 --- a/model_test.go +++ b/model_test.go @@ -1,12 +1,10 @@ package gonnx import ( - "errors" - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" "gorgonia.org/tensor" ) @@ -39,9 +37,7 @@ func TestModel(t *testing.T) { [][]float32{rangeFloat(16)}, ), nil, - errors.New( - "input shape does not match for data_input: expected [0 3] but got (2, 4, 2)", - ), + ErrInvalidShape([]onnx.Dim{{IsDynamic: true, Name: "batch_size", Size: 0}, {IsDynamic: false, Name: "", Size: 3}}, []int{2, 4, 2}), }, { "./sample_models/onnx_models/mlp.onnx", @@ -51,7 +47,7 @@ func TestModel(t *testing.T) { [][]float32{rangeFloat(6)}, ), nil, - errors.New("tensor: data_input not found"), + ErrModel("tensor: %v not found", "data_input"), }, { "./sample_models/onnx_models/gru.onnx", @@ -106,6 +102,7 @@ func TestModel(t *testing.T) { outputs, err := model.Run(test.input) assert.Equal(t, test.err, err) + if test.expected == nil { assert.Nil(t, outputs) } else { @@ -128,6 +125,7 @@ func TestModelIOUtil(t *testing.T) { {IsDynamic: false, Name: "", Size: 3}, }, } + assert.Equal(t, []string{"data_input"}, model.InputNames()) assert.Equal(t, expectedInputShapes, model.InputShapes()) @@ -137,6 +135,7 @@ func TestModelIOUtil(t *testing.T) { {IsDynamic: false, Name: "", Size: 2}, }, } + assert.Equal(t, []string{"preds"}, model.OutputNames()) assert.Equal(t, expectedOutputShapes, model.OutputShapes()) assert.Equal(t, expectedOutputShapes["preds"], model.OutputShape("preds")) @@ -165,11 +164,12 @@ func TestInputDimSizeInvalidInput(t *testing.T) { assert.Nil(t, err) _, err = model.InputDimSize("swagger", 0) - assert.Equal(t, fmt.Errorf("input swagger does not exist"), err) + + assert.Equal(t, ErrModel("input %v does not exist", "swagger"), err) } // tensorsFixture creates Tensors with the given names shapes and backings. This is useful for -// providing a model with inputs and checking it's outputs +// providing a model with inputs and checking it's outputs. func tensorsFixture(names []string, shapes [][]int, backing [][]float32) Tensors { res := make(Tensors, len(names)) for i, name := range names { @@ -178,6 +178,7 @@ func tensorsFixture(names []string, shapes [][]int, backing [][]float32) Tensors tensor.WithBacking(backing[i]), ) } + return res } @@ -186,6 +187,7 @@ func rangeFloat(size int) []float32 { for i := 0; i < size; i++ { res[i] = float32(i) } + return res } @@ -194,6 +196,7 @@ func rangeZeros(size int) []float32 { for i := range res { res[i] = 0.0 } + return res } diff --git a/onnx/graph_proto.go b/onnx/graph_proto.go index 87decfa..046b7e0 100644 --- a/onnx/graph_proto.go +++ b/onnx/graph_proto.go @@ -6,6 +6,7 @@ package onnx import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math" @@ -13,7 +14,7 @@ import ( "gorgonia.org/tensor" ) -// InputNames returns the input names for a GraphProto +// InputNames returns the input names for a GraphProto. func (g *GraphProto) InputNames() []string { return getNamesFromValueProto(g.GetInput()) } @@ -51,6 +52,7 @@ func (g *GraphProto) Params() (map[string]tensor.Tensor, error) { res[i.Name] = t } + return res, nil } @@ -62,13 +64,17 @@ type Shape []Dim // String prints a shape in a human-friendly matter. func (s Shape) String() string { - var dimSizes []int64 + dimSizes := make([]int64, 0, len(s)) + for _, dim := range s { dimSizes = append(dimSizes, dim.Size) } + return fmt.Sprintf("%d", dimSizes) } +var ErrInvalidType = errors.New("invalid type") + // Dim is a dimension. type Dim struct { IsDynamic bool @@ -95,6 +101,7 @@ func getShapesFromValueProto(protos []*ValueInfoProto) Shapes { if protos == nil { return map[string]Shape{} } + shapes := make(map[string]Shape, len(protos)) for _, p := range protos { @@ -119,6 +126,7 @@ func getShapesFromValueProto(protos []*ValueInfoProto) Shapes { } shape := make([]Dim, len(dims)) + for i, dim := range dims { param := dim.GetDimParam() v := dim.GetDimValue() @@ -130,6 +138,7 @@ func getShapesFromValueProto(protos []*ValueInfoProto) Shapes { shape[i] = Dim{IsDynamic: isDynamic, Name: param, Size: v} } + shapes[p.GetName()] = shape } @@ -146,12 +155,15 @@ func getNamesFromTensorProto(protos []*TensorProto) []string { return res } -// TensorFromProto returns a tensor.Tensor from an onnx.TensorProto +// TensorFromProto returns a tensor.Tensor from an onnx.TensorProto. func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) { - var values interface{} - var err error + var ( + values interface{} + err error + ) typeMap := TensorProto_DataType_value + switch tp.DataType { case typeMap["FLOAT"]: values, err = getFloatData(tp) @@ -173,6 +185,8 @@ func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) { values, err = getInt64Data(tp) case typeMap["DOUBLE"]: values, err = getDoubleData(tp) + case typeMap["BOOL"]: + values = getBoolData(tp) default: // At this moment the datatype is either UNDEFINED or some datatype we currently // do not support. @@ -188,7 +202,7 @@ func TensorFromProto(tp *TensorProto) (tensor.Tensor, error) { case len(tp.Uint64Data) > 0: values, err = getUint64Data(tp) default: - return nil, fmt.Errorf("unsupported datatype for Tensor: %v", tp.DataType) + return nil, ErrInvalidType } } @@ -279,8 +293,17 @@ func getDoubleData(tp *TensorProto) ([]float64, error) { return ReadFloat64ArrayFromBytes(tp.RawData) } +func getBoolData(tp *TensorProto) []bool { + if len(tp.Int32Data) > 0 { + return Int32ArrayToBoolArray(tp.GetInt32Data()) + } + + return ReadBoolArrayFromBytes(tp.RawData) +} + const ( float32Size int = 4 + boolSize int = 1 uint8Size int = 1 int8Size int = 1 uint16Size int = 2 @@ -297,11 +320,14 @@ func ReadFloat32ArrayFromBytes(data []byte) ([]float32, error) { buffer := bytes.NewReader(data) element := make([]byte, float32Size) - var err error - var values []float32 + var ( + err error + values []float32 + ) for { var n int + n, err = buffer.Read(element) if n != float32Size || err != nil { break @@ -323,11 +349,14 @@ func ReadFloat64ArrayFromBytes(data []byte) ([]float64, error) { buffer := bytes.NewReader(data) element := make([]byte, float64Size) - var err error - var values []float64 + var ( + err error + values []float64 + ) for { var n int + n, err = buffer.Read(element) if n != float64Size || err != nil { break @@ -344,21 +373,37 @@ func ReadFloat64ArrayFromBytes(data []byte) ([]float64, error) { return values, nil } +// ReadBoolArrayFromBytes reads data and parses it to an array of bool. +// The data is parsed to a bool by comparing the value to 0. If it is +// greater than 0, the bool is considered to be true. +func ReadBoolArrayFromBytes(data []byte) []bool { + values := make([]bool, len(data)) + for i, b := range data { + values[i] = b > 0 + } + + return values +} + // ReadUint8ArrayFromBytes reads data and parses it to an array of uint8. func ReadUint8ArrayFromBytes(data []byte) ([]uint8, error) { buffer := bytes.NewReader(data) element := make([]byte, uint8Size) - var err error - var values []uint8 + var ( + err error + values []uint8 + ) for { var n int + n, err = buffer.Read(element) if n != uint8Size || err != nil { break } - values = append(values, uint8(element[0])) + + values = append(values, element[0]) } if err != io.EOF { @@ -373,11 +418,14 @@ func ReadInt8ArrayFromBytes(data []byte) ([]int8, error) { buffer := bytes.NewReader(data) element := make([]byte, int8Size) - var err error - var values []int8 + var ( + err error + values []int8 + ) for { var n int + n, err = buffer.Read(element) if n != int8Size || err != nil { break @@ -398,11 +446,14 @@ func ReadUint16ArrayFromBytes(data []byte) ([]uint16, error) { buffer := bytes.NewReader(data) element := make([]byte, uint16Size) - var err error - var values []uint16 + var ( + err error + values []uint16 + ) for { var n int + n, err = buffer.Read(element) if n != uint16Size || err != nil { break @@ -423,11 +474,14 @@ func ReadInt16ArrayFromBytes(data []byte) ([]int16, error) { buffer := bytes.NewReader(data) element := make([]byte, uint16Size) - var err error - var values []int16 + var ( + err error + values []int16 + ) for { var n int + n, err = buffer.Read(element) if n != int16Size || err != nil { break @@ -448,11 +502,14 @@ func ReadUint32ArrayFromBytes(data []byte) ([]uint32, error) { buffer := bytes.NewReader(data) element := make([]byte, int32Size) - var err error - var values []uint32 + var ( + err error + values []uint32 + ) for { var n int + n, err = buffer.Read(element) if n != uint32Size || err != nil { break @@ -473,11 +530,14 @@ func ReadInt32ArrayFromBytes(data []byte) ([]int32, error) { buffer := bytes.NewReader(data) element := make([]byte, int32Size) - var err error - var values []int32 + var ( + err error + values []int32 + ) for { var n int + n, err = buffer.Read(element) if n != int32Size || err != nil { break @@ -498,11 +558,14 @@ func ReadUint64ArrayFromBytes(data []byte) ([]uint64, error) { buffer := bytes.NewReader(data) element := make([]byte, int32Size) - var err error - var values []uint64 + var ( + err error + values []uint64 + ) for { var n int + n, err = buffer.Read(element) if n != uint64Size || err != nil { break @@ -523,11 +586,14 @@ func ReadInt64ArrayFromBytes(data []byte) ([]int64, error) { buffer := bytes.NewReader(data) element := make([]byte, int64Size) - var err error - var values []int64 + var ( + err error + values []int64 + ) for { var n int + n, err = buffer.Read(element) if n != int64Size || err != nil { break @@ -543,6 +609,17 @@ func ReadInt64ArrayFromBytes(data []byte) ([]int64, error) { return values, nil } +// Int32ArrayToBoolArray converts an int32 array to a bool array. +// When the value is equal to 1 the boolean is considered to be true. +func Int32ArrayToBoolArray(arr []int32) []bool { + newArr := make([]bool, len(arr)) + for i, value := range arr { + newArr[i] = value == 1 + } + + return newArr +} + // Int32ArrayToInt8Array converts an int32 array to an int8 array. func Int32ArrayToInt8Array(arr []int32) []int8 { newArr := make([]int8, len(arr)) diff --git a/ops/activation.go b/ops/activation.go index 89aca38..1fdee54 100644 --- a/ops/activation.go +++ b/ops/activation.go @@ -1,10 +1,28 @@ package ops -import "gorgonia.org/tensor" +import ( + "gorgonia.org/tensor" +) // Activation is an activation function. type Activation func(n tensor.Tensor) (tensor.Tensor, error) +// activations maps strings to the activation function. This is +// used by operators like LSTM, GRU and RNN. +var activations = map[string]Activation{ + "tanh": Tanh, + "sigmoid": Sigmoid, + "relu": ReLU, +} + +func GetActivation(activation string) (Activation, error) { + if a, ok := activations[activation]; ok { + return a, nil + } + + return nil, ErrActivationNotImplemented(activation) +} + // Tanh performs the tanh operation on a tensor. func Tanh(X tensor.Tensor) (tensor.Tensor, error) { return tensor.Tanh(X) @@ -34,3 +52,18 @@ func Sigmoid(X tensor.Tensor) (tensor.Tensor, error) { return tensor.Div(typedOne, numeratorX) } + +// ReLU performs the ReLU operation on a tensor. +func ReLU(X tensor.Tensor) (tensor.Tensor, error) { + typedZero, err := GetValueAsTensorType(0.0, X.Dtype()) + if err != nil { + return nil, err + } + + comparison, err := tensor.Gt(X, typedZero, tensor.AsSameType()) + if err != nil { + return nil, err + } + + return tensor.Mul(X, comparison) +} diff --git a/ops/binary_op.go b/ops/binary_op.go new file mode 100644 index 0000000..3df36e8 --- /dev/null +++ b/ops/binary_op.go @@ -0,0 +1,161 @@ +package ops + +import ( + "gorgonia.org/tensor" +) + +// BinaryOp describes a general operation between 2 tensors with 1 tensor as result. +type BinaryOp func(A, B tensor.Tensor) (tensor.Tensor, error) + +// ApplyBinaryOperation applies a binary operation (an operation of arity 2) to 2 tensors. +// It returns a list of tensors with only 1 output tensor in order for this function to +// be easily used in operators. +func ApplyBinaryOperation(A, B tensor.Tensor, op BinaryOp, broadcastOption BroadcastType) ([]tensor.Tensor, error) { + var err error + + switch broadcastOption { + case NoBroadcasting: + break + case UnidirectionalBroadcasting: + A, B, err = UnidirectionalBroadcast(A, B) + if err != nil { + return nil, err + } + case MultidirectionalBroadcasting: + A, B, err = MultidirectionalBroadcast(A, B) + if err != nil { + return nil, err + } + } + + out, err := op(A, B) + + return []tensor.Tensor{out}, err +} + +// Add adds 2 tensors to each other. +func Add(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Add(A, B) +} + +// Div divides 1 tensor by the other. +func Div(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Div(A, B) +} + +// Mul multiplies 2 tensors with each other. +func Mul(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Mul(A, B) +} + +// Sub subtracts 1 tensor from the other. +func Sub(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Sub(A, B) +} + +// Or applies the boolean 'or' operation on 2 tensors. +func Or(A, B tensor.Tensor) (tensor.Tensor, error) { + return applyBooleanBinaryOperator( + A, + B, + func(a, b bool) bool { return a || b }, + ) +} + +// Gt applies the greater than (>) operator on 2 tensors. +func Gt(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Gt(A, B) +} + +// Gte applies the greater or equal than (>=) operator on 2 tensors. +func Gte(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Gte(A, B) +} + +// Lt applies the less than (<) operator on 2 tensors. +func Lt(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Lt(A, B) +} + +// Lte applies the less or equal than (<=) operator on 2 tensors. +func Lte(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.Lte(A, B) +} + +// Equal applies the equal operator (=) operator on 2 tensors. +func Equal(A, B tensor.Tensor) (tensor.Tensor, error) { + return tensor.ElEq(A, B) +} + +// And applies the boolean 'and' operation on 2 tensors. +func And(A, B tensor.Tensor) (tensor.Tensor, error) { + return applyBooleanBinaryOperator( + A, + B, + func(a, b bool) bool { return a && b }, + ) +} + +// Xor applies the boolean 'xor' operation on 2 tensors. +func Xor(A, B tensor.Tensor) (tensor.Tensor, error) { + return applyBooleanBinaryOperator( + A, + B, + func(a, b bool) bool { return a != b }, + ) +} + +// BooleanOp describes a binary operation between two booleans that also returns a boolean. +type BooleanOp func(a, b bool) bool + +// ApplyBooleanOperator is a function that applies a boolean operator element-wise to +// to 2 tensors. This assumes that A and B have exactly the same shape. +// We create an iterator that loops over all elements of A (which can also be used for B). +// Using this iterator, the given boolean operator is applied to all pairs of elements from +// A and B and the result is returned. +func applyBooleanBinaryOperator(A, B tensor.Tensor, op BooleanOp) (tensor.Tensor, error) { + A, B, err := MultidirectionalBroadcast(A, B) + if err != nil { + return nil, err + } + + output := tensor.NewDense(tensor.Bool, A.Shape()) + output.Zero() + + iterator := A.Iterator() + iterator.Reset() + + for !iterator.Done() { + valA, err := A.At(iterator.Coord()...) + if err != nil { + return nil, err + } + + boolA, ok := valA.(bool) + if !ok { + return nil, ErrTypeAssert("bool", valA) + } + + valB, err := B.At(iterator.Coord()...) + if err != nil { + return nil, err + } + + boolB, ok := valB.(bool) + if !ok { + return nil, ErrTypeAssert("bool", valB) + } + + err = output.SetAt(op(boolA, boolB), iterator.Coord()...) + if err != nil { + return nil, err + } + + _, err = iterator.Next() + if err != nil { + return nil, err + } + } + + return output, nil +} diff --git a/ops/convert.go b/ops/convert.go index 4bb74c2..0637f49 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -1,9 +1,7 @@ package ops import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/onnx" "gorgonia.org/tensor" ) @@ -14,8 +12,11 @@ type Number interface { // ConvertTensorDtype converts an interface of a specific dtype to a new dtype. func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { - var err error - var newBacking any + var ( + err error + newBacking any + ) + backing := IfScalarToSlice(t.Data()) switch t.Dtype() { @@ -40,7 +41,7 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { case tensor.Uint64: newBacking, err = convertBacking(backing.([]uint64), newType) default: - return nil, fmt.Errorf("unable to convert tensor of type %v to type %v", t.Dtype(), newType) + return nil, ErrConversionInvalidType(t.Dtype(), newType) } if err != nil { @@ -72,8 +73,10 @@ func convertBacking[B Number](backing []B, dataType int32) (any, error) { return createNewBacking[B, uint32](backing), nil case onnx.TensorProto_UINT64: return createNewBacking[B, uint64](backing), nil + case onnx.TensorProto_BFLOAT16, onnx.TensorProto_BOOL, onnx.TensorProto_COMPLEX64, onnx.TensorProto_COMPLEX128, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: + return nil, ErrConversionNotSupported(dataType) default: - return nil, fmt.Errorf("converting to onnx datatype %d is not supported yet", dataType) + return nil, ErrConversionNotSupported(dataType) } } diff --git a/ops/convert_test.go b/ops/convert_test.go index 8154e85..400f67e 100644 --- a/ops/convert_test.go +++ b/ops/convert_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -91,13 +90,13 @@ func TestConvertTensorDtype(t *testing.T) { tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{true, false})), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{1.0, 2.0})), 1, - fmt.Errorf("unable to convert tensor of type bool to type 1"), + ErrConversionInvalidType(tensor.Bool, 1), }, { tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{1.0, 2.1})), nil, 8, - fmt.Errorf("converting to onnx datatype 8 is not supported yet"), + ErrConversionNotSupported(8), }, } @@ -105,6 +104,7 @@ func TestConvertTensorDtype(t *testing.T) { out, err := ConvertTensorDtype(test.tensorIn, test.newType) assert.Equal(t, test.err, err) + if test.err != nil { continue } diff --git a/ops/errors.go b/ops/errors.go index ab25387..e7f01f3 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -1,16 +1,66 @@ package ops -// UnknownAttributeErrTemplate is used to format an error -// when an operator finds an unknown attribute during its initialization. -const UnknownAttributeErrTemplate = "%v: unknown attribute: %v" +import ( + "errors" + "fmt" + "reflect" -// UnsupportedAttrErrTemplate is used to format an error when an operator receives -// an attribute that is not supported yet. -const UnsupportedAttrErrTemplate = "%v: %v attribute not supported yet" + "gorgonia.org/tensor" +) -// InvalidAttrCountErrTemplate is used to format an error when an operator -// got the wrong amount of attributes. -const InvalidAttrCountErrTemplate = "%v: expected %v attributes, got %d" +type AttributeErrorKind string + +const ( + AttributeErrorCount AttributeErrorKind = "count" + AttributeErrorInvalid AttributeErrorKind = "invalid" + AttributeErrorUnsupported AttributeErrorKind = "unsupported" +) + +type AttributeError struct { + kind AttributeErrorKind + attributeCount int + expectedCount int + attributeName string + operator Operator +} + +func (t *AttributeError) Error() string { + switch t.kind { + case AttributeErrorCount: + return fmt.Sprintf("%s attribute error: invalid count %d expected %d", t.operator.String(), t.attributeCount, t.expectedCount) + case AttributeErrorInvalid: + return fmt.Sprintf("%s attribute error: invalid attribute %s", t.operator.String(), t.attributeName) + case AttributeErrorUnsupported: + return fmt.Sprintf("%s attribute error: unsupported attribute %s", t.operator.String(), t.attributeName) + default: + return fmt.Sprintf("%s unknown error attribute error kind %s", t.operator.String(), t.kind) + } +} + +func ErrInvalidAttribute(attributeName string, operator Operator) *AttributeError { + return &AttributeError{attributeName: attributeName, kind: "invalid", operator: operator} +} + +func ErrInvalidAttributeCount(expected, actual int, operator Operator) error { + return &AttributeError{attributeCount: actual, expectedCount: expected, kind: "count", operator: operator} +} + +func ErrUnsupportedAttribute(attributeName string, operator Operator) error { + return &AttributeError{attributeName: attributeName, kind: "unsupported", operator: operator} +} + +type TypeAssertError struct { + expectedType string + actualType any +} + +func (t *TypeAssertError) Error() string { + return fmt.Sprintf("type assert error: expected %v, got %v", t.expectedType, reflect.TypeOf(t.actualType)) +} + +func ErrTypeAssert(expected string, actual any) error { + return &TypeAssertError{expectedType: expected, actualType: actual} +} // InvalidInputCountErrTemplate is used to format an error when an operator got // the wrong amount of input tensors. @@ -20,21 +70,205 @@ const InvalidInputCountErrTemplate = "%v: expected %d input tensors, got %d" // the wrong amount of input tensors when optional inputs are present. const InvalidOptionalInputCountErrTemplate = "%v: expected %d-%d input tensors, got %d" -// UnknowOpTypeErrTemplate is used to format an error when the operator type is unknown. -const UnknowOpTypeErrTemplate = "unknown operator type: %v" +// UnsupportedInputErrTemplate is used to format an error when an operator got +// the wrong amount of input tensors when optional inputs are present. +const UnsupportedInputErrTemplate = "unsupported input for %v: %v" + +// InvalidInputErrTemplate is used to format an error when an operator got +// an invalid input tensor as input. +const InvalidInputErrTemplate = "invalid input tensor for %v: %v" + +type InputErrorKind string + +const ( + InputErrorType InputErrorKind = "type" + InputErrorCount InputErrorKind = "count" + InputErrorUnsupported InputErrorKind = "unsupported" + InputErrorInvalid InputErrorKind = "invalid" +) + +type InputError struct { + kind InputErrorKind + operator Operator + reason string + + // Attributes for input type error. + inputNumber int + actualType string + + // Attributes for input count error. + hasOptionalInputs bool + actualCount int + + // Attributes for unsupported input error. + inputName string +} + +func (i *InputError) Error() string { + switch i.kind { + case InputErrorType: + return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType) + case InputErrorCount: + if i.hasOptionalInputs { + return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.operator.GetMaxInputs(), i.actualCount) + } + + return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount) + case InputErrorUnsupported: + return fmt.Sprintf(UnsupportedInputErrTemplate, i.operator, i.inputName) + case InputErrorInvalid: + return fmt.Sprintf(InvalidInputErrTemplate, i.operator, i.reason) + default: + return fmt.Sprintf("%s unknown error input error kind %s", i.operator.String(), i.kind) + } +} + +func ErrInvalidInputType(inputNumber int, dType string, operator Operator) error { + return &InputError{ + kind: InputErrorType, + operator: operator, + inputNumber: inputNumber, + actualType: dType, + } +} + +func ErrInvalidInputCount(actual int, operator Operator) error { + return &InputError{ + kind: InputErrorCount, + actualCount: actual, + operator: operator, + } +} + +func ErrInvalidOptionalInputCount(actual int, operator Operator) error { + return &InputError{ + kind: InputErrorCount, + hasOptionalInputs: true, + actualCount: actual, + operator: operator, + } +} + +func ErrUnsupportedInput(inputName string, operator Operator) error { + return &InputError{ + kind: InputErrorUnsupported, + inputName: inputName, + operator: operator, + } +} + +func ErrInvalidInput(reason string, operator Operator) error { + return &InputError{ + kind: InputErrorInvalid, + reason: reason, + operator: operator, + } +} + +type BroadcastError struct { + broadcastType string + shapeA tensor.Shape + shapeB tensor.Shape + err error +} + +func (b *BroadcastError) Error() string { + return fmt.Sprintf("%v: could not perform %v, inputs with shape %d and %d.", b.err, b.broadcastType, b.shapeA, b.shapeB) +} + +func ErrMultidirBroadcast(shapeA, shapeB tensor.Shape, err error) error { + return &BroadcastError{ + broadcastType: "multidirectional broadcast", + shapeA: shapeA, + shapeB: shapeB, + err: err, + } +} + +func ErrUnidirBroadcast(shapeA, shapeB tensor.Shape) error { + return &BroadcastError{ + broadcastType: "Unidirectional broadcast", + shapeA: shapeA, + shapeB: shapeB, + } +} + +type InvalidTensorError struct { + reason string + operator Operator +} + +func (i *InvalidTensorError) Error() string { + return fmt.Sprintf("%v invalid tensor found, reason: %s", i.operator.String(), i.reason) +} + +func ErrInvalidTensor(reason string, operator Operator) error { + return &InvalidTensorError{reason: reason, operator: operator} +} + +var ErrUnsupportedOperator = errors.New("unsupported operator") + +func ErrUnknownOperatorType(operatorType string) error { + return fmt.Errorf("%w: %s", ErrUnsupportedOperator, operatorType) +} + +var ErrAxisNotInRange = errors.New("axis out of range") + +func ErrNotAllAxesInRange(min, max int) error { + return fmt.Errorf("%w: all indices entries must be in the range -%d <= x < %d", ErrAxisNotInRange, min, max) +} + +func ErrAxisOutOfRange(min, max, actual int) error { + return fmt.Errorf("%w: axis argument must be in the range -%d <= x < %d, was %d", ErrAxisNotInRange, min, max, actual) +} + +var ErrUnsupportedOpsetVersion = errors.New("unsupported opset version") + +type DimensionErrorKind string + +const ( + DimensionErrorIncompatible DimensionErrorKind = "incompatible" +) + +type DimensionError struct { + kind DimensionErrorKind + reason string +} + +func (d *DimensionError) Error() string { + switch d.kind { + case DimensionErrorIncompatible: + return fmt.Sprintf("dimensions error: incompatible dimensions") + default: + return fmt.Sprintf("dimension error: %s", d.reason) + } +} + +func ErrIncompatibleDimensions() error { + return &DimensionError{kind: DimensionErrorIncompatible, reason: ""} +} + +func ErrDimension(reason string) error { + return &DimensionError{reason: reason} +} + +var ( + ErrCast = errors.New("cast error") + ErrInvalidShape = errors.New("invalid shape error") +) + +var ErrConversion = errors.New("unable to convert") -// MultidirBroadcastErrTemplate is used to format an error when two inputs cannot be -// broadcasted together with Multidirectional broadcasting. -const MultidirBroadcastErrTemplate = "could not multidir broadcast inputs with shape %d and %d: %v" +func ErrConversionInvalidType(dType tensor.Dtype, newType int32) error { + return fmt.Errorf("%w: type %v, to %v is invalid", ErrConversion, dType, newType) +} -// UnidirBroadcastErrTemplate is used to format an error when two inputs cannot be -// broadcasted together with Unidirectional broadcasting. -const UnidirBroadcastErrTemplate = "could not unidir broadcast inputs with shape %d and %d" +func ErrConversionNotSupported(dType int32) error { + return fmt.Errorf("%w: to %v is not supported yet", ErrConversion, dType) +} -// AxisOutOfRangeErrTemplate is used to format an error when an given axis is out of range -// given a certain rank. -const AxisOutOfRangeErrTemplate = "axis argument must be in the range -%d <= x < %d, was %d" +var ErrActivationNotImplementedBase = errors.New("the given activation function is not implemented") -// AxesNotAllInRangeErrTemplate is used to format an error when not all indices -// are within a given range. -const AxesNotAllInRangeErrTemplate = "all indices entries must be in the range -%d <= x < %d" +func ErrActivationNotImplemented(activation string) error { + return fmt.Errorf("%w: %s", ErrActivationNotImplementedBase, activation) +} diff --git a/ops/fixtures.go b/ops/fixtures.go index 12d552a..0e731da 100644 --- a/ops/fixtures.go +++ b/ops/fixtures.go @@ -1,10 +1,13 @@ package ops import ( + "math/rand" + + "github.com/advancedclimatesystems/gonnx/onnx" "gorgonia.org/tensor" ) -// InputFixture is a function that generates inputs for ops. Useful in testing +// InputFixture is a function that generates inputs for ops. Useful in testing. type InputFixture func() []tensor.Tensor // Float32TensorFixture returns a float32 backed gorgonia node. It initializes all its values @@ -16,6 +19,18 @@ func Float32TensorFixture(shp ...int) tensor.Tensor { ) } +func RandomFloat32TensorFixture(shp ...int) tensor.Tensor { + rands := make([]float32, NElements(shp...)) + for i := 0; i < NElements(shp...); i++ { + rands[i] = rand.Float32() + } + + return tensor.New( + tensor.WithShape(shp...), + tensor.WithBacking(rands), + ) +} + // TensorWithBackingFixture returns a gorgonia node with a tensor using the given backing. func TensorWithBackingFixture(b interface{}, shp ...int) tensor.Tensor { return tensor.New(tensor.WithShape(shp...), tensor.WithBacking(b)) @@ -30,3 +45,8 @@ func TensorInputsFixture(nTensors int) []tensor.Tensor { return result } + +// EmptyNodeProto returns a node proto with no attributes. +func EmptyNodeProto() *onnx.NodeProto { + return &onnx.NodeProto{Attribute: []*onnx.AttributeProto{}} +} diff --git a/ops/multidir_broadcast.go b/ops/multidir_broadcast.go index 9b81961..5565bd9 100644 --- a/ops/multidir_broadcast.go +++ b/ops/multidir_broadcast.go @@ -1,8 +1,6 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) @@ -11,12 +9,12 @@ import ( func MultidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { newA, newB, err := ReshapeTensorsForMultidirBroadcast(A, B) if err != nil { - return nil, nil, fmt.Errorf(MultidirBroadcastErrTemplate, A.Shape(), B.Shape(), err) + return nil, nil, ErrMultidirBroadcast(A.Shape(), B.Shape(), err) } newA, newB, err = repeatTensorsForMutltidirBroadcast(newA, newB) if err != nil { - return nil, nil, fmt.Errorf(MultidirBroadcastErrTemplate, A.Shape(), B.Shape(), err) + return nil, nil, ErrMultidirBroadcast(A.Shape(), B.Shape(), err) } return newA, newB, nil @@ -38,12 +36,14 @@ func ReshapeTensorsForMultidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens if err != nil { return nil, nil, err } + return A, newB, nil case nDimsB > nDimsA: newA, err := addExtraDimsToTensor(A, nDimsB-nDimsA) if err != nil { return nil, nil, err } + return newA, B, nil default: return A, B, nil @@ -55,9 +55,11 @@ func ReshapeTensorsForMultidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens // the dimension of the other. If both sizes are not 1, the tensors cannot be broadcasted to // each other. It is assumed that both tensors are reshaped accordingly first. // Example: -// shapeA=(1, 3, 4) and shapeB=(2, 3, 1) yields shapeNewA=(2, 3, 4) and shapeNewB=(2, 3, 4). +// +// shapeA=(1, 3, 4) and shapeB=(2, 3, 1) yields shapeNewA=(2, 3, 4) and shapeNewB=(2, 3, 4). func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { var err error + shapeA := A.Shape() shapeB := B.Shape() nDims := len(shapeA) @@ -73,13 +75,15 @@ func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens if err != nil { return nil, nil, err } + case sizeDimB == 1: B, err = tensor.Repeat(B, axis, sizeDimA) if err != nil { return nil, nil, err } + default: - return nil, nil, fmt.Errorf("incompatible dimensions") + return nil, nil, ErrIncompatibleDimensions() } } } @@ -91,15 +95,22 @@ func repeatTensorsForMutltidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, tens // All extra dimensions are given size one (otherwise the tensor cannot be reshaped). // The given tensor is cloned such that the tensor is not modified in place. // Example: if we add 2 extra dimensions to shape (2, 3) we get shape (1, 1, 2, 3). -func addExtraDimsToTensor(t tensor.Tensor, nExtraDims int) (tensor.Tensor, error) { - t = t.Clone().(tensor.Tensor) +func addExtraDimsToTensor(originalT tensor.Tensor, nExtraDims int) (tensor.Tensor, error) { + t, ok := originalT.Clone().(tensor.Tensor) + if !ok { + return nil, ErrTypeAssert("tensor.Tensor", originalT.Clone()) + } - var newShape []int + newShape := []int{} for i := 0; i < nExtraDims; i++ { newShape = append(newShape, 1) } + newShape = append(newShape, t.Shape()...) - err := t.Reshape(newShape...) - return t, err + if err := t.Reshape(newShape...); err != nil { + return nil, err + } + + return t, nil } diff --git a/ops/multidir_broadcast_test.go b/ops/multidir_broadcast_test.go index c0f454e..02e5cac 100644 --- a/ops/multidir_broadcast_test.go +++ b/ops/multidir_broadcast_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -47,22 +46,12 @@ func TestMultidirectionalBroadcast(t *testing.T) { { [][]int{{1, 4, 5}, {2, 1, 1, 3}}, nil, - fmt.Errorf( - MultidirBroadcastErrTemplate, - []int{1, 4, 5}, - []int{2, 1, 1, 3}, - "incompatible dimensions", - ), + ErrMultidirBroadcast([]int{1, 4, 5}, []int{2, 1, 1, 3}, ErrIncompatibleDimensions()), }, { [][]int{{5}, {2, 3, 4}}, nil, - fmt.Errorf( - MultidirBroadcastErrTemplate, - []int{5}, - []int{2, 3, 4}, - "incompatible dimensions", - ), + ErrMultidirBroadcast([]int{5}, []int{2, 3, 4}, ErrIncompatibleDimensions()), }, } @@ -73,6 +62,7 @@ func TestMultidirectionalBroadcast(t *testing.T) { newA, newB, err := MultidirectionalBroadcast(A, B) assert.Equal(t, test.err, err) + if err == nil { assert.Equal(t, test.expectedShape, newA.Shape()) assert.Equal(t, test.expectedShape, newB.Shape()) diff --git a/ops/operator.go b/ops/operator.go index 82b7486..7f26e4d 100644 --- a/ops/operator.go +++ b/ops/operator.go @@ -1,7 +1,7 @@ package ops import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/onnx" "gorgonia.org/tensor" ) @@ -10,10 +10,11 @@ type Operator interface { // String should return a simple string describing the operator String() string - // Init should initialize the operator based on the given attributes. How these + // Init should initialize the operator based on the given node. + // This node contains attributes, which outputs are expected and more. How these // attributes influence the operator is defined by the ONNX standard, and can be // found in https://github.com/onnx/onnx/blob/main/docs/Operators.md - Init([]*onnx.AttributeProto) error + Init(*onnx.NodeProto) error // Apply should apply the operator to the list of input tensors. It should return a // list with output tensors, the result of the operator. diff --git a/ops/opset13/abs.go b/ops/opset13/abs.go new file mode 100644 index 0000000..482d80e --- /dev/null +++ b/ops/opset13/abs.go @@ -0,0 +1,63 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinAbsInputs = 1 + MaxAbsInputs = 1 +) + +// Abs represents the ONNX abs operator. +type Abs struct{} + +// newAbs creates a new abs operator. +func newAbs() ops.Operator { + return &Abs{} +} + +// Init initializes the abs operator. +func (a *Abs) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the abs operator. +func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := tensor.Abs(inputs[0]) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (a *Abs) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (a *Abs) GetMinInputs() int { + return MinAbsInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (a *Abs) GetMaxInputs() int { + return MaxAbsInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (a *Abs) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (a *Abs) String() string { + return "abs operator" +} diff --git a/ops/opset13/abs_test.go b/ops/opset13/abs_test.go new file mode 100644 index 0000000..e9e0791 --- /dev/null +++ b/ops/opset13/abs_test.go @@ -0,0 +1,144 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAbsInit(t *testing.T) { + a := &Abs{} + + // since 'abs' does not have any attributes we pass in nil. This should not + // fail initializing the abs. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestAbs(t *testing.T) { + tests := []struct { + abs *Abs + backing []float32 + shape []int + expected []float32 + }{ + { + &Abs{}, + []float32{-2, -1, 0, 1}, + []int{2, 2}, + []float32{2, 1, 0, 1}, + }, + { + &Abs{}, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{1, 3, 4, 5}, + }, + { + &Abs{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{1, 1, 1, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.abs.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAbs(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint16{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int16{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Abs{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Abs{}), + }, + } + + for _, test := range tests { + abs := &Abs{} + validated, err := abs.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + assert.Equal(t, test.inputs, validated) + } +} diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go new file mode 100644 index 0000000..139c1ed --- /dev/null +++ b/ops/opset13/acos.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acos represents the ONNX acos operator. +type Acos struct{} + +// newAcos creates a new acos operator. +func newAcos() ops.Operator { + return &Acos{} +} + +// Init initializes the acos operator. +func (c *Acos) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the acos operator. +func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(acos[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(acos[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Acos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Acos) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Acos) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Acos) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Acos) String() string { + return "acos operator" +} + +func acos[T ops.FloatType](x T) T { + return T(math.Acos(float64(x))) +} diff --git a/ops/opset13/acos_test.go b/ops/opset13/acos_test.go new file mode 100644 index 0000000..e2c755a --- /dev/null +++ b/ops/opset13/acos_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAcosInit(t *testing.T) { + c := &Acos{} + + // since 'acos' does not have any attributes we pass in nil. This should not + // fail initializing the acos. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestAcos(t *testing.T) { + tests := []struct { + acos *Acos + backing []float32 + shape []int + expected []float32 + }{ + { + &Acos{}, + []float32{-1, -1, 0, 1}, + []int{2, 2}, + []float32{3.1415927, 3.1415927, 1.5707964, 0}, + }, + { + &Acos{}, + []float32{1, 0.5, 0.0, -0.5}, + []int{1, 4}, + []float32{0, 1.0471976, 1.5707964, 2.0943952}, + }, + { + &Acos{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{3.1415927, 3.1415927, 3.1415927, 3.1415927}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.acos.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAcos(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Acos{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Acos{}), + }, + } + + for _, test := range tests { + acos := &Acos{} + validated, err := acos.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/acosh.go b/ops/opset13/acosh.go new file mode 100644 index 0000000..0e1404c --- /dev/null +++ b/ops/opset13/acosh.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acosh represents the ONNX acosh operator. +type Acosh struct{} + +// newAcosh creates a new acosh operator. +func newAcosh() ops.Operator { + return &Acosh{} +} + +// Init initializes the acosh operator. +func (c *Acosh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the acosh operator. +func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(acosh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(acosh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Acosh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Acosh) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Acosh) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Acosh) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Acosh) String() string { + return "acosh operator" +} + +func acosh[T ops.FloatType](x T) T { + return T(math.Acosh(float64(x))) +} diff --git a/ops/opset13/acosh_test.go b/ops/opset13/acosh_test.go new file mode 100644 index 0000000..d6c155d --- /dev/null +++ b/ops/opset13/acosh_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAcoshInit(t *testing.T) { + c := &Acosh{} + + // since 'acosh' does not have any attributes we pass in nil. This should not + // fail initializing the acosh. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestAcosh(t *testing.T) { + tests := []struct { + acosh *Acosh + backing []float32 + shape []int + expected []float32 + }{ + { + &Acosh{}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{0, 1.316958, 1.7627472, 2.063437}, + }, + { + &Acosh{}, + []float32{1, 2, 3, 4}, + []int{1, 4}, + []float32{0, 1.316958, 1.7627472, 2.063437}, + }, + { + &Acosh{}, + []float32{2, 2, 2, 2}, + []int{1, 4}, + []float32{1.316958, 1.316958, 1.316958, 1.316958}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.acosh.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAcosh(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Acosh{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Acosh{}), + }, + } + + for _, test := range tests { + acosh := &Acosh{} + validated, err := acosh.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/add.go b/ops/opset13/add.go index fbd9f2b..cf5b566 100644 --- a/ops/opset13/add.go +++ b/ops/opset13/add.go @@ -1,11 +1,16 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinAddInputs = 2 + MaxAddInputs = 2 +) + // Add represents the ONNX add operator. type Add struct{} @@ -15,23 +20,18 @@ func newAdd() ops.Operator { } // Init initializes the add operator. -func (a *Add) Init(attributes []*onnx.AttributeProto) error { +func (a *Add) Init(*onnx.NodeProto) error { return nil } // Apply applies the add operator. func (a *Add) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - in1, in2, err := ops.MultidirectionalBroadcast(inputs[0], inputs[1]) - if err != nil { - return nil, err - } - - out, err := tensor.Add(in1, in2) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Add, + ops.MultidirectionalBroadcasting, + ) } // ValidateInputs validates the inputs that will be given to Apply for this operator. @@ -41,12 +41,12 @@ func (a *Add) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (a *Add) GetMinInputs() int { - return 2 + return MinAddInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (a *Add) GetMaxInputs() int { - return 2 + return MaxAddInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/add_test.go b/ops/opset13/add_test.go index 6cec09c..f7dacd1 100644 --- a/ops/opset13/add_test.go +++ b/ops/opset13/add_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -54,7 +53,6 @@ func TestAdd(t *testing.T) { res, err := test.add.Apply(inputs) assert.Nil(t, err) - assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) } } @@ -67,16 +65,7 @@ func TestAddFail(t *testing.T) { add := &Add{} _, err := add.Apply(inputs) - assert.Equal( - t, - err, - fmt.Errorf( - ops.MultidirBroadcastErrTemplate, - []int{2, 2}, - []int{3}, - "incompatible dimensions", - ), - ) + assert.Equal(t, err, ops.ErrMultidirBroadcast(inputs[0].Shape(), inputs[1].Shape(), ops.ErrIncompatibleDimensions())) } func TestInputValidationAdd(t *testing.T) { @@ -130,14 +119,14 @@ func TestInputValidationAdd(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("add operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Add{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("add operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Add{}), }, } @@ -146,6 +135,7 @@ func TestInputValidationAdd(t *testing.T) { validated, err := add.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/and.go b/ops/opset13/and.go new file mode 100644 index 0000000..68b2a22 --- /dev/null +++ b/ops/opset13/and.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinAndInputs = 2 + MaxAndInputs = 2 +) + +// And represents the ONNX and operator. +type And struct{} + +// newAnd creates a new and operator. +func newAnd() ops.Operator { + return &And{} +} + +// Init initializes the and operator. +func (a *And) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the and operator. +func (a *And) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.And, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (a *And) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (a *And) GetMinInputs() int { + return MinAndInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (a *And) GetMaxInputs() int { + return MaxAndInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (a *And) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (a *And) String() string { + return "and operator" +} diff --git a/ops/opset13/and_test.go b/ops/opset13/and_test.go new file mode 100644 index 0000000..b17fc35 --- /dev/null +++ b/ops/opset13/and_test.go @@ -0,0 +1,104 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAndInit(t *testing.T) { + a := &And{} + + // since 'and' does not have any attributes we pass in nil. This should not + // fail initializing the and. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestAnd(t *testing.T) { + tests := []struct { + and *And + backings [][]bool + shapes [][]int + expected []bool + }{ + { + &And{}, + [][]bool{{true, false, true, false}, {true, true, true, false}}, + [][]int{{2, 2}, {2, 2}}, + []bool{true, false, true, false}, + }, + { + &And{}, + [][]bool{{true, false, true, false}, {true, false}}, + [][]int{{2, 2}, {1, 2}}, + []bool{true, false, true, false}, + }, + { + &And{}, + [][]bool{{true, false, true, false}, {true, false}}, + [][]int{{2, 2}, {2, 1}}, + []bool{true, false, false, false}, + }, + { + &And{}, + [][]bool{{true, false, true, false, true, false}, {false, false}}, + [][]int{{3, 2}, {1, 2}}, + []bool{false, false, false, false, false, false}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.and.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAnd(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + ops.ErrInvalidInputCount(1, &And{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(1, "int", &And{}), + }, + } + + for _, test := range tests { + and := &And{} + validated, err := and.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/asin.go b/ops/opset13/asin.go new file mode 100644 index 0000000..0dae65f --- /dev/null +++ b/ops/opset13/asin.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Asin represents the ONNX asin operator. +type Asin struct{} + +// newSin creates a new asin operator. +func newAsin() ops.Operator { + return &Asin{} +} + +// Init initializes the asin operator. +func (s *Asin) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the asin operator. +func (s *Asin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(asin[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(asin[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (s *Asin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(s, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (s *Asin) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (s *Asin) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (s *Asin) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (s *Asin) String() string { + return "asin operator" +} + +func asin[T ops.FloatType](x T) T { + return T(math.Asin(float64(x))) +} diff --git a/ops/opset13/asin_test.go b/ops/opset13/asin_test.go new file mode 100644 index 0000000..c145649 --- /dev/null +++ b/ops/opset13/asin_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAsinInit(t *testing.T) { + s := &Asin{} + + // since 'asin' does not have any attributes we pass in nil. This should not + // fail initializing the asin. + err := s.Init(nil) + assert.Nil(t, err) +} + +func TestAsin(t *testing.T) { + tests := []struct { + asin *Asin + backing []float32 + shape []int + expected []float32 + }{ + { + &Asin{}, + []float32{-1, -1, 0, 1}, + []int{2, 2}, + []float32{-1.5707964, -1.5707964, 0, 1.5707964}, + }, + { + &Asin{}, + []float32{1, 0.5, 0.0, -0.5}, + []int{1, 4}, + []float32{1.5707964, 0.5235988, 0, -0.5235988}, + }, + { + &Asin{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{-1.5707964, -1.5707964, -1.5707964, -1.5707964}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.asin.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAsin(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Asin{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Asin{}), + }, + } + + for _, test := range tests { + asin := &Asin{} + validated, err := asin.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/asinh.go b/ops/opset13/asinh.go new file mode 100644 index 0000000..8490711 --- /dev/null +++ b/ops/opset13/asinh.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Asinh represents the ONNX asinh operator. +type Asinh struct{} + +// newAsinh creates a new asinh operator. +func newAsinh() ops.Operator { + return &Asinh{} +} + +// Init initializes the asinh operator. +func (a *Asinh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the asinh operator. +func (a *Asinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(asinh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(asinh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (a *Asinh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (a *Asinh) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (a *Asinh) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (a *Asinh) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (a *Asinh) String() string { + return "asinh operator" +} + +func asinh[T ops.FloatType](x T) T { + return T(math.Asinh(float64(x))) +} diff --git a/ops/opset13/asinh_test.go b/ops/opset13/asinh_test.go new file mode 100644 index 0000000..da5c6fc --- /dev/null +++ b/ops/opset13/asinh_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAsinhInit(t *testing.T) { + c := &Asinh{} + + // since 'asinh' does not have any attributes we pass in nil. This should not + // fail initializing the asinh. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestAsinh(t *testing.T) { + tests := []struct { + asinh *Asinh + backing []float32 + shape []int + expected []float32 + }{ + { + &Asinh{}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{0.8813736, 1.4436355, 1.8184465, 2.0947125}, + }, + { + &Asinh{}, + []float32{1, 2, 3, 4}, + []int{1, 4}, + []float32{0.8813736, 1.4436355, 1.8184465, 2.0947125}, + }, + { + &Asinh{}, + []float32{2, 2, 2, 2}, + []int{1, 4}, + []float32{1.4436355, 1.4436355, 1.4436355, 1.4436355}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.asinh.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAsinh(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Asinh{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Asinh{}), + }, + } + + for _, test := range tests { + asinh := &Asinh{} + validated, err := asinh.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/atan.go b/ops/opset13/atan.go new file mode 100644 index 0000000..d373d65 --- /dev/null +++ b/ops/opset13/atan.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Atan represents the ONNX atan operator. +type Atan struct{} + +// newAtan creates a new atan operator. +func newAtan() ops.Operator { + return &Atan{} +} + +// Init initializes the atan operator. +func (a *Atan) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the atan operator. +func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(atan[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(atan[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (a *Atan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (a *Atan) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (a *Atan) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (a *Atan) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (a *Atan) String() string { + return "atan operator" +} + +func atan[T ops.FloatType](x T) T { + return T(math.Atan(float64(x))) +} diff --git a/ops/opset13/atan_test.go b/ops/opset13/atan_test.go new file mode 100644 index 0000000..f6d1d97 --- /dev/null +++ b/ops/opset13/atan_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAtanInit(t *testing.T) { + a := &Atan{} + + // since 'atan' does not have any attributes we pass in nil. This should not + // fail initializing the atan. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestAtan(t *testing.T) { + tests := []struct { + atan *Atan + backing []float32 + shape []int + expected []float32 + }{ + { + &Atan{}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, + }, + { + &Atan{}, + []float32{1, 2, 3, 4}, + []int{1, 4}, + []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, + }, + { + &Atan{}, + []float32{2, 2, 2, 2}, + []int{1, 4}, + []float32{1.1071488, 1.1071488, 1.1071488, 1.1071488}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.atan.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAtan(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Atan{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Atan{}), + }, + } + + for _, test := range tests { + atan := &Atan{} + validated, err := atan.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/atanh.go b/ops/opset13/atanh.go new file mode 100644 index 0000000..f60b6d1 --- /dev/null +++ b/ops/opset13/atanh.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Atanh represents the ONNX atanh operator. +type Atanh struct{} + +// newAtanh creates a new atanh operator. +func newAtanh() ops.Operator { + return &Atanh{} +} + +// Init initializes the atanh operator. +func (a *Atanh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the atanh operator. +func (a *Atanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(atanh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(atanh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (a *Atanh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(a, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (a *Atanh) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (a *Atanh) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (a *Atanh) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (a *Atanh) String() string { + return "atanh operator" +} + +func atanh[T ops.FloatType](x T) T { + return T(math.Atanh(float64(x))) +} diff --git a/ops/opset13/atanh_test.go b/ops/opset13/atanh_test.go new file mode 100644 index 0000000..65441a7 --- /dev/null +++ b/ops/opset13/atanh_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestAtanhInit(t *testing.T) { + a := &Atanh{} + + // since 'atanh' does not have any attributes we pass in nil. This should not + // fail initializing the atanh. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestAtanh(t *testing.T) { + tests := []struct { + atanh *Atanh + backing []float32 + shape []int + expected []float32 + }{ + { + &Atanh{}, + []float32{-0.9, -0.5, 0, 0.5}, + []int{2, 2}, + []float32{-1.4722193, -0.54930615, 0, 0.54930615}, + }, + { + &Atanh{}, + []float32{-0.9, -0.5, 0, 0.5}, + []int{1, 4}, + []float32{-1.4722193, -0.54930615, 0, 0.54930615}, + }, + { + &Atanh{}, + []float32{0.5, 0.5, 0.5, 0.5}, + []int{1, 4}, + []float32{0.54930615, 0.54930615, 0.54930615, 0.54930615}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.atanh.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationAtanh(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Atanh{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Atanh{}), + }, + } + + for _, test := range tests { + atanh := &Atanh{} + validated, err := atanh.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go index 1784223..8a8a552 100644 --- a/ops/opset13/cast.go +++ b/ops/opset13/cast.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinCastInputs = 1 + MaxCastInputs = 1 +) + // Cast represents the ONNX cast operator. type Cast struct { to int32 // DataType to cast to, as defined by TensorProto @@ -19,16 +22,18 @@ func newCast() ops.Operator { } // Init initializes the cast operator. -func (c *Cast) Init(attributes []*onnx.AttributeProto) error { +func (c *Cast) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, c, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } attr := attributes[0] if attr.GetName() == "to" { c.to = int32(attr.GetI()) } else { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, c, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), c) } return nil @@ -51,12 +56,12 @@ func (c *Cast) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *Cast) GetMinInputs() int { - return 1 + return MinCastInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (c *Cast) GetMaxInputs() int { - return 1 + return MaxCastInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -68,7 +73,6 @@ func (c *Cast) GetInputTypeConstraints() [][]tensor.Dtype { tensor.Float32, tensor.Float64, }, } - } // String implements the stringer interface, and can be used to format errors or messages. diff --git a/ops/opset13/cast_test.go b/ops/opset13/cast_test.go index 0634cd3..74d4648 100644 --- a/ops/opset13/cast_test.go +++ b/ops/opset13/cast_test.go @@ -1,19 +1,18 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) func TestCastInit(t *testing.T) { c := &Cast{} - err := c.Init([]*onnx.AttributeProto{{Name: "to", I: 1}}) + err := c.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: 1}}}) assert.Nil(t, err) assert.Equal(t, int32(1), c.to) } @@ -64,7 +63,7 @@ func TestCast(t *testing.T) { } for _, test := range tests { - test.cast.Init([]*onnx.AttributeProto{{Name: "to", I: test.to}}) + _ = test.cast.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: test.to}}}) inputs := []tensor.Tensor{ops.TensorWithBackingFixture(test.backing, test.shape...)} res, err := test.cast.Apply(inputs) @@ -93,13 +92,13 @@ func TestInputValidationCast(t *testing.T) { ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), }, - fmt.Errorf("cast operator: expected 1 input tensors, got 2"), + ops.ErrInvalidInputCount(2, &Cast{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{true, false}, 2), }, - fmt.Errorf("cast operator: input 0 does not allow type bool"), + ops.ErrInvalidInputType(0, "bool", &Cast{}), }, } @@ -108,6 +107,7 @@ func TestInputValidationCast(t *testing.T) { validated, err := cast.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/concat.go b/ops/opset13/concat.go index 0a2c35e..a7a24a0 100644 --- a/ops/opset13/concat.go +++ b/ops/opset13/concat.go @@ -1,13 +1,15 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinConcatInputs = 1 +) + // Concat represents the ONNX concat operator. type Concat struct { axis int @@ -21,12 +23,15 @@ func newConcat() ops.Operator { } // Init initializes the concat operator. -func (c *Concat) Init(attributes []*onnx.AttributeProto) error { +func (c *Concat) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, c, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } c.axis = int(attributes[0].GetI()) + return nil } @@ -56,6 +61,7 @@ func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // of inputs dynamically, based on our inputs. Every input can have any type. c.maxInputs = len(inputs) c.inputTypeConstraints = make([][]tensor.Dtype, len(inputs)) + for i := 0; i < len(inputs); i++ { c.inputTypeConstraints[i] = ops.AllTypes } @@ -65,7 +71,7 @@ func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (c *Concat) GetMinInputs() int { - return 1 + return MinConcatInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. diff --git a/ops/opset13/concat_test.go b/ops/opset13/concat_test.go index cd75745..01fd033 100644 --- a/ops/opset13/concat_test.go +++ b/ops/opset13/concat_test.go @@ -1,18 +1,17 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) func TestConcatInit(t *testing.T) { concat := &Concat{} - err := concat.Init([]*onnx.AttributeProto{{Name: "axis", I: 3}}) + err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 3}}}) assert.Nil(t, err) assert.Equal(t, 3, concat.axis) @@ -20,9 +19,9 @@ func TestConcatInit(t *testing.T) { func TestConcatInitFail(t *testing.T) { concat := &Concat{} - err := concat.Init([]*onnx.AttributeProto{}) + err := concat.Init(ops.EmptyNodeProto()) - expected := fmt.Errorf(ops.InvalidAttrCountErrTemplate, concat, 1, 0) + expected := ops.ErrInvalidAttributeCount(1, 0, concat) assert.Equal(t, expected, err) } diff --git a/ops/opset13/constant.go b/ops/opset13/constant.go index bd41249..d0c1261 100644 --- a/ops/opset13/constant.go +++ b/ops/opset13/constant.go @@ -1,10 +1,8 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) @@ -20,20 +18,23 @@ func newConstant() ops.Operator { // Init initializes the constant operator. It supports all constant types except // `sparse_value`, `value_string`, and `value_strings`. -func (c *Constant) Init(attributes []*onnx.AttributeProto) error { +func (c *Constant) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, c, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } + attr := attributes[0] switch attr.GetName() { case "sparse_value", "value_string", "value_strings": - return fmt.Errorf(ops.UnsupportedAttrErrTemplate, c, attr.GetName()) + return ops.ErrUnsupportedAttribute(attr.GetName(), c) case "value": t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err } + c.value = t case "value_float": c.value = tensor.New(tensor.FromScalar(attr.GetF())) @@ -46,14 +47,14 @@ func (c *Constant) Init(attributes []*onnx.AttributeProto) error { ints := attr.GetInts() c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) default: - return fmt.Errorf(ops.UnknownAttributeErrTemplate, c, attr.GetName()) + return ops.ErrUnsupportedAttribute(attr.GetName(), c) } return nil } // Apply applies the constant operator. -func (c *Constant) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{c.value}, nil } diff --git a/ops/opset13/constant_of_shape.go b/ops/opset13/constant_of_shape.go index 507da74..9511108 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/opset13/constant_of_shape.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinConstantOfShapeInputs = 1 + MaxConstantOfShapeInputs = 1 +) + // ConstantOfShape represents the ONNX constant of shape operator. type ConstantOfShape struct { // One element tensor, giving the value and type of the output tensor @@ -21,9 +24,11 @@ func newConstantOfShape() ops.Operator { } // Init initializes the constant of shape operator. -func (op *ConstantOfShape) Init(attributes []*onnx.AttributeProto) error { +func (c *ConstantOfShape) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) > 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, op, "0 or 1", len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), c) } if len(attributes) == 1 { @@ -34,26 +39,22 @@ func (op *ConstantOfShape) Init(attributes []*onnx.AttributeProto) error { return err } - op.value = tensor.New(tensor.WithBacking(t.Data())) - if op.value.Len() != 1 { - return fmt.Errorf( - "Value input tensor should be a single element tensor, but was %v", - op.value, - ) + c.value = tensor.New(tensor.WithBacking(t.Data())) + if c.value.Len() != 1 { + return ops.ErrInvalidTensor("expected tensor to have one element", c) } } else { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, op, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), c) } } else { - // Default - op.value = tensor.New(tensor.FromScalar(float32(0.0))) + c.value = tensor.New(tensor.FromScalar(float32(0.0))) } return nil } // Apply applies the constant of shape operator. -func (op *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { shape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[0].Data())) if err != nil { return nil, err @@ -62,44 +63,44 @@ func (op *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error // Empty dimensions in a tensor are not supported for i := range shape { if shape[i] <= 0 { - return nil, fmt.Errorf( - "Non positive dimensions are not allowed (must be > 0). Given: %v", - shape, - ) + return nil, ops.ErrInvalidTensor("empty dimensions are not allowed", c) } } - t := tensor.New(tensor.WithShape(shape...), tensor.Of(op.value.Dtype())) - t, err = t.AddScalar(op.value, true) + + t := tensor.New(tensor.WithShape(shape...), tensor.Of(c.value.Dtype())) + + t, err = t.AddScalar(c.value, true) + if err != nil { + return nil, err + } return []tensor.Tensor{t}, err } // ValidateInputs validates the inputs that will be given to Apply for this operator. -func (op *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(op, inputs) +func (c *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) } // GetMinInputs returns the minimum number of input tensors this operator expects. -func (op *ConstantOfShape) GetMinInputs() int { - return 1 +func (c *ConstantOfShape) GetMinInputs() int { + return MinConstantOfShapeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. -func (op *ConstantOfShape) GetMaxInputs() int { - return 1 +func (c *ConstantOfShape) GetMaxInputs() int { + return MaxConstantOfShapeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes // for the corresponding input tensor. -func (op *ConstantOfShape) GetInputTypeConstraints() [][]tensor.Dtype { +func (c *ConstantOfShape) GetInputTypeConstraints() [][]tensor.Dtype { return [][]tensor.Dtype{ {tensor.Int64}, } - } // String implements the stringer interface, and can be used to format errors or messages. -func (op *ConstantOfShape) String() string { +func (c *ConstantOfShape) String() string { return "constant of shape operator" - } diff --git a/ops/opset13/constant_of_shape_test.go b/ops/opset13/constant_of_shape_test.go index b539919..e294c25 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/opset13/constant_of_shape_test.go @@ -2,12 +2,11 @@ package opset13 import ( "encoding/binary" - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -19,6 +18,7 @@ func TensorProtoFromNumber(n interface{}) *onnx.TensorProto { size := 1 rawData := make([]byte, size) rawData[0] = uint8(x) + return &onnx.TensorProto{ DataType: onnx.TensorProto_DataType_value["INT8"], Dims: []int64{1}, @@ -29,6 +29,7 @@ func TensorProtoFromNumber(n interface{}) *onnx.TensorProto { size := 2 rawData := make([]byte, size) binary.LittleEndian.PutUint16(rawData, uint16(x)) + return &onnx.TensorProto{ DataType: onnx.TensorProto_DataType_value["INT16"], Dims: []int64{1}, @@ -85,11 +86,12 @@ func TestConstantOfShape(t *testing.T) { // Make the input tensor tp := TensorProtoFromNumber(test.input) assert.NotNil(t, tp) - attr := []*onnx.AttributeProto{{Name: "value", T: tp}} + + node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} // Create operator op := ConstantOfShape{} - err := op.Init(attr) + err := op.Init(node) assert.NoError(t, err) assert.Equal(t, test.input, op.value.Data()) @@ -108,7 +110,7 @@ func TestConstantOfShapeEmptyInit(t *testing.T) { op := &ConstantOfShape{} // No init value given - err := op.Init([]*onnx.AttributeProto{}) + err := op.Init(ops.EmptyNodeProto()) assert.NoError(t, err) assert.Equal(t, float32(0.0), op.value.Data()) @@ -120,7 +122,6 @@ func TestConstantOfShapeEmptyInit(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []float32{0, 0, 0, 0}, res[0].Data()) - } func TestIncorrectInput(t *testing.T) { @@ -129,22 +130,21 @@ func TestIncorrectInput(t *testing.T) { Dims: []int64{3}, Int32Data: []int32{1, 2, 3}, } - attr := []*onnx.AttributeProto{{Name: "value", T: tp}} + node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} op := &ConstantOfShape{} - err := op.Init(attr) + err := op.Init(node) assert.NotNil(t, err) assert.Equal( t, - "Value input tensor should be a single element tensor, but was [1 2 3]", + "constant of shape operator invalid tensor found, reason: expected tensor to have one element", err.Error(), ) - } func TestNegativeShapeNotAllowed(t *testing.T) { op := &ConstantOfShape{} - op.Init([]*onnx.AttributeProto{}) + _ = op.Init(ops.EmptyNodeProto()) shape := []int64{1, -1} @@ -154,13 +154,13 @@ func TestNegativeShapeNotAllowed(t *testing.T) { assert.Equal( t, - "Non positive dimensions are not allowed (must be > 0). Given: [1 -1]", + "constant of shape operator invalid tensor found, reason: empty dimensions are not allowed", err.Error()) } func TestEmptyTensorNotAllowed(t *testing.T) { op := &ConstantOfShape{} - op.Init([]*onnx.AttributeProto{}) + _ = op.Init(ops.EmptyNodeProto()) shape := []int64{0} @@ -170,13 +170,13 @@ func TestEmptyTensorNotAllowed(t *testing.T) { assert.Equal( t, - "Non positive dimensions are not allowed (must be > 0). Given: [0]", + "constant of shape operator invalid tensor found, reason: empty dimensions are not allowed", err.Error()) } func TestScalarShapeInput(t *testing.T) { op := &ConstantOfShape{} - op.Init([]*onnx.AttributeProto{}) + _ = op.Init(ops.EmptyNodeProto()) shape := []int64{6} input := tensor.New(tensor.WithBacking(shape)) @@ -200,11 +200,11 @@ func TestInputValidationConstantOfShape(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("constant of shape operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &ConstantOfShape{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("constant of shape operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &ConstantOfShape{}), }, } @@ -213,6 +213,7 @@ func TestInputValidationConstantOfShape(t *testing.T) { validated, err := constantOfShape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/constant_test.go b/ops/opset13/constant_test.go index 3e57f40..ffebccf 100644 --- a/ops/opset13/constant_test.go +++ b/ops/opset13/constant_test.go @@ -2,12 +2,11 @@ package opset13 import ( "encoding/binary" - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -45,30 +44,30 @@ func TestConstantInit(t *testing.T) { { []*onnx.AttributeProto{{Name: "sparse_value"}}, nil, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &Constant{}, "sparse_value"), + ops.ErrUnsupportedAttribute("sparse_value", &Constant{}), }, { []*onnx.AttributeProto{{Name: "unknownAttribute"}}, nil, - fmt.Errorf(ops.UnknownAttributeErrTemplate, &Constant{}, "unknownAttribute"), + ops.ErrUnsupportedAttribute("unknownAttribute", &Constant{}), }, { []*onnx.AttributeProto{}, nil, - fmt.Errorf(ops.InvalidAttrCountErrTemplate, &Constant{}, 1, 0), + ops.ErrInvalidAttributeCount(1, 0, &Constant{}), }, } for _, test := range tests { constant := &Constant{} - err := constant.Init(test.initAttr) + err := constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) assert.Equal(t, test.err, err) + if err != nil { assert.Equal(t, test.expected, constant.value) } } - } func TestConstant(t *testing.T) { @@ -105,7 +104,7 @@ func TestConstant(t *testing.T) { } for _, test := range tests { - test.constant.Init(test.initAttr) + _ = test.constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) res, err := test.constant.Apply([]tensor.Tensor{}) assert.Nil(t, err) @@ -115,7 +114,7 @@ func TestConstant(t *testing.T) { func TestConstantSingleIntShapeTensor(t *testing.T) { constant := &Constant{} - err := constant.Init([]*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{2}}}) + err := constant.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{2}}}}) assert.Nil(t, err) assert.False(t, constant.value.IsScalar()) @@ -134,7 +133,7 @@ func TestInputValidationConstant(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("constant operator: expected 0 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Constant{}), }, } @@ -143,6 +142,7 @@ func TestInputValidationConstant(t *testing.T) { validated, err := constant.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } @@ -158,6 +158,7 @@ func ConstantValueAttrProtoFixture() []*onnx.AttributeProto { binary.LittleEndian.PutUint64(bValues[16:24], uint64(values[2])) tp := &onnx.TensorProto{DataType: int32(7), Dims: []int64{3}, RawData: bValues} + return []*onnx.AttributeProto{{Name: "value", T: tp}} } diff --git a/ops/opset13/conv.go b/ops/opset13/conv.go new file mode 100644 index 0000000..801a5e9 --- /dev/null +++ b/ops/opset13/conv.go @@ -0,0 +1,590 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinConvInputs = 2 + MaxConvInputs = 3 + NDims1DConvolution = 3 + NDims2DConvolution = 4 +) + +type AutoPadSetting string + +const ( + NotSet AutoPadSetting = "NOTSET" + SameUpper AutoPadSetting = "SAME_UPPER" + SameLower AutoPadSetting = "SAME_LOWER" + Valid AutoPadSetting = "VALID" +) + +// The number of non spatial dimensions inputs and kernels will always have. +// For input tensors, the first dimension will be the batch size. +// For kernel tensors, the first dimension will be the number of kernels. +// For all tensors, the second dimension will be the number of channels. +const nNonSpatialDims = 2 + +// Conv represents the ONNX conv operator. +type Conv struct { + autoPad AutoPadSetting + dilations []int + group int + kernelShape []int + pads []int + strides []int +} + +// newConv creates a new conv operator. +func newConv() ops.Operator { + return &Conv{ + autoPad: NotSet, + } +} + +// Init initializes the conv operator. +func (c *Conv) Init(n *onnx.NodeProto) error { + var err error + + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "auto_pad": + c.autoPad = AutoPadSetting(attr.GetS()) + case "dilations": + c.dilations, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + case "group": + c.group = int(attr.GetI()) + if c.group != 1 { + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + case "kernel_shape": + c.kernelShape, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + case "pads": + c.pads, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + case "strides": + c.strides, err = ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + } + + return nil +} + +// Apply applies the conv operator. +func (c *Conv) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + x := inputs[0] + kernel := inputs[1] + bias := inputs[2] + + if len(c.dilations) == 0 { + c.setDefaultDilations(x) + } + + if len(c.kernelShape) == 0 { + c.setKernelShape(kernel) + } + + if len(c.pads) == 0 { + c.setDefaultPaddings(x) + } + + if len(c.strides) == 0 { + c.setDefaultStrides(x) + } + + kernel, err := c.getDilatedKernel(kernel) + if err != nil { + return nil, err + } + + if c.autoPad != NotSet { + c.setPaddingWithAutoPad(x) + } + + var out tensor.Tensor + + switch len(x.Shape()) { + case NDims1DConvolution: + out, err = c.applyConv1D(x, kernel) + case NDims2DConvolution: + out, err = c.applyConv2D(x, kernel) + default: + return nil, ops.ErrInvalidInput("the convolution operator currently only supports 1D or 2D convolution, i.e. shape [N x C x H (x W)]", c) + } + + if err != nil { + return nil, err + } + + if bias != nil { + out, err = c.addBias(out, bias) + if err != nil { + return nil, err + } + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Conv) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Conv) GetMinInputs() int { + return MinConvInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Conv) GetMaxInputs() int { + return MaxConvInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Conv) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Conv) String() string { + return "conv operator" +} + +// setDefaultDilations sets the dilations attribute to the default. Can be called when no +// dilations were set when initializing. +func (c *Conv) setDefaultDilations(x tensor.Tensor) { + nDims := len(x.Shape()[2:]) + + dilations := make([]int, nDims) + for i := 0; i < nDims; i++ { + dilations[i] = 1 + } + + c.dilations = dilations +} + +// setKernelShape infers the shape of the kernel when it was not given in the attributes. +func (c *Conv) setKernelShape(kernel tensor.Tensor) { + c.kernelShape = kernel.Shape()[2:] +} + +// setDefaultPaddings sets default paddings as attribute. Can be called when no paddings +// were set during initialization. +func (c *Conv) setDefaultPaddings(x tensor.Tensor) { + NPadsPerDim := 2 + paddingLength := len(x.Shape()[2:]) * NPadsPerDim + + pads := make([]int, paddingLength) + for i := 0; i < paddingLength; i++ { + pads[i] = 0 + } + + c.pads = pads +} + +// setDefaultStrides sets default strides as attribute. Can be called when no strides +// were set during initialization. +func (c *Conv) setDefaultStrides(x tensor.Tensor) { + nDims := len(x.Shape()[2:]) + + strides := make([]int, nDims) + for i := 0; i < nDims; i++ { + strides[i] = 1 + } + + c.strides = strides +} + +// setPaddingWithAutoPad sets the padding attribute of the operator based on +// the input tensor `x`, the shape of the kernel and the strides. +func (c *Conv) setPaddingWithAutoPad(x tensor.Tensor) { + if c.autoPad == NotSet { + return + } + + NPadsPerDim := 2 + inputShape := x.Shape() + nDims := len(inputShape) + nSpatialDims := nDims - nNonSpatialDims + + c.pads = make([]int, nSpatialDims*NPadsPerDim) + + for i := 0; i < nSpatialDims; i++ { + dim := inputShape[i] + targetSize := (dim + c.strides[i] - 1) / c.strides[i] + padNeeded := (targetSize-1)*c.strides[i] + c.kernelShape[i] - dim + + var padHead int + if c.autoPad == SameLower { + // nolint as the division by zero is literally division by two + padHead = (padNeeded + 1) / 2 + } else { + // nolint as the division by two is literally division by two + padHead = padNeeded / 2 + } + + padTail := padNeeded - padHead + c.pads[i] = padHead + c.pads[i+nSpatialDims] = padTail + } +} + +// getDilatedKernel creates a new kernel given the `dilations` attribute of this +// conv operator. A dilated kernel basically means inserting zeros in between +// the kernels, i.e. a 2D kernel like: +// +// 1 2 +// 3 4 +// +// Dilated by one in both dimensions yields a new kernel of: +// +// 1 0 2 +// 0 0 0 +// 3 0 4 +// +// This function updates the given kernel and dilates it by the given amount +// for each dimensions separately. It returns a new tensor with the new kernel. +func (c *Conv) getDilatedKernel(kernel tensor.Tensor) (tensor.Tensor, error) { + oldKernelShape := kernel.Shape() + newKernelShape := make([]int, len(oldKernelShape)) + + // Add the non spatial dimensions of the kernel, i.e. the number of + // kernels (index 0) and the number of channels (index 1). These + // dimensions do not have to be dilated. + for i := 0; i < nNonSpatialDims; i++ { + newKernelShape[i] = oldKernelShape[i] + } + + // Add the dilated spatial dimensions of the kernel, i.e. in the case + // of 2D images these are the width and height dimensions. + for i, dilation := range c.dilations { + oldKernelDim := oldKernelShape[nNonSpatialDims+i] + newKernelShape[nNonSpatialDims+i] = oldKernelDim + (oldKernelDim-1)*(dilation-1) + } + + newKernel := tensor.NewDense(kernel.Dtype(), newKernelShape) + newKernel.Zero() + + // Now we fill the empty kernel with the original kernel values at the + // right positions. + iterator := kernel.Iterator() + iterator.Reset() + + for !iterator.Done() { + oldCoords := iterator.Coord() + + value, err := kernel.At(oldCoords...) + if err != nil { + return nil, err + } + + newCoords := c.getNewCoordsAfterDilation(oldCoords) + + err = newKernel.SetAt(value, newCoords...) + if err != nil { + return nil, err + } + + _, err = iterator.Next() + if err != nil { + return nil, err + } + } + + c.setKernelShape(newKernel) + + return newKernel, nil +} + +// getNewCoordsAfterDilation returns the new coordinates of a value given the old coordinates of that +// value in the old kernel and its shape. The new coordinates can be used to store the value/weight +// in the dilated kernel. +func (c *Conv) getNewCoordsAfterDilation(oldCoords []int) []int { + newCoords := make([]int, len(oldCoords)) + + for i := 0; i < nNonSpatialDims; i++ { + newCoords[i] = oldCoords[i] + } + + for i, dilation := range c.dilations { + newCoords[nNonSpatialDims+i] = oldCoords[nNonSpatialDims+i] * dilation + } + + return newCoords +} + +// Applies 1D convolution to tensor X with the 'kernel' tensor. +// X will have 3 dimensions: [N, C, H] where N is the batch size, C is the number +// of channels and H is the number of dimensions on which to apply the convolutions. +// The kernel will have shape [kernelDim], where 'kernelDim' is the size of the kernel +// size of the kernel. +func (c *Conv) applyConv1D(x, kernel tensor.Tensor) (tensor.Tensor, error) { + outputShape := c.getOutputShape(x, kernel) + out := tensor.Tensor(tensor.NewDense(x.Dtype(), outputShape)) + out.Zero() + + paddedX, err := c.padInput(x) + if err != nil { + return nil, err + } + + nBatches := x.Shape()[0] + nKernels := kernel.Shape()[0] + strideSize := c.strides[0] + outputHDim := outputShape[nNonSpatialDims] + + for batchIdx := 0; batchIdx < nBatches; batchIdx++ { + for kernelIdx := 0; kernelIdx < nKernels; kernelIdx++ { + subKernelView, err := kernel.Slice(ops.NewSlicer(kernelIdx, kernelIdx+1)) + if err != nil { + return nil, err + } + + subKernel := subKernelView.Materialize() + + for h := 0; h < paddedX.Shape()[2]; h += strideSize { + dimHOutputIdx := h / strideSize + if dimHOutputIdx >= outputHDim { + continue + } + + subImage, err := c.getSubImage(paddedX, batchIdx, h) + if err != nil { + return nil, err + } + + subImage, subKernel, err = ops.UnidirectionalBroadcast(subImage, subKernel) + if err != nil { + return nil, err + } + + convResult, err := tensor.Mul(subImage, subKernel) + if err != nil { + return nil, err + } + + convValue, err := tensor.Sum(convResult) + if err != nil { + return nil, err + } + + err = out.SetAt(convValue.ScalarValue(), batchIdx, kernelIdx, dimHOutputIdx) + if err != nil { + return nil, err + } + } + } + } + + return out, nil +} + +// Applies 2D convolution to tensor X with the 'kernel' tensor. +// X will have 4 dimensions: [N, C, H, W] where N is the batch size, C is the number +// of channels, H and W are the height and width dimensions on which to apply the convolutions. +// The kernel will have shape [M, C, H, W]. +func (c *Conv) applyConv2D(x, kernel tensor.Tensor) (tensor.Tensor, error) { + outputShape := c.getOutputShape(x, kernel) + out := tensor.Tensor(tensor.NewDense(x.Dtype(), outputShape)) + out.Zero() + + outputHDim := outputShape[nNonSpatialDims] + outputWDim := outputShape[nNonSpatialDims+1] + + paddedX, err := c.padInput(x) + if err != nil { + return nil, err + } + + nBatches := x.Shape()[0] + nKernels := kernel.Shape()[0] + + for batchIdx := 0; batchIdx < nBatches; batchIdx++ { + for kernelIdx := 0; kernelIdx < nKernels; kernelIdx++ { + subKernelView, err := kernel.Slice(ops.NewSlicer(kernelIdx, kernelIdx+1)) + if err != nil { + return nil, err + } + + subKernel := subKernelView.Materialize() + + // Loop over all 2D subImages of the input image and compute the convolution + // for that subImage. Store the result at the right place in the output tensor. + for h := 0; h < paddedX.Shape()[2]; h += c.strides[0] { + dimHOutputIdx := h / c.strides[0] + if dimHOutputIdx >= outputHDim { + continue + } + + for w := 0; w < paddedX.Shape()[2]; w += c.strides[1] { + dimWOutputIdx := w / c.strides[1] + if dimWOutputIdx >= outputWDim { + continue + } + + subImage, err := c.getSubImage(paddedX, batchIdx, h, w) + if err != nil { + return nil, err + } + + subImage, subKernel, err = ops.UnidirectionalBroadcast(subImage, subKernel) + if err != nil { + return nil, err + } + + convResult, err := tensor.Mul(subImage, subKernel) + if err != nil { + return nil, err + } + + convValue, err := tensor.Sum(convResult) + if err != nil { + return nil, err + } + + err = out.SetAt(convValue.ScalarValue(), batchIdx, kernelIdx, dimHOutputIdx, dimWOutputIdx) + if err != nil { + return nil, err + } + } + } + } + } + + return out, nil +} + +// getOutputShape calculates the shape of the output tensor resulting from +// the convolution operation between `x` and `kernel`. +// `x` has shape [N, C, H, W, ...] and `kernel` has shape [M, C, H, W, ...]. +// The output shape will be [N, M, newH, newW, ...], where values like `newH` +// are calculated based on the input shape, kernel size, padding and strides. +func (c *Conv) getOutputShape(x, kernel tensor.Tensor) tensor.Shape { + outputShape := make([]int, len(x.Shape())) + + outputShape[0] = x.Shape()[0] + outputShape[1] = kernel.Shape()[0] + + nSpatialDims := len(x.Shape()) - nNonSpatialDims + for i := 0; i < nSpatialDims; i++ { + inputDim := x.Shape()[nNonSpatialDims+i] + kernelDim := c.kernelShape[i] + outputShape[nNonSpatialDims+i] = ((inputDim - kernelDim + c.pads[i] + c.pads[i+nSpatialDims]) / c.strides[i]) + 1 + } + + return outputShape +} + +// padInput pads the input with zeros according to the `pads` attribute. +// The pad attribute specifies how many zeros should be added before and +// after the values in that specific dimension. +// Please note that according to ONNX specs, the `pads` attributes is an +// array with pads as [x1_begin, x2_begin, ..., x1_after, x2_after]. +// This method achieves padding by concatting tensors with zero values +// before and after each spatial dimension of the input tensor `x`. +func (c *Conv) padInput(x tensor.Tensor) (tensor.Tensor, error) { + var err error + + nSpatialDims := len(x.Shape()[nNonSpatialDims:]) + + for i := 0; i < nSpatialDims; i++ { + if c.pads[i] != 0 { + padsBeforeShape := x.Shape().Clone() + padsBeforeShape[nNonSpatialDims+i] = c.pads[i] + zerosBefore := tensor.Tensor(tensor.NewDense(x.Dtype(), padsBeforeShape)) + zerosBefore.Zero() + + x, err = tensor.Concat(nNonSpatialDims+i, zerosBefore, x) + if err != nil { + return nil, err + } + } + + if c.pads[i+nSpatialDims] != 0 { + padsAfterShape := x.Shape().Clone() + padsAfterShape[nNonSpatialDims+i] = c.pads[i+nSpatialDims] + zerosAfter := tensor.Tensor(tensor.NewDense(x.Dtype(), padsAfterShape)) + zerosAfter.Zero() + + x, err = tensor.Concat(nNonSpatialDims+i, x, zerosAfter) + if err != nil { + return nil, err + } + } + } + + return x, nil +} + +// getSubImage returns a the subimage for a specific example in the batch, based on the +// kernel shape and the given start coordinates. The resulting sub image will be of +// shape [C, kernelShape[0], kernelShape[1], ...]. +func (c *Conv) getSubImage(x tensor.Tensor, batchIdx int, startSpatialCoords ...int) (tensor.Tensor, error) { + if len(startSpatialCoords) != len(c.kernelShape) { + return nil, ops.ErrDimension("expected the coordinates to have the same number of dimensions as the kernel") + } + + slices := []tensor.Slice{ + ops.NewSlicer(batchIdx, batchIdx+1), + nil, // Take all channels at once. + } + + for i := 0; i < len(c.kernelShape); i++ { + dimStartIdx := startSpatialCoords[i] + dimKernelSize := c.kernelShape[i] + slices = append(slices, ops.NewSlicer(dimStartIdx, dimStartIdx+dimKernelSize)) + } + + subImage, err := x.Slice(slices...) + if err != nil { + return nil, err + } + + return subImage.Materialize(), nil +} + +// addBias adds a bias to the output of the convolution. It reshapes the +// bias such that it can be broadcasted, and then is added to the output +// tensor. +func (c *Conv) addBias(out, bias tensor.Tensor) (tensor.Tensor, error) { + biasShape := make([]int, len(out.Shape())) + for i := 0; i < len(out.Shape()); i++ { + biasShape[i] = 1 + } + + biasShape[1] = bias.Shape()[0] + + err := bias.Reshape(biasShape...) + if err != nil { + return nil, err + } + + out, bias, err = ops.UnidirectionalBroadcast(out, bias) + if err != nil { + return nil, err + } + + return tensor.Add(out, bias) +} diff --git a/ops/opset13/conv_test.go b/ops/opset13/conv_test.go new file mode 100644 index 0000000..8da4b87 --- /dev/null +++ b/ops/opset13/conv_test.go @@ -0,0 +1,692 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestConvInit(t *testing.T) { + c := &Conv{} + err := c.Init(Conv2DOnnxNodeProtoFixture()) + + assert.Nil(t, err) + + var expectedAutopad AutoPadSetting = "VALID" + + assert.Equal(t, expectedAutopad, c.autoPad) + assert.Equal(t, []int{1, 1}, c.dilations) + assert.Equal(t, []int{2, 2}, c.kernelShape) + assert.Equal(t, []int{1, 2}, c.pads) + assert.Equal(t, []int{1, 1}, c.strides) +} + +func TestConvInitUnsupported(t *testing.T) { + c := &Conv{} + err := c.Init(ConvUnsupportedOnnxNodeProtoFixture()) + + assert.Equal( + t, + err, + ops.ErrUnsupportedAttribute("group", c), + ) +} + +func TestConv(t *testing.T) { + tests := []struct { + conv *Conv + shapes [][]int + backings [][]float32 + expectedShape tensor.Shape + expected []float32 + }{ + // Test 1D Convolution. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{3}, + pads: []int{0, 0}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 6}, {1, 1, 3}}, + [][]float32{{0, 1, 2, 3, 4, 5}, {1, 1, 1}}, + []int{1, 1, 4}, + []float32{3, 6, 9, 12}, + }, + // Test 2D Convolution. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, + []int{1, 1, 2, 2}, + []float32{8, 12, 20, 24}, + }, + // Test SAME_LOWER autopad setting. + { + &Conv{ + autoPad: "SAME_LOWER", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, + []int{1, 1, 3, 3}, + []float32{0, 1, 3, 3, 8, 12, 9, 20, 24}, + }, + // Test SAME_UPPER autopad setting. + { + &Conv{ + autoPad: "SAME_UPPER", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, + []int{1, 1, 3, 3}, + []float32{8, 12, 7, 20, 24, 13, 13, 15, 8}, + }, + // Test VALID autopad setting. + { + &Conv{ + autoPad: "VALID", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, + []int{1, 1, 3, 3}, + []float32{8, 12, 7, 20, 24, 13, 13, 15, 8}, + }, + // Test dilation attribute. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{2, 2}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 4, 4}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {1, 1, 1, 1}}, + []int{1, 1, 2, 2}, + []float32{20, 24, 36, 40}, + }, + // Test pads attribute. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{1, 1}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{1, 1, 2, 2}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 2, 2}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, + []int{1, 1, 4, 4}, + []float32{0, 1, 1, 0, 2, 6, 4, 0, 2, 5, 3, 0, 0, 0, 0, 0}, + }, + // Test strides attribute. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{2, 2}, + }, + [][]int{{1, 1, 4, 4}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {1, 1, 1, 1}}, + []int{1, 1, 2, 2}, + []float32{10, 18, 42, 50}, + }, + // Test batch dimension. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + [][]int{{2, 1, 3, 3}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, {1, 1, 1, 1}}, + []int{2, 1, 2, 2}, + []float32{8, 12, 20, 24, 44, 48, 56, 60}, + }, + // Test 2D convolution with multiple channels. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + [][]int{{1, 2, 3, 3}, {1, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, {1, 1, 1, 1}}, + []int{1, 1, 2, 2}, + []float32{52, 60, 76, 84}, + }, + // Test multiple kernels. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 3, 3}, {2, 1, 2, 2}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 2, 2, 2, 2}}, + []int{1, 2, 2, 2}, + []float32{8, 12, 20, 24, 16, 24, 40, 48}, + }, + // Test bias. + { + &Conv{ + autoPad: "NOTSET", + dilations: []int{}, + group: 1, + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}, {1}}, + [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}, {0.5}}, + []int{1, 1, 2, 2}, + []float32{8.5, 12.5, 20.5, 24.5}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + nil, + } + + if len(test.backings) == 3 { + inputs[2] = ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...) + } + + res, err := test.conv.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expectedShape, res[0].Shape()) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationConv(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + nil, + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{5, 6}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidOptionalInputCount(1, &Conv{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Conv{}), + }, + } + + for _, test := range tests { + conv := &Conv{} + validated, err := conv.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func TestSetDefaultDilations(t *testing.T) { + c := &Conv{} + x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) + + c.setDefaultDilations(x) + + assert.Equal(t, []int{1, 1}, c.dilations) +} + +func TestSetKernelShape(t *testing.T) { + c := &Conv{} + kernel := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3}, 1, 1, 2, 2) + + c.setKernelShape(kernel) + + assert.Equal(t, []int{2, 2}, c.kernelShape) +} + +func TestSetDefaultPaddings(t *testing.T) { + c := &Conv{} + x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) + + c.setDefaultPaddings(x) + + assert.Equal(t, []int{0, 0, 0, 0}, c.pads) +} + +func TestSetDefaultStrides(t *testing.T) { + c := &Conv{} + x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) + + c.setDefaultStrides(x) + + assert.Equal(t, []int{1, 1}, c.strides) +} + +func TestSetPaddingWithAutoPad(t *testing.T) { + x := ops.TensorWithBackingFixture([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 3, 3) + + tests := []struct { + setting AutoPadSetting + expectedPads []int + }{ + {"NOTSET", []int{0, 0, 0, 0}}, + {"SAME_LOWER", []int{1, 1, 0, 0}}, + {"SAME_UPPER", []int{0, 0, 1, 1}}, + {"VALID", []int{0, 0, 1, 1}}, + } + + for _, test := range tests { + conv := &Conv{ + autoPad: test.setting, + pads: []int{0, 0, 0, 0}, + kernelShape: []int{2, 2}, + strides: []int{1, 1}, + } + conv.setPaddingWithAutoPad(x) + + assert.Equal(t, test.expectedPads, conv.pads) + } +} + +func TestGetDilatedKernel(t *testing.T) { + tests := []struct { + dilations []int + kernelShape []int + kernelBacking []float32 + expectedShape tensor.Shape + expectedBacking []float32 + }{ + { + []int{1}, + []int{1, 1, 3}, + []float32{1, 1, 1}, + []int{1, 1, 3}, + []float32{1, 1, 1}, + }, + { + []int{2}, + []int{1, 1, 3}, + []float32{1, 1, 1}, + []int{1, 1, 5}, + []float32{1, 0, 1, 0, 1}, + }, + { + []int{2, 1}, + []int{1, 1, 2, 2}, + []float32{1, 1, 1, 1}, + []int{1, 1, 3, 2}, + []float32{1, 1, 0, 0, 1, 1}, + }, + { + []int{1, 2}, + []int{1, 1, 2, 2}, + []float32{1, 1, 1, 1}, + []int{1, 1, 2, 3}, + []float32{1, 0, 1, 1, 0, 1}, + }, + { + []int{2, 2}, + []int{1, 1, 3, 3}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, + []int{1, 1, 5, 5}, + []float32{0, 0, 1, 0, 2, 0, 0, 0, 0, 0, 3, 0, 4, 0, 5, 0, 0, 0, 0, 0, 6, 0, 7, 0, 8}, + }, + { + []int{3, 2}, + []int{1, 1, 2, 3}, + []float32{1, 2, 3, 4, 5, 6}, + []int{1, 1, 4, 5}, + []float32{1, 0, 2, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 5, 0, 6}, + }, + } + + for _, test := range tests { + conv := &Conv{ + dilations: test.dilations, + kernelShape: []int{2, 2}, + } + kernel := ops.TensorWithBackingFixture(test.kernelBacking, test.kernelShape...) + + dilatedKernel, err := conv.getDilatedKernel(kernel) + assert.Nil(t, err) + + assert.Equal(t, test.expectedShape, dilatedKernel.Shape()) + assert.Equal(t, test.expectedBacking, dilatedKernel.Data()) + } +} + +func TestGetOutputShape(t *testing.T) { + tests := []struct { + conv *Conv + xShape []int + xBacking []float32 + kernelShape []int + kernelBacking []float32 + expected tensor.Shape + }{ + { + &Conv{ + kernelShape: []int{3}, + pads: []int{0, 0}, + strides: []int{1}, + }, + []int{1, 1, 6}, + []float32{0, 1, 2, 3, 4, 5}, + []int{1, 1, 3}, + []float32{1, 1, 1}, + []int{1, 1, 4}, + }, + { + &Conv{ + kernelShape: []int{3}, + pads: []int{1, 2}, + strides: []int{2}, + }, + []int{1, 1, 6}, + []float32{0, 1, 2, 3, 4, 5}, + []int{1, 1, 3}, + []float32{1, 1, 1}, + []int{1, 1, 4}, + }, + { + &Conv{ + kernelShape: []int{2, 2}, + pads: []int{1, 2, 1, 2}, + strides: []int{2, 1}, + }, + []int{1, 1, 4, 4}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + []int{1, 1, 2, 2}, + []float32{1, 1, 1, 1}, + []int{1, 1, 3, 7}, + }, + { + &Conv{ + kernelShape: []int{2, 2}, + pads: []int{0, 0, 0, 0}, + strides: []int{1, 1}, + }, + []int{1, 1, 4, 4}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + []int{1, 1, 2, 2}, + []float32{1, 1, 1, 1}, + []int{1, 1, 3, 3}, + }, + } + + for _, test := range tests { + outputShape := test.conv.getOutputShape( + ops.TensorWithBackingFixture(test.xBacking, test.xShape...), + ops.TensorWithBackingFixture(test.kernelBacking, test.kernelShape...), + ) + + assert.Equal(t, test.expected, outputShape) + } +} + +func TestPadInput(t *testing.T) { + tests := []struct { + conv *Conv + xShape []int + xBacking []float32 + expectedShape tensor.Shape + expectedBacking []float32 + }{ + { + &Conv{ + pads: []int{0, 0}, + }, + []int{1, 1, 6}, + []float32{0, 1, 2, 3, 4, 5}, + []int{1, 1, 6}, + []float32{0, 1, 2, 3, 4, 5}, + }, + { + &Conv{ + pads: []int{1, 2}, + }, + []int{1, 1, 6}, + []float32{0, 1, 2, 3, 4, 5}, + []int{1, 1, 9}, + []float32{0, 0, 1, 2, 3, 4, 5, 0, 0}, + }, + { + &Conv{ + pads: []int{1, 1, 1, 1}, + }, + []int{1, 1, 2, 2}, + []float32{1, 2, 3, 4}, + []int{1, 1, 4, 4}, + []float32{0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0}, + }, + { + &Conv{ + pads: []int{1, 0, 2, 0}, + }, + []int{1, 1, 2, 2}, + []float32{1, 2, 3, 4}, + []int{1, 1, 5, 2}, + []float32{0, 0, 1, 2, 3, 4, 0, 0, 0, 0}, + }, + } + + for _, test := range tests { + paddedX, err := test.conv.padInput( + ops.TensorWithBackingFixture(test.xBacking, test.xShape...), + ) + + assert.Nil(t, err) + assert.Equal(t, test.expectedShape, paddedX.Shape()) + assert.Equal(t, test.expectedBacking, paddedX.Data()) + } +} + +func TestGetSubImage(t *testing.T) { + tests := []struct { + conv *Conv + xShape []int + xBacking []float32 + batchIdx int + startSpatialCoords []int + expectedShape tensor.Shape + expectedBacking []float32 + }{ + { + &Conv{kernelShape: []int{2}}, + []int{1, 1, 3}, + []float32{0, 1, 2}, + 0, + []int{0}, + []int{1, 2}, + []float32{0, 1}, + }, + { + &Conv{kernelShape: []int{2}}, + []int{1, 2, 3}, + []float32{0, 1, 2, 3, 4, 5}, + 0, + []int{0}, + []int{2, 2}, + []float32{0, 1, 3, 4}, + }, + { + &Conv{kernelShape: []int{2, 2}}, + []int{1, 1, 3, 3}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, + 0, + []int{0, 0}, + []int{1, 2, 2}, + []float32{0, 1, 3, 4}, + }, + { + &Conv{kernelShape: []int{2, 2}}, + []int{1, 1, 3, 3}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, + 0, + []int{1, 1}, + []int{1, 2, 2}, + []float32{4, 5, 7, 8}, + }, + { + &Conv{kernelShape: []int{2}}, + []int{2, 1, 3}, + []float32{0, 1, 2, 3, 4, 5}, + 1, + []int{1}, + []int{1, 2}, + []float32{4, 5}, + }, + } + + for _, test := range tests { + subImage, err := test.conv.getSubImage( + ops.TensorWithBackingFixture(test.xBacking, test.xShape...), + test.batchIdx, + test.startSpatialCoords..., + ) + + assert.Nil(t, err) + assert.Equal(t, test.expectedShape, subImage.Shape()) + assert.Equal(t, test.expectedBacking, subImage.Data()) + } +} + +func TestAddBias(t *testing.T) { + tests := []struct { + conv *Conv + outShape []int + outBacking []float32 + biasShape []int + biasBacking []float32 + expected []float32 + }{ + { + &Conv{}, + []int{1, 1, 3}, + []float32{0, 1, 2}, + []int{1}, + []float32{0.5}, + []float32{0.5, 1.5, 2.5}, + }, + { + &Conv{}, + []int{1, 1, 3, 3}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8}, + []int{1}, + []float32{0.5}, + []float32{0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5}, + }, + { + &Conv{}, + []int{1, 2, 2, 2}, + []float32{0, 1, 2, 3, 4, 5, 6, 7}, + []int{2}, + []float32{-1, 1}, + []float32{-1, 0, 1, 2, 5, 6, 7, 8}, + }, + { + &Conv{}, + []int{2, 2, 2, 2}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + []int{2}, + []float32{-1, 1}, + []float32{-1, 0, 1, 2, 5, 6, 7, 8, 7, 8, 9, 10, 13, 14, 15, 16}, + }, + } + + for _, test := range tests { + out, err := test.conv.addBias( + ops.TensorWithBackingFixture(test.outBacking, test.outShape...), + ops.TensorWithBackingFixture(test.biasBacking, test.biasShape...), + ) + + assert.Nil(t, err) + assert.Equal(t, test.expected, out.Data()) + } +} + +func Conv2DOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("VALID")}, + {Name: "dilations", Ints: []int64{1, 1}}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{1, 2}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, + } +} + +func ConvUnsupportedOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "group", I: 2}, + }, + } +} diff --git a/ops/opset13/cos.go b/ops/opset13/cos.go new file mode 100644 index 0000000..ad01f82 --- /dev/null +++ b/ops/opset13/cos.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Cos represents the ONNX cos operator. +type Cos struct{} + +// newCos creates a new cos operator. +func newCos() ops.Operator { + return &Cos{} +} + +// Init initializes the cos operator. +func (c *Cos) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the cos operator. +func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(cos[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(cos[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Cos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Cos) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Cos) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Cos) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Cos) String() string { + return "cos operator" +} + +func cos[T ops.FloatType](x T) T { + return T(math.Cos(float64(x))) +} diff --git a/ops/opset13/cos_test.go b/ops/opset13/cos_test.go new file mode 100644 index 0000000..b1087c4 --- /dev/null +++ b/ops/opset13/cos_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestCosInit(t *testing.T) { + c := &Cos{} + + // since 'cos' does not have any attributes we pass in nil. This should not + // fail initializing the cos. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestCos(t *testing.T) { + tests := []struct { + cos *Cos + backing []float32 + shape []int + expected []float32 + }{ + { + &Cos{}, + []float32{-2, -1, 0, 1}, + []int{2, 2}, + []float32{-0.41614684, 0.5403023, 1, 0.5403023}, + }, + { + &Cos{}, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{0.5403023, -0.9899925, -0.6536436, 0.2836622}, + }, + { + &Cos{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{0.5403023, 0.5403023, 0.5403023, 0.5403023}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.cos.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationCos(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Cos{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Cos{}), + }, + } + + for _, test := range tests { + cos := &Cos{} + validated, err := cos.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/cosh.go b/ops/opset13/cosh.go new file mode 100644 index 0000000..cddb129 --- /dev/null +++ b/ops/opset13/cosh.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Cosh represents the ONNX cosh operator. +type Cosh struct{} + +// newCosh creates a new cosh operator. +func newCosh() ops.Operator { + return &Cosh{} +} + +// Init initializes the cosh operator. +func (c *Cosh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the cosh operator. +func (c *Cosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(cosh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(cosh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (c *Cosh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(c, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (c *Cosh) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (c *Cosh) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (c *Cosh) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (c *Cosh) String() string { + return "cosh operator" +} + +func cosh[T ops.FloatType](x T) T { + return T(math.Cosh(float64(x))) +} diff --git a/ops/opset13/cosh_test.go b/ops/opset13/cosh_test.go new file mode 100644 index 0000000..3359ada --- /dev/null +++ b/ops/opset13/cosh_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestCoshInit(t *testing.T) { + c := &Cosh{} + + // since 'cosh' does not have any attributes we pass in nil. This should not + // fail initializing the cosh. + err := c.Init(nil) + assert.Nil(t, err) +} + +func TestCosh(t *testing.T) { + tests := []struct { + cosh *Cosh + backing []float32 + shape []int + expected []float32 + }{ + { + &Cosh{}, + []float32{-2, -1, 0, 1}, + []int{2, 2}, + []float32{3.7621956, 1.5430807, 1, 1.5430807}, + }, + { + &Cosh{}, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{1.5430807, 10.067662, 27.308233, 74.209946}, + }, + { + &Cosh{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{1.5430807, 1.5430807, 1.5430807, 1.5430807}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.cosh.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationCosh(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Cosh{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Cosh{}), + }, + } + + for _, test := range tests { + cosh := &Cosh{} + validated, err := cosh.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/div.go b/ops/opset13/div.go index a047caf..e918e7f 100644 --- a/ops/opset13/div.go +++ b/ops/opset13/div.go @@ -1,11 +1,16 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinDivInputs = 2 + MaxDivInputs = 2 +) + // Div represents the ONNX div operator. type Div struct{} @@ -15,23 +20,18 @@ func newDiv() ops.Operator { } // Init initializes the div operator. -func (d *Div) Init(attributes []*onnx.AttributeProto) error { +func (d *Div) Init(*onnx.NodeProto) error { return nil } // Apply applies the div operator. func (d *Div) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - in1, in2, err := ops.MultidirectionalBroadcast(inputs[0], inputs[1]) - if err != nil { - return nil, err - } - - out, err := tensor.Div(in1, in2) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Div, + ops.MultidirectionalBroadcasting, + ) } // ValidateInputs validates the inputs that will be given to Apply for this operator. @@ -41,12 +41,12 @@ func (d *Div) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (d *Div) GetMinInputs() int { - return 2 + return MinDivInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (d *Div) GetMaxInputs() int { - return 2 + return MaxDivInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/div_test.go b/ops/opset13/div_test.go index 4a94e2c..06a4f45 100644 --- a/ops/opset13/div_test.go +++ b/ops/opset13/div_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -108,14 +107,14 @@ func TestInputValidationDiv(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("div operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Div{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("div operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Div{}), }, } @@ -124,6 +123,7 @@ func TestInputValidationDiv(t *testing.T) { validated, err := div.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/equal.go b/ops/opset13/equal.go new file mode 100644 index 0000000..db888b8 --- /dev/null +++ b/ops/opset13/equal.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinEqualInputs = 2 + MaxEqualInputs = 2 +) + +// Equal represents the ONNX equal operator. +type Equal struct{} + +// newEqual creates a new equal operator. +func newEqual() ops.Operator { + return &Equal{} +} + +// Init initializes the equal operator. +func (e *Equal) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the equal operator. +func (e *Equal) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Equal, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (e *Equal) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(e, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (e *Equal) GetMinInputs() int { + return MinEqualInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (e *Equal) GetMaxInputs() int { + return MaxEqualInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (e *Equal) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (e *Equal) String() string { + return "equal operator" +} diff --git a/ops/opset13/equal_test.go b/ops/opset13/equal_test.go new file mode 100644 index 0000000..9014e78 --- /dev/null +++ b/ops/opset13/equal_test.go @@ -0,0 +1,133 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestEqualInit(t *testing.T) { + e := &Equal{} + + // since 'equal' does not have any attributes we pass in nil. This should not + // fail initializing the equal. + err := e.Init(ops.EmptyNodeProto()) + assert.Nil(t, err) +} + +func TestEqual(t *testing.T) { + tests := []struct { + equal *Equal + backings [][]float32 + shapes [][]int + expected []bool + }{ + { + &Equal{}, + [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, + [][]int{{2, 2}, {2, 2}}, + []bool{false, true, false, false}, + }, + { + &Equal{}, + [][]float32{{0, 1, 2, 2, 4, 5}, {2, 2, 2, 2, 2, 2}}, + [][]int{{3, 2}, {3, 2}}, + []bool{false, false, true, true, false, false}, + }, + { + &Equal{}, + [][]float32{{0, 1}, {0, 1, 0, 1}}, + [][]int{{2}, {2, 2}}, + []bool{true, true, true, true}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.equal.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationEqual(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + ops.TensorWithBackingFixture([]uint32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + ops.TensorWithBackingFixture([]uint64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &Equal{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Equal{}), + }, + } + + for _, test := range tests { + equal := &Equal{} + validated, err := equal.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/gather.go b/ops/opset13/gather.go index e278974..e6e7f3f 100644 --- a/ops/opset13/gather.go +++ b/ops/opset13/gather.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinGatherInputs = 2 + MaxGatherInputs = 2 +) + // Gather represents the ONNX gather operator. type Gather struct { axis int // axis to gather on, default is 0 @@ -15,24 +18,27 @@ type Gather struct { // newGather creates a new gather operator. func newGather() ops.Operator { - return &Gather{} + return &Gather{ + axis: 0, + } } // Init initializes the gather operator. -func (g *Gather) Init(attributes []*onnx.AttributeProto) error { - switch length := len(attributes); { - case length > 1: - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, g, "0 or 1", len(attributes)) - case length == 1: +func (g *Gather) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + + if len(attributes) == 1 { attr := attributes[0] + if attr.GetName() == "axis" { g.axis = int(attr.GetI()) } else { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, g, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), g) } - default: - g.axis = 0 + } else if len(attributes) > 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), g) } + return nil } @@ -43,6 +49,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if err != nil { return nil, err } + indices := tensor.New(tensor.WithBacking(indicesData), tensor.WithShape(inputs[1].Shape()...)) data := inputs[0] @@ -50,8 +57,9 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // Make sure axis is in the correct range (according to the size of the data tensor) rank := len(data.Shape()) dataAxis := g.axis + if dataAxis < -rank || dataAxis > rank-1 { - return nil, fmt.Errorf(ops.AxisOutOfRangeErrTemplate, rank, rank, dataAxis) + return nil, ops.ErrAxisOutOfRange(rank, rank, dataAxis) } // Offset axis if a negative index is given. if dataAxis < 0 { @@ -62,9 +70,13 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // dimension which is selected by `axis`) axisDimSize := data.Shape()[dataAxis] if !ops.AllInRange(indicesData, -axisDimSize, axisDimSize-1) { - return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, axisDimSize, axisDimSize) + return nil, ops.ErrNotAllAxesInRange(axisDimSize, axisDimSize) + } + + err = ops.OffsetTensorIfNegative(indices, axisDimSize) + if err != nil { + return nil, err } - ops.OffsetTensorIfNegative(indices, axisDimSize) // Make the shape of the output tensor os := insertWithReplace(indices.Shape(), data.Shape(), dataAxis) @@ -75,6 +87,7 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if err != nil { return nil, err } + return []tensor.Tensor{output}, nil } @@ -85,12 +98,12 @@ func (g *Gather) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *Gather) GetMinInputs() int { - return 2 + return MinGatherInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *Gather) GetMaxInputs() int { - return 2 + return MaxGatherInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -118,26 +131,33 @@ func (g *Gather) String() string { // Then output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[j_{0}, k, j_{1}, ..., j_{r-2}] // -------------------------- // where q: size of `indices` -// r: size of `data` -// i and j are here indices which should be iterated over. +// +// r: size of `data` +// i and j are here indices which should be iterated over. // // A simplified example of how i and j work in such a statement (not related to gather): // suppose x = [1, 2] and y = [4, 5], and we have statement: -// l = x[i_0] -// output[i_0, j_0] = y[j_0] - l +// +// l = x[i_0] +// output[i_0, j_0] = y[j_0] - l +// // This means, for each valid combination of (i_0, j_0) (in this case (0,0) (0,1), (1,0) (1,1) ) // we evaluate the expression, so: -// l = x[0] -> l = 1 -// output[0, 0] = y[0] - l -> output[0,0] = 4 - 1 = 3 -// l = x[0] -> l = 1 -// output[0, 1] = y[1] - l -> output[0,1] = 5 - 1 = 4 -// l = x[1] -> l = 2 -// output[1, 0] = y[0] - l -> output[1,0] = 4 - 2 = 2 -// l = x[1] -> l = 2 -// output[1, 1] = y[1] - l -> output[1,1] = 5 - 2 = 3 +// +// l = x[0] -> l = 1 +// output[0, 0] = y[0] - l -> output[0,0] = 4 - 1 = 3 +// l = x[0] -> l = 1 +// output[0, 1] = y[1] - l -> output[0,1] = 5 - 1 = 4 +// l = x[1] -> l = 2 +// output[1, 0] = y[0] - l -> output[1,0] = 4 - 2 = 2 +// l = x[1] -> l = 2 +// output[1, 1] = y[1] - l -> output[1,1] = 5 - 2 = 3 +// // so this results in: -// output = [ 3 4 ] -// [ 2 3 ] +// +// output = [ 3 4 ] +// [ 2 3 ] +// // ------------------------- // The implementation iterates over each element in 'indices', and k is extracted. // For each given k (and therefore also [i_0, ..., i_q-1]) we need to iterate over each combination @@ -145,13 +165,20 @@ func (g *Gather) String() string { // slicing to extract the blocks that we need to assign, and then pairwise assign them. func gather(out, data, indices tensor.Tensor, axis int) error { it := indices.Iterator() - for it.Reset(); !it.Done(); it.Next() { + it.Reset() + + for !it.Done() { coords := it.Coord() + at, err := indices.At(coords...) if err != nil { return err } - k := at.(int) + + k, ok := at.(int) + if !ok { + return ops.ErrTypeAssert("int", at) + } // Slice that selects `k` on the given axis. // Equivalent to: data[:, ... , :, k, :, ..., :], where `k` is on the index `axis` @@ -168,12 +195,18 @@ func gather(out, data, indices tensor.Tensor, axis int) error { for i, s := range coords { oslices[i+axis] = ops.NewSlicer(s) } + outputSlice, _ := out.Slice(oslices...) err = ops.PairwiseAssign(outputSlice, dataSlice) if err != nil { return err } + + _, err = it.Next() + if err != nil { + return err + } } return nil @@ -185,10 +218,11 @@ func gather(out, data, indices tensor.Tensor, axis int) error { // Example: // > a = [-1, -2, -3] // > x = [1, 2, 3, 4, 5, 6, 7] -// insertWithReplace(a, x, 3) -> [1, 2, 3, -1, -2, -3, 5, 6, 7] +// insertWithReplace(a, x, 3) -> [1, 2, 3, -1, -2, -3, 5, 6, 7]. func insertWithReplace(a, x []int, axis int) []int { y := append([]int{}, x[:axis]...) y = append(y, a...) + if axis+1 < len(x) { y = append(y, x[axis+1:]...) } diff --git a/ops/opset13/gather_test.go b/ops/opset13/gather_test.go index 5186dd2..e48925a 100644 --- a/ops/opset13/gather_test.go +++ b/ops/opset13/gather_test.go @@ -1,17 +1,18 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) -func makeAxisProto(n int) []*onnx.AttributeProto { - return []*onnx.AttributeProto{{Name: "axis", I: int64(n)}} +func makeAxisProto(n int) *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{{Name: "axis", I: int64(n)}}, + } } func TestGatherInit(t *testing.T) { @@ -24,21 +25,21 @@ func TestGatherInit(t *testing.T) { func TestGatherInitDefault(t *testing.T) { op := Gather{} - err := op.Init([]*onnx.AttributeProto{}) + err := op.Init(ops.EmptyNodeProto()) assert.NoError(t, err) assert.Equal(t, op.axis, 0) } func TestGatherInitTooManyAttrs(t *testing.T) { op := Gather{} - err := op.Init([]*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}) - assert.EqualError(t, err, "gather operator: expected 0 or 1 attributes, got 2") + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}}) + assert.EqualError(t, err, "gather operator attribute error: invalid count 2 expected 1") } func TestGatherInitInvalidAttrName(t *testing.T) { op := Gather{} - err := op.Init([]*onnx.AttributeProto{{Name: "axes"}}) // should be axis - assert.EqualError(t, err, "gather operator: unknown attribute: axes") + err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axes"}}}) // should be axis + assert.EqualError(t, err, "gather operator attribute error: invalid attribute axes") } func TestGather(t *testing.T) { @@ -63,53 +64,125 @@ func TestGather(t *testing.T) { // >>> np.take(x, i, axis=0).shape // Out: (1, 2) - {[]float32{1, 2, 3, 4}, []int{4}, - []int64{0}, []int{1}, 0, - []float32{1}, tensor.Shape([]int{1})}, + { + []float32{1, 2, 3, 4}, + []int{4}, + []int64{0}, + []int{1}, + 0, + []float32{1}, + tensor.Shape([]int{1}), + }, - {[]float32{1, 2, 3, 4}, []int{2, 2}, - []int64{0}, []int{1}, 0, - []float32{1, 2}, tensor.Shape([]int{1, 2})}, + { + []float32{1, 2, 3, 4}, + []int{2, 2}, + []int64{0}, + []int{1}, + 0, + []float32{1, 2}, + tensor.Shape([]int{1, 2}), + }, - {[]float32{1, 2, 3, 4}, []int{2, 2}, - []int64{0}, []int{1}, 1, - []float32{1, 3}, tensor.Shape([]int{2, 1})}, + { + []float32{1, 2, 3, 4}, + []int{2, 2}, + []int64{0}, + []int{1}, + 1, + []float32{1, 3}, + tensor.Shape([]int{2, 1}), + }, - {[]float32{1, 2, 3, 4}, []int{2, 2}, - []int64{0}, []int{1}, -1, - []float32{1, 3}, tensor.Shape([]int{2, 1})}, + { + []float32{1, 2, 3, 4}, + []int{2, 2}, + []int64{0}, + []int{1}, + -1, + []float32{1, 3}, + tensor.Shape([]int{2, 1}), + }, - {[]float32{1, 2, 3, 4}, []int{2, 2}, - []int64{1}, []int{1}, 1, - []float32{2, 4}, tensor.Shape([]int{2, 1})}, + { + []float32{1, 2, 3, 4}, + []int{2, 2}, + []int64{1}, + []int{1}, + 1, + []float32{2, 4}, + tensor.Shape([]int{2, 1}), + }, - {[]float32{1, 2, 3, 4}, []int{2, 2}, - []int64{0}, []int{1, 1}, 1, - []float32{1, 3}, tensor.Shape([]int{2, 1, 1})}, + { + []float32{1, 2, 3, 4}, + []int{2, 2}, + []int64{0}, + []int{1, 1}, + 1, + []float32{1, 3}, + tensor.Shape([]int{2, 1, 1}), + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, []int{3, 2, 2}, - []int64{0}, []int{1}, 2, - []float32{1, 3, 5, 7, 9, 11}, tensor.Shape([]int{3, 2, 1})}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + []int{3, 2, 2}, + []int64{0}, + []int{1}, + 2, + []float32{1, 3, 5, 7, 9, 11}, + tensor.Shape([]int{3, 2, 1}), + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, []int{3, 2, 2}, - []int64{0}, []int{1}, 1, - []float32{1, 2, 5, 6, 9, 10}, tensor.Shape([]int{3, 1, 2})}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + []int{3, 2, 2}, + []int64{0}, + []int{1}, + 1, + []float32{1, 2, 5, 6, 9, 10}, + tensor.Shape([]int{3, 1, 2}), + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, []int{3, 3}, - []int64{0, 2}, []int{1, 2}, 1, - []float32{1, 3, 4, 6, 7, 9}, tensor.Shape([]int{3, 1, 2})}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + []int{3, 3}, + []int64{0, 2}, + []int{1, 2}, + 1, + []float32{1, 3, 4, 6, 7, 9}, + tensor.Shape([]int{3, 1, 2}), + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, []int{3, 2, 2}, - []int64{-2}, []int{1}, 1, - []float32{1, 2, 5, 6, 9, 10}, tensor.Shape([]int{3, 1, 2})}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + []int{3, 2, 2}, + []int64{-2}, + []int{1}, + 1, + []float32{1, 2, 5, 6, 9, 10}, + tensor.Shape([]int{3, 1, 2}), + }, - {[]float32{1, 2, 3, 4}, []int{4}, - []int64{-4}, []int{1}, 0, - []float32{1}, tensor.Shape([]int{1})}, + { + []float32{1, 2, 3, 4}, + []int{4}, + []int64{-4}, + []int{1}, + 0, + []float32{1}, + tensor.Shape([]int{1}), + }, - {[]float32{1, 2, 3, 4}, []int{2, 2}, - []int64{0}, []int{1}, -1, - []float32{1, 3}, tensor.Shape([]int{2, 1})}, + { + []float32{1, 2, 3, 4}, + []int{2, 2}, + []int64{0}, + []int{1}, + -1, + []float32{1, 3}, + tensor.Shape([]int{2, 1}), + }, } for _, test := range tests { @@ -130,7 +203,7 @@ func TestGather(t *testing.T) { func TestCombinedWithOtherOp(t *testing.T) { concat := &Concat{} - err := concat.Init([]*onnx.AttributeProto{{Name: "axis", I: 0}}) + err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 0}}}) assert.NoError(t, err) data0 := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) @@ -170,7 +243,7 @@ func TestGatherAxesIndexOutOfRange(t *testing.T) { _, err = op.Apply([]tensor.Tensor{dataIn, indicesIn}) assert.Error(t, err) - assert.EqualError(t, err, "axis argument must be in the range -1 <= x < 1, was 1") + assert.EqualError(t, err, "axis out of range: axis argument must be in the range -1 <= x < 1, was 1") } func TestGatherIndexOutOfRange(t *testing.T) { @@ -181,7 +254,7 @@ func TestGatherIndexOutOfRange(t *testing.T) { _, err := op.Apply([]tensor.Tensor{dataIn, indicesIn}) assert.Error(t, err) - assert.EqualError(t, err, "all indices entries must be in the range -1 <= x < 1") + assert.EqualError(t, err, "axis out of range: all indices entries must be in the range -1 <= x < 1") } func TestInputValidationGather(t *testing.T) { @@ -205,14 +278,14 @@ func TestInputValidationGather(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("gather operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Gather{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), }, - fmt.Errorf("gather operator: input 1 does not allow type float32"), + ops.ErrInvalidInputType(1, "float32", &Gather{}), }, } @@ -221,6 +294,7 @@ func TestInputValidationGather(t *testing.T) { validated, err := gather.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/gemm.go b/ops/opset13/gemm.go index a3382a5..2db2a44 100644 --- a/ops/opset13/gemm.go +++ b/ops/opset13/gemm.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinGemmInputs = 2 + MaxGemmInputs = 3 +) + // Gemm represents the ONNX gemm operator. type Gemm struct { alpha float32 @@ -27,8 +30,8 @@ func newGemm() ops.Operator { } // Init initializes the Gemm operator based on the ModelProto attributes. -func (g *Gemm) Init(attributes []*onnx.AttributeProto) error { - for _, attr := range attributes { +func (g *Gemm) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { switch attr.GetName() { case "alpha": g.alpha = attr.GetF() @@ -39,7 +42,7 @@ func (g *Gemm) Init(attributes []*onnx.AttributeProto) error { case "transB": g.transB = ops.Int64ToBool(attr.GetI()) default: - return fmt.Errorf(ops.UnknownAttributeErrTemplate, g, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), g) } } @@ -49,6 +52,7 @@ func (g *Gemm) Init(attributes []*onnx.AttributeProto) error { // Apply applies the gemm operator on the given graph. func (g *Gemm) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var err error + a := inputs[0] b := inputs[1] c := inputs[2] @@ -107,12 +111,12 @@ func (g *Gemm) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *Gemm) GetMinInputs() int { - return 2 + return MinGemmInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *Gemm) GetMaxInputs() int { - return 3 + return MaxGemmInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/gemm_test.go b/ops/opset13/gemm_test.go index 2766da2..37255d4 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/opset13/gemm_test.go @@ -1,18 +1,17 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) func TestGemmInit(t *testing.T) { gemm := Gemm{} - err := gemm.Init(GemmOnnxAttributeProtoFixture()) + err := gemm.Init(GemmOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, float32(10.0), gemm.alpha) @@ -23,9 +22,9 @@ func TestGemmInit(t *testing.T) { func TestGemmInitFail(t *testing.T) { gemm := &Gemm{} - err := gemm.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}}) + err := gemm.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}}}) - expected := fmt.Errorf(ops.UnknownAttributeErrTemplate, gemm, "unknownAttribute") + expected := ops.ErrInvalidAttribute("unknownAttribute", gemm) assert.Equal(t, expected, err) } @@ -68,7 +67,8 @@ func TestGemm(t *testing.T) { { &Gemm{1, 1, false, false}, [][]int{{20, 4}, {4, 6}, {6}}, - []float32{84, 91, 98, 105, 112, 119, 228, 251, 274, + []float32{ + 84, 91, 98, 105, 112, 119, 228, 251, 274, 297, 320, 343, 372, 411, 450, 489, 528, 567, 516, 571, 626, 681, 736, 791, 660, 731, 802, 873, 944, 1015, 804, 891, 978, 1065, 1152, 1239, 948, 1051, 1154, 1257, @@ -80,7 +80,8 @@ func TestGemm(t *testing.T) { 2331, 2562, 2793, 3024, 3255, 2244, 2491, 2738, 2985, 3232, 3479, 2388, 2651, 2914, 3177, 3440, 3703, 2532, 2811, 3090, 3369, 3648, 3927, 2676, 2971, 3266, 3561, - 3856, 4151, 2820, 3131, 3442, 3753, 4064, 4375}, + 3856, 4151, 2820, 3131, 3442, 3753, 4064, 4375, + }, }, } @@ -94,6 +95,7 @@ func TestGemm(t *testing.T) { } else { inputs = append(inputs, nil) } + res, err := test.gemm.Apply(inputs) assert.Nil(t, err) @@ -132,7 +134,7 @@ func TestInputValidationGemm(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - fmt.Errorf("gemm operator: expected 2-3 input tensors, got 1"), + ops.ErrInvalidOptionalInputCount(1, &Gemm{}), }, { []tensor.Tensor{ @@ -142,7 +144,7 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, - fmt.Errorf("gemm operator: expected 2-3 input tensors, got 4"), + ops.ErrInvalidOptionalInputCount(4, &Gemm{}), }, { []tensor.Tensor{ @@ -150,7 +152,7 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("gemm operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Gemm{}), }, } @@ -159,6 +161,7 @@ func TestInputValidationGemm(t *testing.T) { validated, err := gemm.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) @@ -169,11 +172,13 @@ func TestInputValidationGemm(t *testing.T) { } } -func GemmOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "alpha", F: 10.0}, - {Name: "beta", F: 0.98}, - {Name: "transA", I: 1}, - {Name: "transB", I: 1}, +func GemmOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 10.0}, + {Name: "beta", F: 0.98}, + {Name: "transA", I: 1}, + {Name: "transB", I: 1}, + }, } } diff --git a/ops/opset13/greater.go b/ops/opset13/greater.go new file mode 100644 index 0000000..37e5af4 --- /dev/null +++ b/ops/opset13/greater.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinGreaterInputs = 2 + MaxGreaterInputs = 2 +) + +// Greater represents the ONNX greater operator. +type Greater struct{} + +// newGreater creates a new greater operator. +func newGreater() ops.Operator { + return &Greater{} +} + +// Init initializes the greater operator. +func (g *Greater) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the greater operator. +func (g *Greater) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Gt, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *Greater) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *Greater) GetMinInputs() int { + return MinGreaterInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *Greater) GetMaxInputs() int { + return MaxGreaterInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *Greater) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *Greater) String() string { + return "greater operator" +} diff --git a/ops/opset13/greater_or_equal.go b/ops/opset13/greater_or_equal.go new file mode 100644 index 0000000..25eb27b --- /dev/null +++ b/ops/opset13/greater_or_equal.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinGreaterOrEqualInputs = 2 + MaxGreaterOrEqualInputs = 2 +) + +// GreaterOrEqual represents the ONNX greaterOrEqual operator. +type GreaterOrEqual struct{} + +// newGreaterOrEqual creates a new greaterOrEqual operator. +func newGreaterOrEqual() ops.Operator { + return &GreaterOrEqual{} +} + +// Init initializes the greaterOrEqual operator. +func (g *GreaterOrEqual) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the greaterOrEqual operator. +func (g *GreaterOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Gte, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (g *GreaterOrEqual) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(g, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (g *GreaterOrEqual) GetMinInputs() int { + return MinGreaterOrEqualInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (g *GreaterOrEqual) GetMaxInputs() int { + return MaxGreaterOrEqualInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (g *GreaterOrEqual) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (g *GreaterOrEqual) String() string { + return "greaterOrEqual operator" +} diff --git a/ops/opset13/greater_or_equal_test.go b/ops/opset13/greater_or_equal_test.go new file mode 100644 index 0000000..37f5dec --- /dev/null +++ b/ops/opset13/greater_or_equal_test.go @@ -0,0 +1,133 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestGreaterOrEqualInit(t *testing.T) { + g := &GreaterOrEqual{} + + // since 'greaterOrEqual' does not have any attributes we pass in nil. This should not + // fail initializing the greaterOrEqual. + err := g.Init(ops.EmptyNodeProto()) + assert.Nil(t, err) +} + +func TestGreaterOrEqual(t *testing.T) { + tests := []struct { + greaterOrEqual *GreaterOrEqual + backings [][]float32 + shapes [][]int + expected []bool + }{ + { + &GreaterOrEqual{}, + [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, + [][]int{{2, 2}, {2, 2}}, + []bool{false, true, true, true}, + }, + { + &GreaterOrEqual{}, + [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, + [][]int{{3, 2}, {3, 2}}, + []bool{false, false, true, true, true, true}, + }, + { + &GreaterOrEqual{}, + [][]float32{{0, 1}, {0, 1, 2, 3}}, + [][]int{{2}, {2, 2}}, + []bool{true, true, false, false}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.greaterOrEqual.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationGreaterOrEqual(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + ops.TensorWithBackingFixture([]uint32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + ops.TensorWithBackingFixture([]uint64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &GreaterOrEqual{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", &GreaterOrEqual{}), + }, + } + + for _, test := range tests { + greaterOrEqual := &GreaterOrEqual{} + validated, err := greaterOrEqual.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/greater_test.go b/ops/opset13/greater_test.go new file mode 100644 index 0000000..18bc294 --- /dev/null +++ b/ops/opset13/greater_test.go @@ -0,0 +1,133 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestGreaterInit(t *testing.T) { + g := &Greater{} + + // since 'greater' does not have any attributes we pass in nil. This should not + // fail initializing the greater. + err := g.Init(ops.EmptyNodeProto()) + assert.Nil(t, err) +} + +func TestGreater(t *testing.T) { + tests := []struct { + greater *Greater + backings [][]float32 + shapes [][]int + expected []bool + }{ + { + &Greater{}, + [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, + [][]int{{2, 2}, {2, 2}}, + []bool{false, false, true, true}, + }, + { + &Greater{}, + [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, + [][]int{{3, 2}, {3, 2}}, + []bool{false, false, false, true, true, true}, + }, + { + &Greater{}, + [][]float32{{0, 1}, {0, 1, 2, 3}}, + [][]int{{2}, {2, 2}}, + []bool{false, false, false, false}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.greater.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationGreater(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + ops.TensorWithBackingFixture([]uint32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + ops.TensorWithBackingFixture([]uint64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &Greater{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Greater{}), + }, + } + + for _, test := range tests { + greater := &Greater{} + validated, err := greater.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/gru.go b/ops/opset13/gru.go index 42812ff..e5f3ff6 100644 --- a/ops/opset13/gru.go +++ b/ops/opset13/gru.go @@ -1,41 +1,67 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinGRUInputs = 3 + MaxGRUInputs = 6 +) + // GRU represents the ONNX gru operator. It only supports a simple forward gru // operation with default activations. type GRU struct { - // Number of neurons in the hidden state. - hiddenSize int - - // When computing the output of the hidden gate, apply the linear - // transformation before multiplying by the output of the reset gate. + activationAlpha []float32 + activationBeta []float32 + activations []string + direction ops.SequenceProcessDirection + hiddenSize int linearBeforeReset bool } // newGRU creates a new gru operator. func newGRU() ops.Operator { - return &GRU{} + return &GRU{ + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + linearBeforeReset: false, + } } // Init initializes the gru operator. Currently, our GRU operator does not support all // attributes as specified by the ONNX operator. The basic functionality is working and // the other attributes can be added later on. -func (g *GRU) Init(attributes []*onnx.AttributeProto) error { +func (g *GRU) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() for _, attr := range attributes { switch attr.GetName() { - case "hidden_size": + case ops.ActivationAlphaAttr: + g.activationAlpha = attr.GetFloats() + case ops.ActivationBetaAttr: + g.activationBeta = attr.GetFloats() + case ops.ActivationsAttr: + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + g.activations = activations + case ops.ClipAttr: + return ops.ErrUnsupportedAttribute(attr.GetName(), g) + case ops.DirectionAttr: + g.direction = ops.SequenceProcessDirection(attr.GetS()) + if g.direction != ops.Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), g) + } + case ops.HiddenSizeAttr: g.hiddenSize = int(attr.GetI()) case "linear_before_reset": g.linearBeforeReset = ops.Int64ToBool(attr.GetI()) default: - return fmt.Errorf(ops.UnsupportedAttrErrTemplate, g, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), g) } } @@ -44,30 +70,29 @@ func (g *GRU) Init(attributes []*onnx.AttributeProto) error { // Apply applies the gru operator. func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - X := inputs[0] - W := inputs[1] - R := inputs[2] - B := inputs[3] if inputs[4] != nil { - return nil, fmt.Errorf("%v: sequence lens not yet supported as input", g) + return nil, ops.ErrUnsupportedInput("sequence lens", g) } - initialH := inputs[5] + X := inputs[0] seqLength := X.Shape()[0] batchSize := X.Shape()[1] - Wz, Wr, Wh, err := g.getForwardWeights(W) + Wz, Wr, Wh, err := g.getWeights(inputs[1]) if err != nil { return nil, err } - Rz, Rr, Rh, err := g.getRecurrentWeights(R) + Rz, Rr, Rh, err := g.getWeights(inputs[2]) if err != nil { return nil, err } + B := inputs[3] if B == nil { - B = g.initialB() + // 6 is the number of bias matrices required by ONNX definition. + nBiasMatrices := 6 + B = ops.ZeroTensor(1, nBiasMatrices*g.hiddenSize) } Wbz, Wbr, Wbh, Rbz, Rbr, Rbh, err := g.getBiases(B) @@ -75,39 +100,49 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - var prevH tensor.Tensor - if initialH == nil { - prevH = g.initialH(batchSize) - } else { - prevH = initialH.Clone().(tensor.Tensor) + prevH := inputs[5] + if prevH == nil { + prevH = ops.ZeroTensor(1, batchSize, g.hiddenSize) } // Extract the shape of the hidden dimensions without the bidirectional dimension, as // we do not support bidirectional GRU yet. shapeWithoutBidir := prevH.Shape().Clone()[1:] + err = prevH.Reshape(shapeWithoutBidir...) if err != nil { return nil, err } + fActivation, err := ops.GetActivation(g.activations[0]) + if err != nil { + return nil, err + } + + gActivation, err := ops.GetActivation(g.activations[1]) + if gActivation == nil { + return nil, err + } + outputs := []tensor.Tensor{} + for i := 0; i < seqLength; i++ { Xt, err := g.extractXt(X, i) if err != nil { return nil, err } - zt, err := g.gateCalculation(Xt, prevH, Wz, Rz, Wbz, Rbz, ops.Sigmoid) + zt, err := g.gateCalculation(Xt, prevH, Wz, Rz, Wbz, Rbz, fActivation) if err != nil { return nil, err } - rt, err := g.gateCalculation(Xt, prevH, Wr, Rr, Wbr, Rbr, ops.Sigmoid) + rt, err := g.gateCalculation(Xt, prevH, Wr, Rr, Wbr, Rbr, fActivation) if err != nil { return nil, err } - ht, err := g.htCalculation(Xt, prevH, rt, Wh, Rh, Wbh, Rbh, ops.Tanh) + ht, err := g.htCalculation(Xt, prevH, rt, Wh, Rh, Wbh, Rbh, gActivation) if err != nil { return nil, err } @@ -136,12 +171,17 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - Yh := prevH.Clone().(tensor.Tensor) + Yh, ok := prevH.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", prevH.Clone()) + } + // Reshape the output so it adds the num_directions as specified by onnx. err = Yh.Reshape([]int{1, batchSize, g.hiddenSize}...) if err != nil { return nil, err } + return []tensor.Tensor{Y, Yh}, nil } @@ -152,12 +192,12 @@ func (g *GRU) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (g *GRU) GetMinInputs() int { - return 3 + return MinGRUInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (g *GRU) GetMaxInputs() int { - return 6 + return MaxGRUInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -186,8 +226,8 @@ func (g *GRU) extractXt(X tensor.Tensor, t int) (tensor.Tensor, error) { func (g *GRU) gateCalculation( Xt, H, W, R, Wb, Rb tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { - gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) if err != nil { return nil, err @@ -209,12 +249,12 @@ func (g *GRU) gateCalculation( func (g *GRU) htCalculation( Xt, prevH, rt, W, R, Wb, Rb tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { - if !g.linearBeforeReset { temp1, err := tensor.Mul(rt, prevH) if err != nil { return nil, err } + return g.gateCalculation(Xt, temp1, W, R, Wb, Rb, activation) } @@ -244,7 +284,7 @@ func (g *GRU) htCalculation( } func (g *GRU) hiddenCalculation(zt, ht, prevH tensor.Tensor) (tensor.Tensor, error) { - temp1, err := tensor.Sub(onesTensor(zt), zt) + temp1, err := tensor.Sub(ops.OnesTensor(zt), zt) if err != nil { return nil, err } @@ -262,102 +302,32 @@ func (g *GRU) hiddenCalculation(zt, ht, prevH tensor.Tensor) (tensor.Tensor, err return tensor.Add(temp2, temp3) } -// getForwardWeights returns the weights for the gate. -func (g *GRU) getForwardWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { - n, err := g.extractWeights(W) - if err != nil { - return nil, nil, nil, err - } - return n[0], n[1], n[2], nil -} +// getWeights splits tensor W into 3 weight matrices. +// The W tensor, by GONNX definition, has 3 dimensions with 3 weight +// tensors in it (6 if bidirectional, but that is not supported). +func (g *GRU) getWeights(W tensor.Tensor) (Wz, Wr, Wh tensor.Tensor, err error) { + nWeightMatrices := 3 + nWeightDimensions := 3 -// getRecurrentWeights returns recurrent weights. -func (g *GRU) getRecurrentWeights(R tensor.Tensor) (Rz, Rr, Rh tensor.Tensor, err error) { - recurrentWeights, err := g.extractWeights(R) + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, g.hiddenSize) if err != nil { return nil, nil, nil, err } - return recurrentWeights[0], recurrentWeights[1], recurrentWeights[2], nil + + return weights[0], weights[1], weights[2], nil } // getBiases returns the biases from the Bias node as specified by the ONNX standard. +// The B tensor, by GONNX definition, has 2 dimensions with 6 bias +// tensors in it (12 if bidirectional, but that is not supported). func (g *GRU) getBiases(B tensor.Tensor) (Wbz, Wbr, Wbh, Rbz, Rbr, Rbh tensor.Tensor, err error) { - biases, err := g.extractBiases(B) + nBiasMatrices := 6 + nBiasDimensions := 2 + + biases, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, g.hiddenSize) if err != nil { return nil, nil, nil, nil, nil, nil, err } - return biases[0], biases[1], biases[2], biases[3], biases[4], biases[5], nil -} - -// extractWeights extracts 3 weight tensors from node W. -// W contains all 3 weight tensors concatenated on top of each other in the following order: -// forward weights: [Wz, Wr, Wh] -// recurrent weights: [Rz, Rr, Rh] -// -// W will have a shape of (num_directions, 3 * hidden_size, ...) and we extract the -// by slicing over the '3 * hidden_size' dimension. -func (g *GRU) extractWeights(W tensor.Tensor) ([]tensor.Tensor, error) { - dirSlice := ops.NewSlicer(0) - weights := make([]tensor.Tensor, 3) - - for i := 0; i < 3; i++ { - slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) - w, err := W.Slice(dirSlice, slice, nil) - if err != nil { - return nil, err - } - - weights[i] = w - } - return weights, nil -} -// extractBiases extracts the 6 bias tensors from tensor B. -// B contains all 6 bias tensors concatenated on top of each other in the following order: -// [Wbz, Wbr, Wbh, Rbz, Rbr, Rbh] -// B has a shape of (num_directions, 6 * hidden_size) and every individual bias tensor should have -// shape (hidden_size). We extract the biases by slicing over the '6 * hidden_size' dimension. -func (g *GRU) extractBiases(B tensor.Tensor) ([]tensor.Tensor, error) { - dirSlice := ops.NewSlicer(0) - biases := make([]tensor.Tensor, 7) - - for i := 0; i < 6; i++ { - slice := ops.NewSlicer(i*g.hiddenSize, (i+1)*g.hiddenSize) - w, err := B.Slice(dirSlice, slice) - if err != nil { - return nil, err - } - - biases[i] = w - } - return biases, nil -} - -// initialB returns the initialB tensor. If the biases are not specified by the inputs -// of the gru operator this tensor can be used as the biases tensor. By default biases -// are all 0. -func (g *GRU) initialB() tensor.Tensor { - return tensor.New( - tensor.WithShape(1, 6*g.hiddenSize), - tensor.WithBacking(ops.Zeros(6*g.hiddenSize)), - ) -} - -// initialH can be used for initialH when it is not specified by the inputs of the operator. -// if it is not specified by the inputs assumed to be 0. It has shape -// (num_directions, batch_size, hidden_size). -func (g *GRU) initialH(batchSize int) tensor.Tensor { - hiddenFloats := ops.Zeros(batchSize * g.hiddenSize) - return tensor.New( - tensor.WithShape(1, batchSize, g.hiddenSize), - tensor.WithBacking(hiddenFloats), - ) -} - -// onesTensor returns a new tensor with the same shape as the given tensor intialized with all ones. -func onesTensor(t tensor.Tensor) tensor.Tensor { - return tensor.New( - tensor.WithShape(t.Shape()...), - tensor.WithBacking(ops.Ones(ops.NElements(t.Shape()...))), - ) + return biases[0], biases[1], biases[2], biases[3], biases[4], biases[5], nil } diff --git a/ops/opset13/gru_test.go b/ops/opset13/gru_test.go index ccc648d..44140f9 100644 --- a/ops/opset13/gru_test.go +++ b/ops/opset13/gru_test.go @@ -1,22 +1,25 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) func TestGruInit(t *testing.T) { gru := &GRU{} - err := gru.Init(GRUOnnxAttributeProtoFixture()) + err := gru.Init(GRUOnnxNodeProtoFixture()) assert.Nil(t, err) - assert.Equal(t, true, gru.linearBeforeReset) + assert.Equal(t, []float32{1.0}, gru.activationAlpha) + assert.Equal(t, []float32{2.0}, gru.activationBeta) + assert.Equal(t, []string{"sigmoid", "tanh"}, gru.activations) + assert.Equal(t, gru.direction, ops.Forward) assert.Equal(t, 5, gru.hiddenSize) + assert.Equal(t, true, gru.linearBeforeReset) } func TestGruInitUnkownAttr(t *testing.T) { @@ -25,34 +28,18 @@ func TestGruInitUnkownAttr(t *testing.T) { attr []*onnx.AttributeProto err error }{ - { - []*onnx.AttributeProto{{Name: "activation_alpha"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "activation_alpha"), - }, - { - []*onnx.AttributeProto{{Name: "activation_beta"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "activation_beta"), - }, - { - []*onnx.AttributeProto{{Name: "direction"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "direction"), - }, { []*onnx.AttributeProto{{Name: "clip"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "clip"), - }, - { - []*onnx.AttributeProto{{Name: "activation"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "activation"), + ops.ErrUnsupportedAttribute("clip", &gru), }, { []*onnx.AttributeProto{{Name: "unknown"}}, - fmt.Errorf(ops.UnsupportedAttrErrTemplate, &gru, "unknown"), + ops.ErrInvalidAttribute("unknown", &gru), }, } for _, test := range tests { - err := gru.Init(test.attr) + err := gru.Init(&onnx.NodeProto{Attribute: test.attr}) assert.Equal(t, test.err, err) } } @@ -65,25 +52,53 @@ func TestGru(t *testing.T) { err error }{ { - &GRU{4, true}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: true, + }, gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, }, { - &GRU{4, false}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, }, { - &GRU{4, false}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, gruInput1, []float32{0.44905475, 0.4406946, 0.43368173, 0.42782417}, nil, }, { - &GRU{4, false}, + &GRU{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + linearBeforeReset: false, + }, gruInputNoBNoH, []float32{0.24553154, 0.24553154, 0.24553154, 0.24553154}, nil, @@ -138,7 +153,7 @@ func TestInputValidationGRU(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, - fmt.Errorf("gru operator: expected 3-6 input tensors, got 1"), + ops.ErrInvalidOptionalInputCount(1, &GRU{}), }, { []tensor.Tensor{ @@ -147,7 +162,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &GRU{}), }, { []tensor.Tensor{ @@ -156,7 +171,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &GRU{}), }, { []tensor.Tensor{ @@ -165,7 +180,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &GRU{}), }, { []tensor.Tensor{ @@ -174,7 +189,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 2 does not allow type int"), + ops.ErrInvalidInputType(2, "int", &GRU{}), }, { []tensor.Tensor{ @@ -184,7 +199,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 3 does not allow type int"), + ops.ErrInvalidInputType(3, "int", &GRU{}), }, { []tensor.Tensor{ @@ -195,7 +210,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 4 does not allow type float32"), + ops.ErrInvalidInputType(4, "float32", &GRU{}), }, { []tensor.Tensor{ @@ -207,7 +222,7 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - fmt.Errorf("gru operator: input 5 does not allow type int"), + ops.ErrInvalidInputType(5, "int", &GRU{}), }, } @@ -216,6 +231,7 @@ func TestInputValidationGRU(t *testing.T) { validated, err := gru.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) @@ -268,9 +284,15 @@ func gruInputNoBNoH() []tensor.Tensor { return inputs } -func GRUOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "linear_before_reset", I: 1}, - {Name: "hidden_size", I: 5}, +func GRUOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + {Name: "linear_before_reset", I: 1}, + }, } } diff --git a/ops/opset13/less.go b/ops/opset13/less.go new file mode 100644 index 0000000..d8e271d --- /dev/null +++ b/ops/opset13/less.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinLessInputs = 2 + MaxLessInputs = 2 +) + +// Less represents the ONNX less operator. +type Less struct{} + +// newLess creates a new less operator. +func newLess() ops.Operator { + return &Less{} +} + +// Init initializes the less operator. +func (l *Less) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the less operator. +func (l *Less) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Lt, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (l *Less) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(l, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (l *Less) GetMinInputs() int { + return MinLessInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (l *Less) GetMaxInputs() int { + return MaxLessInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (l *Less) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (l *Less) String() string { + return "less operator" +} diff --git a/ops/opset13/less_or_equal.go b/ops/opset13/less_or_equal.go new file mode 100644 index 0000000..3fcb85f --- /dev/null +++ b/ops/opset13/less_or_equal.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinLessOrEqualInputs = 2 + MaxLessOrEqualInputs = 2 +) + +// LessOrEqual represents the ONNX lessOrEqual operator. +type LessOrEqual struct{} + +// newLessOrEqual creates a new lessOrEqual operator. +func newLessOrEqual() ops.Operator { + return &LessOrEqual{} +} + +// Init initializes the lessOrEqual operator. +func (l *LessOrEqual) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the lessOrEqual operator. +func (l *LessOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Lte, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (l *LessOrEqual) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(l, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (l *LessOrEqual) GetMinInputs() int { + return MinLessOrEqualInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (l *LessOrEqual) GetMaxInputs() int { + return MaxLessOrEqualInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (l *LessOrEqual) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (l *LessOrEqual) String() string { + return "lessOrEqual operator" +} diff --git a/ops/opset13/less_or_equal_test.go b/ops/opset13/less_or_equal_test.go new file mode 100644 index 0000000..fbba443 --- /dev/null +++ b/ops/opset13/less_or_equal_test.go @@ -0,0 +1,133 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestLessOrEqualInit(t *testing.T) { + l := &LessOrEqual{} + + // since 'lessOrEqual' does not have any attributes we pass in nil. This should not + // fail initializing the lessOrEqual. + err := l.Init(ops.EmptyNodeProto()) + assert.Nil(t, err) +} + +func TestLessOrEqual(t *testing.T) { + tests := []struct { + lessOrEqual *LessOrEqual + backings [][]float32 + shapes [][]int + expected []bool + }{ + { + &LessOrEqual{}, + [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, + [][]int{{2, 2}, {2, 2}}, + []bool{true, true, false, false}, + }, + { + &LessOrEqual{}, + [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, + [][]int{{3, 2}, {3, 2}}, + []bool{true, true, true, false, false, false}, + }, + { + &LessOrEqual{}, + [][]float32{{0, 1}, {0, 1, 2, 3}}, + [][]int{{2}, {2, 2}}, + []bool{true, true, true, true}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.lessOrEqual.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationLessOrEqual(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + ops.TensorWithBackingFixture([]uint32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + ops.TensorWithBackingFixture([]uint64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &LessOrEqual{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", &LessOrEqual{}), + }, + } + + for _, test := range tests { + lessOrEqual := &LessOrEqual{} + validated, err := lessOrEqual.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/less_test.go b/ops/opset13/less_test.go new file mode 100644 index 0000000..a7a4036 --- /dev/null +++ b/ops/opset13/less_test.go @@ -0,0 +1,133 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestLessInit(t *testing.T) { + l := &Less{} + + // since 'less' does not have any attributes we pass in nil. This should not + // fail initializing the less. + err := l.Init(ops.EmptyNodeProto()) + assert.Nil(t, err) +} + +func TestLess(t *testing.T) { + tests := []struct { + less *Less + backings [][]float32 + shapes [][]int + expected []bool + }{ + { + &Less{}, + [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, + [][]int{{2, 2}, {2, 2}}, + []bool{true, false, false, false}, + }, + { + &Less{}, + [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, + [][]int{{3, 2}, {3, 2}}, + []bool{true, true, false, false, false, false}, + }, + { + &Less{}, + [][]float32{{0, 1}, {0, 1, 2, 3}}, + [][]int{{2}, {2, 2}}, + []bool{false, false, true, true}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.less.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationLess(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + ops.TensorWithBackingFixture([]uint32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + ops.TensorWithBackingFixture([]uint64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &Less{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Less{}), + }, + } + + for _, test := range tests { + less := &Less{} + validated, err := less.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/linear_regressor.go b/ops/opset13/linear_regressor.go new file mode 100644 index 0000000..ceb0cb1 --- /dev/null +++ b/ops/opset13/linear_regressor.go @@ -0,0 +1,117 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinLinearRegressorInputs = 1 + MaxLinearRegressorInputs = 1 +) + +// PostTransformOption describes all possible post transform options for the +// linear regressor operator. +type postTransformOption string + +const ( + noTransform postTransformOption = "NONE" + softmaxTransform postTransformOption = "SOFTMAX" + logisticTransform postTransformOption = "LOGISTIC" + softmaxZeroTransform postTransformOption = "SOFTMAX_ZERO" + probitTransform postTransformOption = "PROBIT" +) + +// LinearRegressor represents the ONNX-ml linearRegressor operator. +type LinearRegressor struct { + coefficients tensor.Tensor + intercepts tensor.Tensor + postTransform postTransformOption + targets int +} + +// newLinearRegressor creates a new linearRegressor operator. +func newLinearRegressor() ops.Operator { + return &LinearRegressor{ + postTransform: noTransform, + targets: 1, + } +} + +// Init initializes the linearRegressor operator. +func (l *LinearRegressor) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "coefficients": + floats := attr.GetFloats() + l.coefficients = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) + case "intercepts": + floats := attr.GetFloats() + l.intercepts = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) + case "post_transform": + return ops.ErrUnsupportedAttribute(attr.GetName(), l) + case "targets": + l.targets = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), l) + } + } + + err := l.coefficients.Reshape(l.targets, ops.NElements(l.coefficients.Shape()...)/l.targets) + if err != nil { + return err + } + + return l.coefficients.T() +} + +// Apply applies the linearRegressor operator. +func (l *LinearRegressor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + X := inputs[0] + + result, err := tensor.MatMul(X, l.coefficients) + if err != nil { + return nil, err + } + + result, intercepts, err := ops.UnidirectionalBroadcast(result, l.intercepts) + if err != nil { + return nil, err + } + + Y, err := tensor.Add(result, intercepts) + if err != nil { + return nil, err + } + + return []tensor.Tensor{Y}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (l *LinearRegressor) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(l, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (l *LinearRegressor) GetMinInputs() int { + return MinLinearRegressorInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (l *LinearRegressor) GetMaxInputs() int { + return MaxLinearRegressorInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (l *LinearRegressor) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (l *LinearRegressor) String() string { + return "linearRegressor operator" +} diff --git a/ops/opset13/linear_regressor_test.go b/ops/opset13/linear_regressor_test.go new file mode 100644 index 0000000..abaeb25 --- /dev/null +++ b/ops/opset13/linear_regressor_test.go @@ -0,0 +1,196 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestLinearRegressorInit(t *testing.T) { + linearRegressor := &LinearRegressor{} + err := linearRegressor.Init(LinearRegressorOnnxNodeProtoFixture()) + + assert.Nil(t, err) + assert.Equal(t, []float32{1.5, 2.5, 3.5}, linearRegressor.coefficients.Data()) + assert.Equal(t, []float32{0.5}, linearRegressor.intercepts.Data()) + assert.Equal(t, 1, linearRegressor.targets) +} + +func TestLinearRegressorInitFailUnsupportedAttribute(t *testing.T) { + linearRegressor := &LinearRegressor{} + err := linearRegressor.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "post_transform"}, {Name: "Another"}}}) + + expected := ops.ErrUnsupportedAttribute("post_transform", linearRegressor) + assert.Equal(t, expected, err) +} + +func TestLinearRegressorInitFailInvalidAttribute(t *testing.T) { + linearRegressor := &LinearRegressor{} + err := linearRegressor.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "much_invalid"}}}) + + expected := ops.ErrInvalidAttribute("much_invalid", linearRegressor) + assert.Equal(t, expected, err) +} + +func TestLinearRegressor(t *testing.T) { + tests := []struct { + attrs []*onnx.AttributeProto + shape []int + backing []float32 + expectedShape tensor.Shape + expectedBacking []float32 + description string + }{ + { + []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{-0.45977323}}, + {Name: "intercepts", Floats: []float32{0.21509616}}, + {Name: "targets", I: 1}, + }, + []int{1, 1}, + []float32{0.7777024}, + []int{1, 1}, + []float32{-0.14247058}, + "linear regressor with 1 input and 1 output variable, 1 sample", + }, + { + []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{-0.45977323}}, + {Name: "intercepts", Floats: []float32{0.21509616}}, + {Name: "targets", I: 1}, + }, + []int{5, 1}, + []float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011}, + []int{5, 1}, + []float32{-0.14247058, 0.105881065, -0.16388504, -0.22892947, -0.23207982}, + "linear regressor with 1 input and 1 output variable, 5 samples", + }, + { + []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}}, + {Name: "intercepts", Floats: []float32{-0.43156273}}, + {Name: "targets", I: 1}, + }, + []int{1, 3}, + []float32{0.7777024, 0.23754121, 0.82427853}, + []int{1, 1}, + []float32{0.039368242}, + "linear regressor with 3 inputs and 1 output variable, 1 sample", + }, + { + []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}}, + {Name: "intercepts", Floats: []float32{-0.43156273}}, + {Name: "targets", I: 1}, + }, + []int{2, 3}, + []float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011, 0.45344925}, + []int{2, 1}, + []float32{0.039368242, 0.14766997}, + "linear regressor with 3 inputs and 1 output variable, 2 samples", + }, + { + []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{ + 0.5384742, 0.36729308, 0.13292366, -0.03843413, + 0.28054297, -0.27832435, 0.4381632, 0.00726224, + -0.64418418, -0.35812317, 0.69767598, 0.12989015, + }}, + {Name: "intercepts", Floats: []float32{-0.37036705, -0.34072968, 0.05487297}}, + {Name: "targets", I: 3}, + }, + []int{1, 4}, + []float32{0.7777024, 0.23754121, 0.82427853, 0.9657492}, + []int{1, 3}, + []float32{0.20810121, 0.17951778, 0.16934107}, + "linear regressor with 4 input and 3 output variables, 1 samples", + }, + { + []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{ + 0.5384742, 0.36729308, 0.13292366, -0.03843413, + 0.28054297, -0.27832435, 0.4381632, 0.00726224, + -0.64418418, -0.35812317, 0.69767598, 0.12989015, + }}, + {Name: "intercepts", Floats: []float32{-0.37036705, -0.34072968, 0.05487297}}, + {Name: "targets", I: 3}, + }, + []int{2, 4}, + []float32{0.7777024, 0.23754121, 0.82427853, 0.9657492, 0.9726011, 0.45344925, 0.60904247, 0.7755265}, + []int{2, 3}, + []float32{0.20810121, 0.17951778, 0.16934107, 0.37105185, 0.0784128, -0.20840444}, + "linear regressor with 4 input and 3 output variables, 2 samples", + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + linearRegressor := newLinearRegressor() + err := linearRegressor.Init(&onnx.NodeProto{Attribute: test.attrs}) + assert.Nil(t, err, test.description) + + res, err := linearRegressor.Apply(inputs) + assert.Nil(t, err, test.description) + assert.Equal(t, test.expectedShape, res[0].Shape(), test.description) + assert.Equal(t, test.expectedBacking, res[0].Data(), test.description) + } +} + +func TestInputValidationLinearRegressor(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ops.TensorWithBackingFixture([]int32{1, 2}, 2)}, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]int64{1, 2}, 2)}, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &LinearRegressor{}), + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + ops.ErrInvalidInputType(0, "int", &LinearRegressor{}), + }, + } + + for _, test := range tests { + linearRegressor := &LinearRegressor{} + validated, err := linearRegressor.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func LinearRegressorOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "coefficients", Floats: []float32{1.5, 2.5, 3.5}}, + {Name: "intercepts", Floats: []float32{0.5}}, + {Name: "targets", I: 1}, + }, + } +} diff --git a/ops/opset13/lstm.go b/ops/opset13/lstm.go new file mode 100644 index 0000000..8b32b2a --- /dev/null +++ b/ops/opset13/lstm.go @@ -0,0 +1,413 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinLSTMInputs = 3 + MaxLSTMInputs = 8 +) + +// LSTM represents the ONNX lstm operator. +type LSTM struct { + activationAlpha []float32 + activationBeta []float32 + activations []string + direction ops.SequenceProcessDirection + hiddenSize int + inputForget bool + + outputs []string +} + +// newLSTM creates a new lstm operator. +func newLSTM() ops.Operator { + return &LSTM{ + activations: []string{"sigmoid", "tanh", "tanh"}, + direction: ops.Forward, + inputForget: false, + outputs: []string{"Y", "Y_h", "Y_c"}, + } +} + +// Init initializes the lstm operator. +func (l *LSTM) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case ops.ActivationAlphaAttr: + l.activationAlpha = attr.GetFloats() + case ops.ActivationBetaAttr: + l.activationBeta = attr.GetFloats() + case ops.ActivationsAttr: + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + l.activations = activations + case ops.ClipAttr: + return ops.ErrUnsupportedAttribute(attr.GetName(), l) + case ops.DirectionAttr: + l.direction = ops.SequenceProcessDirection(attr.GetS()) + if l.direction != ops.Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), l) + } + case ops.HiddenSizeAttr: + l.hiddenSize = int(attr.GetI()) + case "input_forget": + l.inputForget = attr.GetI() == 1 + default: + return ops.ErrInvalidAttribute(attr.GetName(), l) + } + } + + l.outputs = n.GetOutput() + + return nil +} + +// Apply applies the lstm operator. +func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + if inputs[4] != nil { + return nil, ops.ErrUnsupportedInput("sequence_lens", l) + } + + X := inputs[0] + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] + + Wi, Wo, Wf, Wc, err := l.getWeights(inputs[1]) + if err != nil { + return nil, err + } + + Ri, Ro, Rf, Rc, err := l.getWeights(inputs[2]) + if err != nil { + return nil, err + } + + B := inputs[3] + if B == nil { + // 8 is the number of bias matrices required by ONNX definition. + nBiasMatrices := 8 + B = ops.ZeroTensor(1, nBiasMatrices*l.hiddenSize) + } + + Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc, err := l.getBiases(B) + if err != nil { + return nil, err + } + + Ht := inputs[5] + if Ht == nil { + Ht = ops.ZeroTensor(1, batchSize, l.hiddenSize) + } + + Ct := inputs[6] + if Ct == nil { + Ct = ops.ZeroTensor(1, batchSize, l.hiddenSize) + } + + var Pi, Po, Pf tensor.Tensor + + P := inputs[7] + if P != nil { + Pi, Po, Pf, err = l.getPeepholes(P) + if err != nil { + return nil, err + } + } + + // Reshape the hidden and cell tensor without the bidirectional dimension, as + // we do not support bidirectional yet. This is the dimension at + // index 0. + if err = Ht.Reshape(Ht.Shape().Clone()[1:]...); err != nil { + return nil, err + } + + if err = Ct.Reshape(Ct.Shape().Clone()[1:]...); err != nil { + return nil, err + } + + fActivation, err := ops.GetActivation(l.activations[0]) + if err != nil { + return nil, err + } + + gActivation, err := ops.GetActivation(l.activations[1]) + if gActivation == nil { + return nil, err + } + + hActivation, err := ops.GetActivation(l.activations[2]) + if err != nil { + return nil, err + } + + outputs := []tensor.Tensor{} + + // Loop over all timesteps of the input, applying the LSTM calculation to every + // timesteps while updating the hidden tensor. + for t := 0; t < seqLength; t++ { + Xt, err := X.Slice(ops.NewSlicer(t, t+1), nil, nil) + if err != nil { + return nil, err + } + + it, err := l.gateCalculation(Xt, Wi, Wbi, Ht, Ri, Rbi, Pi, Ct, fActivation) + if err != nil { + return nil, err + } + + ft, err := l.gateCalculation(Xt, Wf, Wbf, Ht, Rf, Rbf, Pf, Ct, fActivation) + if err != nil { + return nil, err + } + + ct, err := l.gateCalculation(Xt, Wc, Wbc, Ht, Rc, Rbc, nil, nil, gActivation) + if err != nil { + return nil, err + } + + Ct, err = l.cellCalculation(ft, it, ct, Ct) + if err != nil { + return nil, err + } + + ot, err := l.gateCalculation(Xt, Wo, Wbo, Ht, Ro, Rbo, Po, Ct, fActivation) + if err != nil { + return nil, err + } + + Ht, err = l.hiddenCalculation(ot, Ct, hActivation) + if err != nil { + return nil, err + } + + outputs = append(outputs, Ht) + } + + Y := outputs[0] + if len(outputs) > 1 { + Y, err = tensor.Concat(0, Y, outputs[1:]...) + if err != nil { + return nil, err + } + } + + Yh, ok := Ht.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", Ht.Clone()) + } + + Yc, ok := Ct.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", Ct.Clone()) + } + + // Reshape the hidden tensor without the bidirectional dimension, as + // we do not support bidirectional RNN yet. This is the dimension at + // index 0. + if err = Y.Reshape(seqLength, 1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + if err = Yc.Reshape(1, batchSize, l.hiddenSize); err != nil { + return nil, err + } + + outputMap := map[string]tensor.Tensor{ + "Y": Y, "Y_h": Yh, "Y_c": Yc, + } + + result := []tensor.Tensor{} + for _, outputName := range l.outputs { + result = append(result, outputMap[outputName]) + } + + return result, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (l *LSTM) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(l, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (l *LSTM) GetMinInputs() int { + return MinLSTMInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (l *LSTM) GetMaxInputs() int { + return MaxLSTMInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (l *LSTM) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (l *LSTM) String() string { + return "lstm operator" +} + +// gateCalculation performs a standard gate calculation for an LSTM gate defined as: +// +// o = f(Xt*(W^T) + Wb + H*(R^T) + Rb + P (.) C) +// +// Where: +// - 'f()' is an activation function +// - 'Xt' is the input tensor +// - 'W' is the input weight +// - 'Wb' is the input bias +// - 'H' is the hidden tensor +// - 'R' is the hidden weight tensor +// - 'Rb' is the hidden bias +// - 'P' are peephole weights (optional, can be nil) +// - 'C' is the cell state +// - '(.)' is element-wise multiplication +// +// 'o' is the result tensor that is returned. +// This calculation can be used for the forget gate, input gate, cell gate +// and output gate calculations. +func (l *LSTM) gateCalculation( + Xt, W, Wb, H, R, Rb, P, C tensor.Tensor, activation ops.Activation, +) (tensor.Tensor, error) { + gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) + if err != nil { + return nil, err + } + + hiddenCalc, err := gemm.Apply([]tensor.Tensor{H, R, Rb}) + if err != nil { + return nil, err + } + + output, err := tensor.Add(inputCalc[0], hiddenCalc[0]) + if err != nil { + return nil, err + } + + if P != nil { + C, broadcastedP, err := ops.UnidirectionalBroadcast(C, P) + if err != nil { + return nil, err + } + + peepholeActivation, err := tensor.Mul(broadcastedP, C) + if err != nil { + return nil, err + } + + output, err = tensor.Add(output, peepholeActivation) + if err != nil { + return nil, err + } + } + + return activation(output) +} + +// cellCalculation performs the calculation of the LSTM cell update defined by: +// +// Ct = ft (.) Ct-1 + it (.) ct +// +// Where 'ft' is the forget gate activation at time t, (.) denotes element-wise +// multiplication, 'Ct-1' denotes the cell state at time t-1, 'it' denotes the input +// gate activation at time t and 'ct' denotes the cell state activation at time t (which) +// is not the same as Ct or Ct-1). +func (l *LSTM) cellCalculation(ft, it, ct, Ct tensor.Tensor) (tensor.Tensor, error) { + cellForget, err := tensor.Mul(ft, Ct) + if err != nil { + return nil, err + } + + cellInput, err := tensor.Mul(it, ct) + if err != nil { + return nil, err + } + + return tensor.Add(cellForget, cellInput) +} + +// hiddenCalculation performs the calculation of the new LSTM hidden state defined by: +// +// Ht = ot (.) h(Ct) +// +// Where Ht is the new hidden state at time t, 'ot' is the output at time t, (.) denotes +// element-wise multiplication, 'h()' denotes an activation function and 'Ct' denotes the +// cell state at time t. +func (l *LSTM) hiddenCalculation(ot, Ct tensor.Tensor, activation ops.Activation) (tensor.Tensor, error) { + cellActivated, err := activation(Ct) + if err != nil { + return nil, err + } + + return tensor.Mul(ot, cellActivated) +} + +// getWeights splits tensor W into 4 weight matrices. +// The W tensor, by GONNX definition, has 3 dimensions with 4 weight +// tensors in it (8 if bidirectional, but that is not supported). +func (l *LSTM) getWeights(W tensor.Tensor) (Wi, Wo, Wf, Wh tensor.Tensor, err error) { + nWeightMatrices := 4 + nWeightDimensions := 3 + + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, l.hiddenSize) + if err != nil { + return nil, nil, nil, nil, err + } + + return weights[0], weights[1], weights[2], weights[3], nil +} + +// getBiases splits tensor B into 8 bias matrices. +// The B tensor, by GONNX definition, has 2 dimensions with 8 bias +// tensors in it (16 if bidirectional, but that is not supported). +func (l *LSTM) getBiases(B tensor.Tensor) (Wbi, Wbo, Wbf, Wbc, Rbi, Rbo, Rbf, Rbc tensor.Tensor, err error) { + nBiasMatrices := 8 + nBiasDimensions := 2 + + b, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, l.hiddenSize) + if err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err + } + + return b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], nil +} + +// getPeepholes splits tensor P into 3 bias matrices. +// The P tensor, by GONNX definition, has 2 dimensions with 3 peephole +// tensors in it (6 if bidirectional, but that is not supported). +func (l *LSTM) getPeepholes(P tensor.Tensor) (Pi, Po, Pf tensor.Tensor, err error) { + nPeepholeMatrices := 3 + nPeepholeDimensions := 2 + + p, err := ops.ExtractMatrices(P, nPeepholeMatrices, nPeepholeDimensions, l.hiddenSize) + if err != nil { + return nil, nil, nil, err + } + + return p[0], p[1], p[2], nil +} diff --git a/ops/opset13/lstm_test.go b/ops/opset13/lstm_test.go new file mode 100644 index 0000000..83bfc86 --- /dev/null +++ b/ops/opset13/lstm_test.go @@ -0,0 +1,386 @@ +package opset13 + +import ( + "math/rand" + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestLSTMInit(t *testing.T) { + lstm := &LSTM{} + err := lstm.Init(LSTMOnnxNodeProtoFixture()) + + assert.Nil(t, err) + assert.Equal(t, []float32{1.0}, lstm.activationAlpha) + assert.Equal(t, []float32{2.0}, lstm.activationBeta) + assert.Equal(t, []string{"sigmoid", "tanh", "relu"}, lstm.activations) + assert.Equal(t, ops.Forward, lstm.direction) + assert.Equal(t, 5, lstm.hiddenSize) + assert.Equal(t, false, lstm.inputForget) + assert.Equal(t, []string{"Y", "Y_h"}, lstm.outputs) +} + +func TestLSTMInitUnkownAttr(t *testing.T) { + lstm := LSTM{} + tests := []struct { + attr []*onnx.AttributeProto + err error + }{ + { + []*onnx.AttributeProto{{Name: "clip"}}, + ops.ErrUnsupportedAttribute("clip", &lstm), + }, + { + []*onnx.AttributeProto{{Name: "unknown"}}, + ops.ErrInvalidAttribute("unknown", &lstm), + }, + } + + for _, test := range tests { + err := lstm.Init(&onnx.NodeProto{Attribute: test.attr}) + assert.Equal(t, test.err, err) + } +} + +func TestLSTM(t *testing.T) { + tests := []struct { + lstm *LSTM + inputs ops.InputFixture + expected []float32 + err error + }{ + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInput0, + []float32{0.9159305, 0.9356764, 0.87070554, 0.84180677}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "relu"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInput0, + []float32{1.7530097, 1.7829735, 1.6231446, 1.5197954}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "relu"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInput1, + []float32{10.598255, 10.547241, 10.214846, 10.267471}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "relu"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInputNoBNoH, + []float32{8.276371, 8.291079, 8.161418, 7.7900877}, + nil, + }, + { + &LSTM{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid", "tanh", "tanh"}, + direction: ops.Forward, + hiddenSize: 4, + outputs: []string{"Y", "Y_h", "Y_c"}, + }, + lstmInputPeepholes, + []float32{0.99891853, 0.99994266, 0.9995524, 0.99171203}, + nil, + }, + } + + for _, test := range tests { + inputs := test.inputs() + res, err := test.lstm.Apply(inputs) + assert.Equal(t, test.err, err) + + if err == nil { + assert.Equal(t, test.expected, res[1].Data()) + } + } +} + +func TestInputValidationLSTM(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + expected []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + nil, + nil, + nil, + nil, + nil, + }, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(1, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(0, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(2, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(3, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(4, "float32", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(5, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(6, "int", &LSTM{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(7, "int", &LSTM{}), + }, + } + + for _, test := range tests { + lstm := &LSTM{} + validated, err := lstm.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + if test.expected != nil { + assert.Equal(t, test.expected, validated) + } else { + assert.Equal(t, test.inputs, validated) + } + } + } +} + +func lstmInput0() []tensor.Tensor { + rand.Seed(10) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B: (num_directions, 8 * hidden_size). + ops.RandomFloat32TensorFixture(1, 32), + // Input sequence_lens: not supported. + nil, + // Input initial_h: (num_directions, batch_size, hidden_size). + ops.TensorWithBackingFixture(ops.Zeros(4), 1, 1, 4), + // Input initial_c: (num_directions, batch_size, hidden_size). + ops.TensorWithBackingFixture(ops.Zeros(4), 1, 1, 4), + // Input P: peephole weights. + nil, + } +} + +func lstmInput1() []tensor.Tensor { + rand.Seed(11) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B: (num_directions, 8 * hidden_size). + ops.RandomFloat32TensorFixture(1, 32), + // Input sequence_lens: not supported. + nil, + // Input initial_h: (num_directions, batch_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 1, 4), + // Input initial_c: (num_directions, batch_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 1, 4), + // Input P: peephole weights. + nil, + } +} + +func lstmInputNoBNoH() []tensor.Tensor { + rand.Seed(12) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B. + nil, + // Input sequence_lens: not supported. + nil, + // Input initial_h. + nil, + // Input initial_c. + nil, + // Input P: peephole weights. + nil, + } +} + +func lstmInputPeepholes() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 1, 3), + // Input W: (num_directions, 4 * hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 16, 3), + // Input R: (num_directions, 4 * hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 16, 4), + // Input B. + nil, + // Input sequence_lens: not supported. + nil, + // Input initial_h. + nil, + // Input initial_c. + nil, + // Input P: (num_directions, 3 * hidden_size). + ops.RandomFloat32TensorFixture(1, 12), + } +} + +func LSTMOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("relu")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + {Name: "input_forget", I: 0}, + }, + Output: []string{"Y", "Y_h"}, + } +} diff --git a/ops/opset13/matmul.go b/ops/opset13/matmul.go index 2cf7d0f..1212233 100644 --- a/ops/opset13/matmul.go +++ b/ops/opset13/matmul.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinMatMulInputs = 2 + MaxMatMulInputs = 2 +) + // MatMul represents the ONNX matmul operator. type MatMul struct{} @@ -17,7 +20,7 @@ func newMatMul() ops.Operator { } // Init initializes the matmul operator. -func (m *MatMul) Init(attributes []*onnx.AttributeProto) error { +func (m *MatMul) Init(*onnx.NodeProto) error { return nil } @@ -29,6 +32,10 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // If both are normal matrices, apply normal matrix multiplication. if len(A.Shape()) == 2 && len(B.Shape()) == 2 { out, err := tensor.MatMul(A, B) + if err != nil { + return nil, err + } + return []tensor.Tensor{out}, err } @@ -36,22 +43,34 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { prependedDimension := false if len(A.Shape()) == 1 { prependedDimension = true - A = A.Clone().(tensor.Tensor) - err := A.Reshape(1, A.Shape()[0]) - if err != nil { + + reshapedA, ok := A.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", A.Clone()) + } + + if err := reshapedA.Reshape(1, reshapedA.Shape()[0]); err != nil { return nil, err } + + A = reshapedA } // If B is a vector, promote to a matrix for the calculation. appendedDimension := false if len(B.Shape()) == 1 { appendedDimension = true - B = B.Clone().(tensor.Tensor) - err := B.Reshape(B.Shape()[0], 1) - if err != nil { + + reshapedB, ok := B.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", B.Clone()) + } + + if err := reshapedB.Reshape(reshapedB.Shape()[0], 1); err != nil { return nil, err } + + B = reshapedB } // Now we have to perform batch matrix multiplication. First we need to broadcast @@ -71,8 +90,8 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { currentShape := out.Shape().Clone() newShape := currentShape[:len(currentShape)-2] newShape = append(newShape, currentShape[len(currentShape)-1]) - err = out.Reshape(newShape...) - if err != nil { + + if err := out.Reshape(newShape...); err != nil { return nil, err } } @@ -80,8 +99,8 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if appendedDimension { currentShape := out.Shape().Clone() newShape := currentShape[:len(currentShape)-1] - err = out.Reshape(newShape...) - if err != nil { + + if err = out.Reshape(newShape...); err != nil { return nil, err } } @@ -96,12 +115,12 @@ func (m *MatMul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (m *MatMul) GetMinInputs() int { - return 2 + return MinMatMulInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (m *MatMul) GetMaxInputs() int { - return 2 + return MaxMatMulInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -132,7 +151,10 @@ func (m *MatMul) broadcastTensors(A, B tensor.Tensor) (tensor.Tensor, tensor.Ten // want to broadcast those. All leading dimensions we do want to broadcast. shapeA := A.Shape() shapeB := B.Shape() - for axis := len(shapeA) - 3; axis >= 0; axis-- { + + nMatrixDims := 3 + + for axis := len(shapeA) - nMatrixDims; axis >= 0; axis-- { sizeDimA := shapeA[axis] sizeDimB := shapeB[axis] @@ -149,7 +171,7 @@ func (m *MatMul) broadcastTensors(A, B tensor.Tensor) (tensor.Tensor, tensor.Ten return nil, nil, err } default: - return nil, nil, fmt.Errorf("incompatible dimensions") + return nil, nil, ops.ErrIncompatibleDimensions() } } } @@ -163,6 +185,7 @@ func (m *MatMul) broadcastTensors(A, B tensor.Tensor) (tensor.Tensor, tensor.Ten func (m *MatMul) batchedMatMul(A, B tensor.Tensor) (tensor.Tensor, error) { shapeA := A.Shape() shapeB := B.Shape() + outerShape := append([]int{}, shapeA[:len(shapeA)-2]...) // This will be the shape of the output tensor. @@ -177,7 +200,9 @@ func (m *MatMul) batchedMatMul(A, B tensor.Tensor) (tensor.Tensor, error) { } var err error + var matrixA, matrixB, matrixOut tensor.Tensor + for { matrixA, err = A.Slice(slices...) if err != nil { @@ -224,6 +249,7 @@ func incrementSlices(slices []tensor.Slice, shape []int) bool { slices[i] = ops.NewSlicer(0) // Else we start again for this dimension. } else { slices[i] = ops.NewSlicer(dimSliceStart + 1) + return true } } diff --git a/ops/opset13/matmul_test.go b/ops/opset13/matmul_test.go index 3926836..fa8dcc2 100644 --- a/ops/opset13/matmul_test.go +++ b/ops/opset13/matmul_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -84,7 +83,7 @@ func TestMatMul(t *testing.T) { }, } - for _, test := range tests { + for i, test := range tests { matmul := &MatMul{} inputs := []tensor.Tensor{ ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), @@ -93,7 +92,7 @@ func TestMatMul(t *testing.T) { res, err := matmul.Apply(inputs) assert.Nil(t, err) - assert.Equal(t, test.expected, res[0].Data()) + assert.Equal(t, test.expected, res[0].Data(), "test number %d", i) assert.Equal(t, test.expectedShape, res[0].Shape()) } } @@ -179,14 +178,14 @@ func TestInputValidationMatMul(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("matmul operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &MatMul{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("matmul operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &MatMul{}), }, } @@ -195,6 +194,7 @@ func TestInputValidationMatMul(t *testing.T) { validated, err := matmul.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/mul.go b/ops/opset13/mul.go index a9b815d..3d4db10 100644 --- a/ops/opset13/mul.go +++ b/ops/opset13/mul.go @@ -1,11 +1,16 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinMulInputs = 2 + MaxMulInputs = 2 +) + // Mul represents the ONNX mul operator. type Mul struct{} @@ -15,23 +20,18 @@ func newMul() ops.Operator { } // Init initializes the mul operator. -func (m *Mul) Init(attributes []*onnx.AttributeProto) error { +func (m *Mul) Init(*onnx.NodeProto) error { return nil } // Apply applies the mul operator. func (m *Mul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - in1, in2, err := ops.MultidirectionalBroadcast(inputs[0], inputs[1]) - if err != nil { - return nil, err - } - - out, err := tensor.Mul(in1, in2) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Mul, + ops.MultidirectionalBroadcasting, + ) } // ValidateInputs validates the inputs that will be given to Apply for this operator. @@ -41,12 +41,12 @@ func (m *Mul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (m *Mul) GetMinInputs() int { - return 2 + return MinMulInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (m *Mul) GetMaxInputs() int { - return 2 + return MaxMulInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/mul_test.go b/ops/opset13/mul_test.go index 822333c..e6d00e4 100644 --- a/ops/opset13/mul_test.go +++ b/ops/opset13/mul_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -70,12 +69,7 @@ func TestMulFail(t *testing.T) { assert.Equal( t, err, - fmt.Errorf( - ops.MultidirBroadcastErrTemplate, - []int{2, 2}, - []int{3}, - "incompatible dimensions", - ), + ops.ErrMultidirBroadcast([]int{2, 2}, []int{3}, ops.ErrIncompatibleDimensions()), ) } @@ -130,14 +124,14 @@ func TestInputValidationMul(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("mul operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Mul{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("mul operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Mul{}), }, } @@ -146,6 +140,7 @@ func TestInputValidationMul(t *testing.T) { validated, err := mul.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/not.go b/ops/opset13/not.go new file mode 100644 index 0000000..ba69c56 --- /dev/null +++ b/ops/opset13/not.go @@ -0,0 +1,60 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Not represents the ONNX not operator. +type Not struct{} + +// newNot creates a new not operator. +func newNot() ops.Operator { + return &Not{} +} + +// Init initializes the not operator. +func (n *Not) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the not operator. +func (n *Not) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := inputs[0].Apply(not) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (n *Not) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(n, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (n *Not) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (n *Not) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (n *Not) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Bool}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (n *Not) String() string { + return "not operator" +} + +func not(x bool) bool { + return !x +} diff --git a/ops/opset13/not_test.go b/ops/opset13/not_test.go new file mode 100644 index 0000000..6069622 --- /dev/null +++ b/ops/opset13/not_test.go @@ -0,0 +1,93 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestNotInit(t *testing.T) { + n := &Not{} + + // since 'not' does not have any attributes we pass in nil. This should not + // fail initializing the not. + err := n.Init(nil) + assert.Nil(t, err) +} + +func TestNot(t *testing.T) { + tests := []struct { + not *Not + backing []bool + shape []int + expected []bool + }{ + { + &Not{}, + []bool{true, false, true, false}, + []int{2, 2}, + []bool{false, true, false, true}, + }, + { + &Not{}, + []bool{true, true, false, false}, + []int{1, 4}, + []bool{false, false, true, true}, + }, + { + &Not{}, + []bool{false, false, false, false}, + []int{4, 1}, + []bool{true, true, true, true}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.not.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationNot(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Not{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Not{}), + }, + } + + for _, test := range tests { + not := &Not{} + validated, err := not.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index d3918d9..636f498 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -1,49 +1,77 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops" ) var operators13 = map[string]func() ops.Operator{ + "Abs": newAbs, + "Acos": newAcos, + "Acosh": newAcosh, "Add": newAdd, + "And": newAnd, + "Asin": newAsin, + "Asinh": newAsinh, + "Atan": newAtan, + "Atanh": newAtanh, "Cast": newCast, "Concat": newConcat, "Constant": newConstant, "ConstantOfShape": newConstantOfShape, + "Conv": newConv, + "Cos": newCos, + "Cosh": newCosh, "Div": newDiv, + "Equal": newEqual, "Gather": newGather, "Gemm": newGemm, + "Greater": newGreater, + "GreaterOrEqual": newGreaterOrEqual, "GRU": newGRU, + "Less": newLess, + "LessOrEqual": newLessOrEqual, + "LinearRegressor": newLinearRegressor, + "LSTM": newLSTM, "MatMul": newMatMul, "Mul": newMul, + "Not": newNot, + "Or": newOr, + "PRelu": newPRelu, "Relu": newRelu, "Reshape": newReshape, + "RNN": newRNN, "Scaler": newScaler, "Shape": newShape, "Sigmoid": newSigmoid, + "Sin": newSin, + "Sinh": newSinh, "Slice": newSlice, + "Softmax": newSoftmax, "Squeeze": newSqueeze, "Sub": newSub, + "Tan": newTan, "Tanh": newTanh, "Transpose": newTranspose, "Unsqueeze": newUnsqueeze, + "Xor": newXor, } // GetOperator maps strings as found in the ModelProto to Operators from opset 13. -func GetOperator(opType string) (ops.Operator, error) { - if opInit, ok := operators13[opType]; ok { +func GetOperator(operatorType string) (ops.Operator, error) { + if opInit, ok := operators13[operatorType]; ok { return opInit(), nil } - return nil, fmt.Errorf(ops.UnknowOpTypeErrTemplate, opType) + + return nil, ops.ErrUnknownOperatorType(operatorType) } // GetOpNames returns a list with operator names for opset 13. func GetOpNames() []string { - var opList []string + opList := make([]string, 0, len(operators13)) + for opName := range operators13 { opList = append(opList, opName) } + return opList } diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index e82e1b8..aabd0a9 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" ) func TestGetOperator(t *testing.T) { @@ -14,11 +13,51 @@ func TestGetOperator(t *testing.T) { expected ops.Operator err error }{ + { + "Abs", + newAbs(), + nil, + }, + { + "Acos", + newAcos(), + nil, + }, + { + "Acosh", + newAcosh(), + nil, + }, { "Add", newAdd(), nil, }, + { + "And", + newAnd(), + nil, + }, + { + "Atan", + newAtan(), + nil, + }, + { + "Atanh", + newAtanh(), + nil, + }, + { + "Asin", + newAsin(), + nil, + }, + { + "Asinh", + newAsinh(), + nil, + }, { "Cast", newCast(), @@ -39,11 +78,31 @@ func TestGetOperator(t *testing.T) { newConstantOfShape(), nil, }, + { + "Conv", + newConv(), + nil, + }, + { + "Cos", + newCos(), + nil, + }, + { + "Cosh", + newCosh(), + nil, + }, { "Div", newDiv(), nil, }, + { + "Equal", + newEqual(), + nil, + }, { "Gather", newGather(), @@ -54,11 +113,26 @@ func TestGetOperator(t *testing.T) { newGemm(), nil, }, + { + "Greater", + newGreater(), + nil, + }, + { + "GreaterOrEqual", + newGreaterOrEqual(), + nil, + }, { "GRU", newGRU(), nil, }, + { + "LSTM", + newLSTM(), + nil, + }, { "MatMul", newMatMul(), @@ -69,6 +143,21 @@ func TestGetOperator(t *testing.T) { newMul(), nil, }, + { + "Not", + newNot(), + nil, + }, + { + "Or", + newOr(), + nil, + }, + { + "PRelu", + newPRelu(), + nil, + }, { "Relu", newRelu(), @@ -79,6 +168,11 @@ func TestGetOperator(t *testing.T) { newReshape(), nil, }, + { + "RNN", + newRNN(), + nil, + }, { "Scaler", newScaler(), @@ -94,11 +188,26 @@ func TestGetOperator(t *testing.T) { newSigmoid(), nil, }, + { + "Sin", + newSin(), + nil, + }, + { + "Sinh", + newSinh(), + nil, + }, { "Slice", newSlice(), nil, }, + { + "Softmax", + newSoftmax(), + nil, + }, { "Squeeze", newSqueeze(), @@ -109,6 +218,11 @@ func TestGetOperator(t *testing.T) { newSub(), nil, }, + { + "Tan", + newTan(), + nil, + }, { "Tanh", newTanh(), @@ -124,10 +238,15 @@ func TestGetOperator(t *testing.T) { newUnsqueeze(), nil, }, + { + "Xor", + newXor(), + nil, + }, { "NotYetImplemented", nil, - fmt.Errorf(ops.UnknowOpTypeErrTemplate, "NotYetImplemented"), + ops.ErrUnknownOperatorType("NotYetImplemented"), }, } diff --git a/ops/opset13/or.go b/ops/opset13/or.go new file mode 100644 index 0000000..f660891 --- /dev/null +++ b/ops/opset13/or.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinOrInputs = 2 + MaxOrInputs = 2 +) + +// Or represents the ONNX or operator. +type Or struct{} + +// newOr creates a new or operator. +func newOr() ops.Operator { + return &Or{} +} + +// Init initializes the or operator. +func (o *Or) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the or operator. +func (o *Or) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Or, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (o *Or) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(o, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (o *Or) GetMinInputs() int { + return MinOrInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (o *Or) GetMaxInputs() int { + return MaxOrInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (o *Or) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (o *Or) String() string { + return "or operator" +} diff --git a/ops/opset13/or_test.go b/ops/opset13/or_test.go new file mode 100644 index 0000000..1c370a2 --- /dev/null +++ b/ops/opset13/or_test.go @@ -0,0 +1,104 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestOrInit(t *testing.T) { + o := &Or{} + + // since 'or' does not have any attributes we pass in nil. This should not + // fail initializing the or. + err := o.Init(nil) + assert.Nil(t, err) +} + +func TestOr(t *testing.T) { + tests := []struct { + or *Or + backings [][]bool + shapes [][]int + expected []bool + }{ + { + &Or{}, + [][]bool{{true, false, true, false}, {true, true, true, false}}, + [][]int{{2, 2}, {2, 2}}, + []bool{true, true, true, false}, + }, + { + &Or{}, + [][]bool{{true, false, true, false}, {true, false}}, + [][]int{{2, 2}, {1, 2}}, + []bool{true, false, true, false}, + }, + { + &Or{}, + [][]bool{{true, false, true, false}, {true, false}}, + [][]int{{2, 2}, {2, 1}}, + []bool{true, true, true, false}, + }, + { + &Or{}, + [][]bool{{true, false, true, false, true, false}, {false, false}}, + [][]int{{3, 2}, {1, 2}}, + []bool{true, false, true, false, true, false}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.or.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationOr(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + ops.ErrInvalidInputCount(1, &Or{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(1, "int", &Or{}), + }, + } + + for _, test := range tests { + or := &Or{} + validated, err := or.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/prelu.go b/ops/opset13/prelu.go new file mode 100644 index 0000000..bfdc5d2 --- /dev/null +++ b/ops/opset13/prelu.go @@ -0,0 +1,134 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + PReluMinInputs = 2 + PReluMaxInputs = 2 +) + +// PRelu represents the ONNX prelu operator. +type PRelu struct{} + +// newPRelu creates a new prelu operator. +func newPRelu() ops.Operator { + return &PRelu{} +} + +// Init initializes the prelu operator. +func (op *PRelu) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the prelu operator. +func (op *PRelu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + x, slope := inputs[0], inputs[1] + + x, slope, err = ops.UnidirectionalBroadcast(x, slope) + if err != nil { + return nil, err + } + + y := tensor.NewDense(x.Dtype(), x.Shape()) + + switch x.Dtype() { + case tensor.Float32: + err = calcPRelu[float32](y.Data(), x.Data(), slope.Data()) + case tensor.Float64: + err = calcPRelu[float64](y.Data(), x.Data(), slope.Data()) + case tensor.Uint32: + err = calcPRelu[uint32](y.Data(), x.Data(), slope.Data()) + case tensor.Uint64: + err = calcPRelu[uint64](y.Data(), x.Data(), slope.Data()) + case tensor.Int32: + err = calcPRelu[int32](y.Data(), x.Data(), slope.Data()) + case tensor.Int64: + err = calcPRelu[int64](y.Data(), x.Data(), slope.Data()) + default: + return nil, ops.ErrInvalidInputType(0, x.Dtype().String(), op) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{y}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (op *PRelu) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + inputs, err := ops.ValidateInputs(op, inputs) + if err != nil { + return nil, err + } + + x, slope := inputs[0], inputs[1] + if x.Dtype() != slope.Dtype() { + return nil, ops.ErrInvalidTensor("DType of 'slope' does not match DType of 'x'", op) + } + + return inputs, nil +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (op *PRelu) GetMinInputs() int { + return PReluMinInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (op *PRelu) GetMaxInputs() int { + return PReluMaxInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (op *PRelu) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (op *PRelu) String() string { + return "prelu operator" +} + +func calcPRelu[T float32 | float64 | uint32 | uint64 | int32 | int64](result any, input any, slope any) error { + var convertedResult []T + + var convertedInput []T + + var convertedSlope []T + + convertedResult, ok := result.([]T) + if !ok { + return ops.ErrTypeAssert("numeric list", result) + } + + convertedInput, ok = input.([]T) + if !ok { + return ops.ErrTypeAssert("numeric list", input) + } + + convertedSlope, ok = slope.([]T) + if !ok { + return ops.ErrTypeAssert("numeric list", slope) + } + + for i, v := range convertedInput { + if v < 0 { + v = convertedSlope[i] * v + } + + convertedResult[i] = v + } + + return nil +} diff --git a/ops/opset13/prelu_test.go b/ops/opset13/prelu_test.go new file mode 100644 index 0000000..763cbfb --- /dev/null +++ b/ops/opset13/prelu_test.go @@ -0,0 +1,108 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestPReluInit(t *testing.T) { + p := &PRelu{} + + // since the prelu does not have any attributes we pass in nil. This should not + // fail initializing the prelu. + err := p.Init(nil) + assert.Nil(t, err) +} + +func TestPRelu(t *testing.T) { + tests := []struct { + prelu *PRelu + backing []float32 + slope []float32 + shape []int + expected []float32 + }{ + { + &PRelu{}, + []float32{-4, -4, -4, -3, -2, -1}, + []float32{2, 2, 4, 4, 0, 0}, + []int{3, 2}, + []float32{-8, -8, -16, -12, 0, 0}, + }, + { + &PRelu{}, + []float32{-4, -4, -4, 3, 2, 1}, + []float32{2, 2, 4, 4, 0, 0}, + []int{3, 2}, + []float32{-8, -8, -16, 3, 2, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + ops.TensorWithBackingFixture(test.slope, test.shape...), + } + res, err := test.prelu.Apply(inputs) + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationPRelu(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &PRelu{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &PRelu{}), + }, + } + + for _, test := range tests { + prelu := &PRelu{} + validated, err := prelu.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func BenchmarkPRelu_Apply(b *testing.B) { + prelu := &PRelu{} + input := ops.Float32TensorFixture(3, 256, 256) + slope := ops.Float32TensorFixture(3, 256, 256) + inputs := []tensor.Tensor{input, slope} + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + y, err := prelu.Apply(inputs) + if err != nil { + b.Fatal(err) + } + + _ = y + } +} diff --git a/ops/opset13/relu.go b/ops/opset13/relu.go index 186007d..370940a 100644 --- a/ops/opset13/relu.go +++ b/ops/opset13/relu.go @@ -1,8 +1,8 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) @@ -15,25 +15,13 @@ func newRelu() ops.Operator { } // Init initializes the relu operator. -func (r *Relu) Init(attributes []*onnx.AttributeProto) error { +func (r *Relu) Init(*onnx.NodeProto) error { return nil } // Apply applies the relu operator. func (r *Relu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - t := inputs[0] - - typedZero, err := ops.GetValueAsTensorType(0.0, t.Dtype()) - if err != nil { - return nil, err - } - - comparison, err := tensor.Gt(t, typedZero, tensor.AsSameType()) - if err != nil { - return nil, err - } - - out, err := tensor.Mul(t, comparison) + out, err := ops.ReLU(inputs[0]) if err != nil { return nil, err } diff --git a/ops/opset13/relu_test.go b/ops/opset13/relu_test.go index 0528138..b2d5fa0 100644 --- a/ops/opset13/relu_test.go +++ b/ops/opset13/relu_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -68,11 +67,11 @@ func TestInputValidationRelu(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("relu operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Relu{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("relu operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Relu{}), }, } @@ -81,6 +80,7 @@ func TestInputValidationRelu(t *testing.T) { validated, err := relu.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/reshape.go b/ops/opset13/reshape.go index 7a9508a..a2a8f59 100644 --- a/ops/opset13/reshape.go +++ b/ops/opset13/reshape.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + ReshapeMinInputs = 2 + ReshapeMaxInputs = 2 +) + // Reshape represents the ONNX reshape operator. type Reshape struct{} @@ -17,13 +20,14 @@ func newReshape() ops.Operator { } // Init initializes the reshape operator. -func (r *Reshape) Init(attributes []*onnx.AttributeProto) error { +func (r *Reshape) Init(*onnx.NodeProto) error { return nil } // Apply applies the reshape operator. func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { t := inputs[0] + newShape, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data().([]int64))) if err != nil { return nil, err @@ -34,8 +38,13 @@ func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - out := t.Clone().(tensor.Tensor) + out, ok := t.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", t.Clone()) + } + err = out.Reshape(newShape...) + return []tensor.Tensor{out}, err } @@ -46,12 +55,12 @@ func (r *Reshape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error // GetMinInputs returns the minimum number of input tensors this operator expects. func (r *Reshape) GetMinInputs() int { - return 2 + return ReshapeMinInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (r *Reshape) GetMaxInputs() int { - return 2 + return ReshapeMaxInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -69,8 +78,9 @@ func processShape(newShape, currentShape []int) error { for i := 0; i < len(newShape); i++ { if newShape[i] == 0 { if i >= len(currentShape) { - return fmt.Errorf("could not infer dim size") + return ops.ErrDimension("could not infer dim size") } + newShape[i] = currentShape[i] } } @@ -82,19 +92,21 @@ func processShape(newShape, currentShape []int) error { // When encountering a -1 dim size, calculate which size this should be. if newShape[i] == -1 { remainingSize := totalSize + for j := 0; j < len(newShape); j++ { if j == i { continue } if newShape[j] == -1 { - return fmt.Errorf("At most one -1 dim size is allowed") + return ops.ErrDimension("at most one -1 dim size is allowed") } remainingSize /= newShape[j] } newShape[i] = remainingSize + break } } diff --git a/ops/opset13/reshape_test.go b/ops/opset13/reshape_test.go index 567ca46..8651d32 100644 --- a/ops/opset13/reshape_test.go +++ b/ops/opset13/reshape_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -84,14 +83,14 @@ func TestInputValidationReshape(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("reshape operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Reshape{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("reshape operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &Reshape{}), }, } @@ -100,6 +99,7 @@ func TestInputValidationReshape(t *testing.T) { validated, err := reshape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go new file mode 100644 index 0000000..b3248d8 --- /dev/null +++ b/ops/opset13/rnn.go @@ -0,0 +1,249 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinRNNInputs = 3 + MaxRNNInputs = 6 +) + +// RNN represents the ONNX rnn operator. +type RNN struct { + activationAlpha []float32 + activationBeta []float32 + activations []string + direction ops.SequenceProcessDirection + hiddenSize int +} + +// newRNN creates a new rnn operator. +func newRNN() ops.Operator { + return &RNN{ + activations: []string{"tanh"}, + direction: ops.Forward, + } +} + +// Init initializes the rnn operator. +func (r *RNN) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case ops.ActivationAlphaAttr: + r.activationAlpha = attr.GetFloats() + case ops.ActivationBetaAttr: + r.activationBeta = attr.GetFloats() + case ops.ActivationsAttr: + activations := []string{} + for _, activation := range attr.GetStrings() { + activations = append(activations, string(activation)) + } + + r.activations = activations + case ops.ClipAttr: + return ops.ErrUnsupportedAttribute(attr.GetName(), r) + case ops.DirectionAttr: + r.direction = ops.SequenceProcessDirection(attr.GetS()) + if r.direction != ops.Forward { + return ops.ErrUnsupportedAttribute(attr.GetName(), r) + } + case ops.HiddenSizeAttr: + r.hiddenSize = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), r) + } + } + + return nil +} + +// Apply applies the rnn operator. +func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + if inputs[4] != nil { + return nil, ops.ErrUnsupportedInput("sequence lens", r) + } + + X := inputs[0] + seqLength := X.Shape()[0] + batchSize := X.Shape()[1] + + Wi, err := r.getWeights(inputs[1]) + if err != nil { + return nil, err + } + + Ri, err := r.getWeights(inputs[2]) + if err != nil { + return nil, err + } + + B := inputs[3] + if B == nil { + // 2 is the number of bias matrices required by ONNX definition. + nBiasMatrices := 2 + B = ops.ZeroTensor(1, nBiasMatrices*r.hiddenSize) + } + + Wbi, Rbi, err := r.getBiases(B) + if err != nil { + return nil, err + } + + Ht := inputs[5] + if Ht == nil { + Ht = ops.ZeroTensor(1, batchSize, r.hiddenSize) + } + + // Reshape the hidden tensor without the bidirectional dimension, as + // we do not support bidirectional RNN yet. This is the dimension at + // index 0. + if err = Ht.Reshape(Ht.Shape().Clone()[1:]...); err != nil { + return nil, err + } + + activation, err := ops.GetActivation(r.activations[0]) + if err != nil { + return nil, err + } + + outputs := []tensor.Tensor{} + + // Loop over all timesteps of the input, applying the RNN calculation to every + // timesteps while updating the hidden tensor. + for t := 0; t < seqLength; t++ { + Xt, err := X.Slice(ops.NewSlicer(t, t+1), nil, nil) + if err != nil { + return nil, err + } + + Ht, err = r.layerCalculation(Xt, Ht, Wi, Ri, Wbi, Rbi, activation) + if err != nil { + return nil, err + } + + outputs = append(outputs, Ht) + } + + Y := outputs[0] + if len(outputs) > 1 { + Y, err = tensor.Concat(0, Y, outputs[1:]...) + if err != nil { + return nil, err + } + } + + Yh, ok := Ht.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", Ht.Clone()) + } + + // Reshape the hidden tensor without the bidirectional dimension, as + // we do not support bidirectional RNN yet. This is the dimension at + // index 0. + if err = Y.Reshape(seqLength, 1, batchSize, r.hiddenSize); err != nil { + return nil, err + } + + if err = Yh.Reshape(1, batchSize, r.hiddenSize); err != nil { + return nil, err + } + + return []tensor.Tensor{Y, Yh}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (r *RNN) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(r, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (r *RNN) GetMinInputs() int { + return MinRNNInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (r *RNN) GetMaxInputs() int { + return MaxRNNInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (r *RNN) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (r *RNN) String() string { + return "rnn operator" +} + +// layerCalculation performs the actual RNN calculation. By ONNX definition +// this is: +// +// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) +// +// We achieve this by two Gemm operations, adding them together and finally +// putting them through an activation function. +func (r *RNN) layerCalculation( + Xt, H, Wi, Ri, Wbi, Rbi tensor.Tensor, activation ops.Activation, +) (tensor.Tensor, error) { + gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + + inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, Wi, Wbi}) + if err != nil { + return nil, err + } + + hiddenCalc, err := gemm.Apply([]tensor.Tensor{H, Ri, Rbi}) + if err != nil { + return nil, err + } + + result, err := tensor.Add(inputCalc[0], hiddenCalc[0]) + if err != nil { + return nil, err + } + + return activation(result) +} + +// getWeights returns the weights from a concatenated weight tensor. The result is +// a single weight matrix. W has shape (num_directions, hidden_size, ...). +// The W tensor, by GONNX definition, has 3 dimensions with 1 weight +// tensor in it (2 if bidirectional, but that is not supported). +func (r *RNN) getWeights(W tensor.Tensor) (tensor.Tensor, error) { + nWeightMatrices := 1 + nWeightDimensions := 3 + + weights, err := ops.ExtractMatrices(W, nWeightMatrices, nWeightDimensions, r.hiddenSize) + if err != nil { + return nil, err + } + + return weights[0], nil +} + +// getBiases splits tensor B into 2 bias matrices. +// The B tensor, by GONNX definition, has 2 dimensions with 2 bias +// tensors in it (4 if bidirectional, but that is not supported). +func (r *RNN) getBiases(B tensor.Tensor) (Wbi, Rbi tensor.Tensor, err error) { + nBiasMatrices := 2 + nBiasDimensions := 2 + + b, err := ops.ExtractMatrices(B, nBiasMatrices, nBiasDimensions, r.hiddenSize) + if err != nil { + return nil, nil, err + } + + return b[0], b[1], nil +} diff --git a/ops/opset13/rnn_test.go b/ops/opset13/rnn_test.go new file mode 100644 index 0000000..a987ddd --- /dev/null +++ b/ops/opset13/rnn_test.go @@ -0,0 +1,334 @@ +package opset13 + +import ( + "math/rand" + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestRNNInit(t *testing.T) { + rnn := &RNN{} + err := rnn.Init(RNNOnnxNodeProtoFixture()) + + assert.Nil(t, err) + assert.Equal(t, []float32{1.0}, rnn.activationAlpha) + assert.Equal(t, []float32{2.0}, rnn.activationBeta) + assert.Equal(t, []string{"sigmoid"}, rnn.activations) + assert.Equal(t, ops.SequenceProcessDirection("forward"), rnn.direction) + assert.Equal(t, 5, rnn.hiddenSize) +} + +func TestRNNInitUnsupportedAttr(t *testing.T) { + rnn := RNN{} + err := rnn.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "clip"}}}) + assert.Equal(t, err, ops.ErrUnsupportedAttribute("clip", &rnn)) +} + +func TestRNNInitUnknownAttr(t *testing.T) { + rnn := RNN{} + err := rnn.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknown"}}}) + assert.Equal(t, err, ops.ErrInvalidAttribute("unknown", &rnn)) +} + +func TestRNN(t *testing.T) { + tests := []struct { + rnn *RNN + inputs ops.InputFixture + expected []float32 + err error + }{ + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: ops.Forward, + hiddenSize: 4, + }, + rnnInput0, + []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"sigmoid"}, + direction: ops.Forward, + hiddenSize: 4, + }, + rnnInput0, + []float32{0.82048327, 0.922734, 0.89050114, 0.8620579}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"relu"}, + direction: ops.Forward, + hiddenSize: 4, + }, + rnnInput0, + []float32{1.0667435, 2.328037, 1.7986122, 1.545068}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: ops.Forward, + hiddenSize: 10, + }, + rnnInput1, + []float32{0.99996024, 0.9999855, 0.99998087, 0.9999288, 0.9997511, 0.99918234, 0.99999964, 0.9999981, 0.9997658, 0.9999618, 0.9998762, 0.9999353, 0.9999194, 0.9999428, 0.9997284, 0.9982606, 0.999999, 0.9999897, 0.99964744, 0.9998234, 0.99997497, 0.9999893, 0.9999906, 0.9999812, 0.99983937, 0.99967873, 0.9999998, 0.9999965, 0.9999516, 0.9999541}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: ops.Forward, + hiddenSize: 4, + }, + rnnInputNoB, + // Same values as first test, but B is initialized automatically. + []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, + nil, + }, + { + &RNN{ + activationAlpha: []float32{}, + activationBeta: []float32{}, + activations: []string{"tanh"}, + direction: ops.Forward, + hiddenSize: 4, + }, + rnnInputNoBNoH, + // Same values as first test, but B and H are initialized automatically. + []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, + nil, + }, + } + + for _, test := range tests { + inputs := test.inputs() + res, err := test.rnn.Apply(inputs) + assert.Equal(t, test.err, err) + + if err == nil { + assert.Equal(t, test.expected, res[1].Data()) + } + } +} + +func TestInputValidationRNN(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + expected []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + nil, + nil, + nil, + }, + nil, + }, + { + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(0, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(1, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(2, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(3, "int", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(4, "float32", &RNN{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + nil, + ops.ErrInvalidInputType(5, "int", &RNN{}), + }, + } + + for _, test := range tests { + rnn := &RNN{} + validated, err := rnn.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + if test.expected != nil { + assert.Equal(t, test.expected, validated) + } else { + assert.Equal(t, test.inputs, validated) + } + } + } +} + +func rnnInput0() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 4, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 4, 4), + // Input B: (num_directions, 2 * hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 8)), 1, 8), + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 1, 4)), 1, 1, 4), + } +} + +func rnnInput1() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(10, 3, 4), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 10, 4), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 10, 10), + // Input B: (num_directions, 2 * hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 20)), 1, 20), + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 3, 10)), 1, 3, 10), + } +} + +func rnnInputNoB() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 4, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 4, 4), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + ops.TensorWithBackingFixture(ops.Zeros(ops.NElements(1, 1, 4)), 1, 1, 4), + } +} + +func rnnInputNoBNoH() []tensor.Tensor { + rand.Seed(13) + + return []tensor.Tensor{ + // Input X: (sequence_length, batch_size, input_size). + ops.RandomFloat32TensorFixture(2, 1, 3), + // Input W: (num_directions, hidden_size, input_size). + ops.RandomFloat32TensorFixture(1, 4, 3), + // Input R: (num_directions, hidden_size, hidden_size). + ops.RandomFloat32TensorFixture(1, 4, 4), + // Input B: not provided. + nil, + // Input sequence_lens: not supported + nil, + // Input initial_h: (num_directions, batch_size, hidden_size) + nil, + } +} + +func RNNOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{1.0}}, + {Name: "activation_beta", Floats: []float32{2.0}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 5}, + }, + } +} diff --git a/ops/opset13/scaler.go b/ops/opset13/scaler.go index e4c24e8..c5eb53b 100644 --- a/ops/opset13/scaler.go +++ b/ops/opset13/scaler.go @@ -1,13 +1,17 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + ScalerExpectedAttributes = 2 + MinScalerInputs = 1 + MaxScalerInputs = 1 +) + // Scaler represents the ONNX-ml scaler operator. type Scaler struct { offset tensor.Tensor @@ -20,9 +24,10 @@ func newScaler() ops.Operator { } // Init initializes the scaler operator. -func (s *Scaler) Init(attributes []*onnx.AttributeProto) error { - if len(attributes) != 2 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, s, 2, len(attributes)) +func (s *Scaler) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != ScalerExpectedAttributes { + return ops.ErrInvalidAttributeCount(ScalerExpectedAttributes, len(attributes), s) } for _, attr := range attributes { @@ -34,7 +39,7 @@ func (s *Scaler) Init(attributes []*onnx.AttributeProto) error { floats := attr.GetFloats() s.scale = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) default: - return fmt.Errorf(ops.UnknownAttributeErrTemplate, s, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), s) } } @@ -73,12 +78,12 @@ func (s *Scaler) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Scaler) GetMinInputs() int { - return 1 + return MinScalerInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Scaler) GetMaxInputs() int { - return 1 + return MaxScalerInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/scaler_test.go b/ops/opset13/scaler_test.go index d83caee..2a6dfae 100644 --- a/ops/opset13/scaler_test.go +++ b/ops/opset13/scaler_test.go @@ -1,18 +1,17 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) func TestScalerInit(t *testing.T) { scaler := &Scaler{} - err := scaler.Init(ScalerOnnxAttributeProtoFixture()) + err := scaler.Init(ScalerOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, []float32{1.5, 2.5, 3.5}, scaler.offset.Data()) @@ -21,17 +20,17 @@ func TestScalerInit(t *testing.T) { func TestScalerInitFailWrongAttribute(t *testing.T) { scaler := &Scaler{} - err := scaler.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}, {Name: "Another"}}) + err := scaler.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}, {Name: "Another"}}}) - expected := fmt.Errorf(ops.UnknownAttributeErrTemplate, scaler, "unknownAttribute") + expected := ops.ErrInvalidAttribute("unknownAttribute", scaler) assert.Equal(t, expected, err) } func TestScalerInitFailAttrCount(t *testing.T) { scaler := &Scaler{} - err := scaler.Init([]*onnx.AttributeProto{}) + err := scaler.Init(ops.EmptyNodeProto()) - expected := fmt.Errorf(ops.InvalidAttrCountErrTemplate, scaler, 2, 0) + expected := ops.ErrInvalidAttributeCount(2, 0, scaler) assert.Equal(t, expected, err) } @@ -109,11 +108,11 @@ func TestInputValidationScaler(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("scaler operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Scaler{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("scaler operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Scaler{}), }, } @@ -122,15 +121,18 @@ func TestInputValidationScaler(t *testing.T) { validated, err := scaler.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } } } -func ScalerOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "offset", Floats: []float32{1.5, 2.5, 3.5}}, - {Name: "scale", Floats: []float32{0.5, 1.0, 2.0}}, +func ScalerOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "offset", Floats: []float32{1.5, 2.5, 3.5}}, + {Name: "scale", Floats: []float32{0.5, 1.0, 2.0}}, + }, } } diff --git a/ops/opset13/shape.go b/ops/opset13/shape.go index 10aab66..bb99709 100644 --- a/ops/opset13/shape.go +++ b/ops/opset13/shape.go @@ -1,11 +1,16 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinShapeInputs = 1 + MaxShapeInputs = 1 +) + // Shape represents the ONNX shape operator. type Shape struct{} @@ -15,7 +20,7 @@ func newShape() ops.Operator { } // Init initializes the shape operator. -func (s *Shape) Init(attributes []*onnx.AttributeProto) error { +func (s *Shape) Init(*onnx.NodeProto) error { return nil } @@ -24,11 +29,13 @@ func (s *Shape) Init(attributes []*onnx.AttributeProto) error { func (s *Shape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { nodeShape := inputs[0].Shape() shape := make([]int64, len(nodeShape)) + for i, dimSize := range nodeShape { shape[i] = int64(dimSize) } out := tensor.New(tensor.WithShape(len(nodeShape)), tensor.WithBacking(shape)) + return []tensor.Tensor{out}, nil } @@ -39,12 +46,12 @@ func (s *Shape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Shape) GetMinInputs() int { - return 1 + return MinShapeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Shape) GetMaxInputs() int { - return 1 + return MaxShapeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/shape_test.go b/ops/opset13/shape_test.go index edf88e9..1ab9382 100644 --- a/ops/opset13/shape_test.go +++ b/ops/opset13/shape_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -60,11 +59,11 @@ func TestInputValidationShape(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("shape operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Shape{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("shape operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Shape{}), }, } @@ -73,6 +72,7 @@ func TestInputValidationShape(t *testing.T) { validated, err := shape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/sigmoid.go b/ops/opset13/sigmoid.go index 424104b..b8bc077 100644 --- a/ops/opset13/sigmoid.go +++ b/ops/opset13/sigmoid.go @@ -1,8 +1,8 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) @@ -15,13 +15,14 @@ func newSigmoid() ops.Operator { } // Init initializes the sigmoid operator. -func (s *Sigmoid) Init(attributes []*onnx.AttributeProto) error { +func (s *Sigmoid) Init(*onnx.NodeProto) error { return nil } // Apply the sigmoid operator to the input node. func (s *Sigmoid) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { out, err := ops.Sigmoid(inputs[0]) + return []tensor.Tensor{out}, err } diff --git a/ops/opset13/sigmoid_test.go b/ops/opset13/sigmoid_test.go index 71ed130..3277a6f 100644 --- a/ops/opset13/sigmoid_test.go +++ b/ops/opset13/sigmoid_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -26,23 +25,29 @@ func TestSigmoid(t *testing.T) { { []float32{-4, -3, -2, -1, 0, 12}, []int{3, 2}, - []float32{0.01798620996209155802679, + []float32{ + 0.01798620996209155802679, 0.04742587317756678087885, 0.1192029220221175559403, 0.2689414213699951207488, 0.5, - 0.9999938558253977852822}, + 0.9999938558253977852822, + }, }, { []float32{-4, -4, -4, 3, 2, 1}, []int{3, 2}, - []float32{0.01798621, 0.01798621, 0.01798621, - 0.95257413, 0.8807971, 0.7310586}, + []float32{ + 0.01798621, 0.01798621, 0.01798621, + 0.95257413, 0.8807971, 0.7310586, + }, }, { []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, []int{4, 3}, - []float32{0.5, 0.7310586, 0.8807971, 0.95257413, + []float32{ + 0.5, 0.7310586, 0.8807971, 0.95257413, 0.98201376, 0.9933072, 0.99752736, 0.99908894, - 0.99966466, 0.9998766, 0.9999546, 0.9999833}, + 0.99966466, 0.9998766, 0.9999546, 0.9999833, + }, }, } @@ -73,11 +78,11 @@ func TestInputValidationSigmoid(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("sigmoid operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Sigmoid{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("sigmoid operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Sigmoid{}), }, } @@ -86,6 +91,7 @@ func TestInputValidationSigmoid(t *testing.T) { validated, err := sigmoid.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/sin.go b/ops/opset13/sin.go new file mode 100644 index 0000000..ff61a71 --- /dev/null +++ b/ops/opset13/sin.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Sin represents the ONNX sin operator. +type Sin struct{} + +// newSin creates a new sin operator. +func newSin() ops.Operator { + return &Sin{} +} + +// Init initializes the sin operator. +func (s *Sin) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the sin operator. +func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(sin[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(sin[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (s *Sin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(s, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (s *Sin) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (s *Sin) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (s *Sin) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (s *Sin) String() string { + return "sin operator" +} + +func sin[T ops.FloatType](x T) T { + return T(math.Sin(float64(x))) +} diff --git a/ops/opset13/sin_test.go b/ops/opset13/sin_test.go new file mode 100644 index 0000000..1ec4483 --- /dev/null +++ b/ops/opset13/sin_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestSinInit(t *testing.T) { + a := &Sin{} + + // since 'sin' does not have any attributes we pass in nil. This should not + // fail initializing the sin. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestSin(t *testing.T) { + tests := []struct { + sin *Sin + backing []float32 + shape []int + expected []float32 + }{ + { + &Sin{}, + []float32{-2, -1, 0, 1}, + []int{2, 2}, + []float32{-0.9092974, -0.84147096, 0, 0.84147096}, + }, + { + &Sin{}, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{0.84147096, 0.14112, -0.7568025, -0.9589243}, + }, + { + &Sin{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{-0.84147096, -0.84147096, -0.84147096, -0.84147096}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.sin.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationSin(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Sin{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Sin{}), + }, + } + + for _, test := range tests { + sin := &Sin{} + validated, err := sin.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/sinh.go b/ops/opset13/sinh.go new file mode 100644 index 0000000..19d81e7 --- /dev/null +++ b/ops/opset13/sinh.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Sinh represents the ONNX sinh operator. +type Sinh struct{} + +// newSin creates a new sinh operator. +func newSinh() ops.Operator { + return &Sinh{} +} + +// Init initializes the sinh operator. +func (s *Sinh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the sinh operator. +func (s *Sinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(sinh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(sinh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (s *Sinh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(s, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (s *Sinh) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (s *Sinh) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (s *Sinh) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (s *Sinh) String() string { + return "sinh operator" +} + +func sinh[T ops.FloatType](x T) T { + return T(math.Sinh(float64(x))) +} diff --git a/ops/opset13/sinh_test.go b/ops/opset13/sinh_test.go new file mode 100644 index 0000000..3288490 --- /dev/null +++ b/ops/opset13/sinh_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestSinhInit(t *testing.T) { + s := &Sinh{} + + // since 'sinh' does not have any attributes we pass in nil. This should not + // fail initializing the sinh. + err := s.Init(nil) + assert.Nil(t, err) +} + +func TestSinh(t *testing.T) { + tests := []struct { + sinh *Sinh + backing []float32 + shape []int + expected []float32 + }{ + { + &Sinh{}, + []float32{-2, -1, 0, 1}, + []int{2, 2}, + []float32{-3.6268604, -1.1752012, 0, 1.1752012}, + }, + { + &Sinh{}, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{1.1752012, 10.017875, 27.289917, 74.20321}, + }, + { + &Sinh{}, + []float32{-1, -1, -1, -1}, + []int{1, 4}, + []float32{-1.1752012, -1.1752012, -1.1752012, -1.1752012}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.sinh.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationSinh(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Sinh{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Sinh{}), + }, + } + + for _, test := range tests { + sinh := &Sinh{} + validated, err := sinh.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/slice.go b/ops/opset13/slice.go index c0c7fec..d7589f5 100644 --- a/ops/opset13/slice.go +++ b/ops/opset13/slice.go @@ -1,11 +1,16 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinSliceInputs = 3 + MaxSliceInputs = 5 +) + // Slice represents the ONNX slice operator. type Slice struct{} @@ -15,13 +20,14 @@ func newSlice() ops.Operator { } // Init initializes the slice operator. -func (s *Slice) Init(attributes []*onnx.AttributeProto) error { +func (s *Slice) Init(*onnx.NodeProto) error { return nil } // Apply applies the slice operator. func (s *Slice) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { data := inputs[0] + starts, err := ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[1].Data())) if err != nil { return nil, err @@ -65,12 +71,12 @@ func (s *Slice) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Slice) GetMinInputs() int { - return 3 + return MinSliceInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Slice) GetMaxInputs() int { - return 5 + return MaxSliceInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -102,6 +108,7 @@ func (s *Slice) constructSlices(starts, ends, steps, axes []int, nTotalSlices in if ax < 0 { ax = nTotalSlices + ax } + slices[ax] = ops.NewSlicer(starts[i], ends[i], steps[i]) } @@ -114,6 +121,7 @@ func (s *Slice) getDefaultAxes(nSlices int) []int { for i := 0; i < nSlices; i++ { axes[i] = i } + return axes } @@ -123,5 +131,6 @@ func (s *Slice) getDefaultSteps(nSlices int) []int { for i := 0; i < nSlices; i++ { steps[i] = 1 } + return steps } diff --git a/ops/opset13/slice_test.go b/ops/opset13/slice_test.go index f7eafaa..652608f 100644 --- a/ops/opset13/slice_test.go +++ b/ops/opset13/slice_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -140,6 +139,7 @@ func TestConstructSlices(t *testing.T) { ) assert.Equal(t, test.nSlices, len(slices)) + for i := 0; i < test.nSlices; i++ { if test.expectedSlices[i] == nil { assert.Nil(t, slices[i]) @@ -199,7 +199,7 @@ func TestInputValidationSlice(t *testing.T) { { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - fmt.Errorf("slice operator: expected 3-5 input tensors, got 1"), + ops.ErrInvalidOptionalInputCount(1, &Slice{}), }, { []tensor.Tensor{ @@ -208,7 +208,7 @@ func TestInputValidationSlice(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("slice operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &Slice{}), }, } @@ -217,6 +217,7 @@ func TestInputValidationSlice(t *testing.T) { validated, err := slice.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) diff --git a/ops/opset13/softmax.go b/ops/opset13/softmax.go new file mode 100644 index 0000000..8a2c0c0 --- /dev/null +++ b/ops/opset13/softmax.go @@ -0,0 +1,84 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Softmax represents the ONNX softmax operator. +type Softmax struct { + // The axis along which to perform the Softmax operation. + axis int +} + +// newSoftmax creates a new softmax operator. +func newSoftmax() ops.Operator { + return &Softmax{ + axis: -1, // This is the default value by ONNX definition. + } +} + +// Init initializes the softmax operator. +func (s *Softmax) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + nAttributes := len(attributes) + + if nAttributes > 1 { + return ops.ErrInvalidAttributeCount(1, nAttributes, s) + } + + if nAttributes == 1 { + s.axis = int(attributes[0].GetI()) + } + + return nil +} + +// Apply applies the softmax operator. +func (s *Softmax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + input := inputs[0] + nDims := len(input.Shape()) + + if s.axis < -nDims || s.axis >= nDims { + return nil, ops.ErrAxisOutOfRange(-nDims, nDims, s.axis) + } + + axis := s.axis + if s.axis < 0 { + axis += nDims + } + + out, err := tensor.SoftMax(inputs[0], axis) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (s *Softmax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(s, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (s *Softmax) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (s *Softmax) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (s *Softmax) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (s *Softmax) String() string { + return "softmax operator" +} diff --git a/ops/opset13/softmax_test.go b/ops/opset13/softmax_test.go new file mode 100644 index 0000000..5c01bc0 --- /dev/null +++ b/ops/opset13/softmax_test.go @@ -0,0 +1,140 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestSoftmaxInit(t *testing.T) { + s := &Softmax{} + + // since 'softmax' does not have any attributes we pass in nil. This should not + // fail initializing the softmax. + err := s.Init(nil) + assert.Nil(t, err) +} + +func TestSoftmax(t *testing.T) { + tests := []struct { + softmax *Softmax + backing []float32 + shape []int + expected []float32 + }{ + { + &Softmax{ + axis: -1, + }, + []float32{0, 1, 2, 3}, + []int{1, 4}, + []float32{0.032058604, 0.087144315, 0.2368828, 0.6439142}, + }, + { + &Softmax{ + axis: 1, + }, + []float32{0, 1, 2, 3}, + []int{1, 4}, + []float32{0.032058604, 0.087144315, 0.2368828, 0.6439142}, + }, + { + &Softmax{ + axis: -1, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0.26894143, 0.7310586, 0.26894143, 0.7310586}, + }, + { + &Softmax{ + axis: -1, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{1, 2, 3}, + []float32{0.09003057, 0.24472848, 0.66524094, 0.09003057, 0.24472848, 0.66524094}, + }, + { + &Softmax{ + axis: -1, + }, + []float32{0, 1, 2, 3}, + []int{4, 1}, + []float32{1, 1, 1, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.softmax.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestSoftmaxFail(t *testing.T) { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2, 3, 4}, 2, 2), + } + + softmax := &Softmax{ + // This axis is out of range, because the input tensor only has 2 dimensions. + axis: 3, + } + _, err := softmax.Apply(inputs) + assert.Equal( + t, + err, + ops.ErrAxisOutOfRange(-2, 2, 3), + ) +} + +func TestInputValidationSoftmax(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(2, &Softmax{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Softmax{}), + }, + } + + for _, test := range tests { + softmax := &Softmax{} + validated, err := softmax.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/squeeze.go b/ops/opset13/squeeze.go index f621065..d4c9055 100644 --- a/ops/opset13/squeeze.go +++ b/ops/opset13/squeeze.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinSqueezeInputs = 1 + MaxSqueezeInputs = 2 +) + // Squeeze represents the ONNX squeeze operator. type Squeeze struct{} @@ -17,19 +20,20 @@ func newSqueeze() ops.Operator { } // Init initializes the squeeze operator. -func (s *Squeeze) Init(attributes []*onnx.AttributeProto) error { +func (s *Squeeze) Init(*onnx.NodeProto) error { return nil } // Apply applies the squeeze operator. func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { var err error + currentShape := inputs[0].Shape() nDims := len(currentShape) dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) if !ops.AllInRange(dimsToSqueeze, -nDims, nDims-1) { - return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, nDims, nDims) + return nil, ops.ErrNotAllAxesInRange(nDims, nDims) } // negative entries should be offset by the rank of the output tensor @@ -45,8 +49,13 @@ func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { newShape := getNewShape(currentShape, dimsToSqueeze) - out := inputs[0].Clone().(tensor.Tensor) + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + err = out.Reshape(newShape...) + return []tensor.Tensor{out}, err } @@ -57,12 +66,12 @@ func (s *Squeeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Squeeze) GetMinInputs() int { - return 1 + return MinSqueezeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Squeeze) GetMaxInputs() int { - return 2 + return MaxSqueezeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -91,29 +100,34 @@ func getDimsToSqueezeFromTensor(t tensor.Tensor, nDims int) ([]int, error) { dimsToSqueeze[i] = nDims + val } } + return dimsToSqueeze, nil } // getDimsToSqueezeFromShape creates a list with ints representing the dimensions/axes to squeeze // based on the current shape. All dimensions with only 1 value will be squeezed. func getDimsToSqueezeFromShape(shape []int) []int { - var res []int + result := []int{} + for i, size := range shape { if size == 1 { - res = append(res, i) + result = append(result, i) } } - return res + + return result } // getNewShape returns a new shape based on the current shape and a list of dims to squeeze. func getNewShape(currentShape tensor.Shape, dimsToSqueeze []int) []int { - var newShape []int + newShape := []int{} + for i, dimSize := range currentShape { if keepDim(i, dimsToSqueeze) { newShape = append(newShape, dimSize) } } + return newShape } @@ -124,5 +138,6 @@ func keepDim(dim int, dimsToSqueeze []int) bool { return false } } + return true } diff --git a/ops/opset13/squeeze_test.go b/ops/opset13/squeeze_test.go index 5f92b9a..bb160b5 100644 --- a/ops/opset13/squeeze_test.go +++ b/ops/opset13/squeeze_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -149,7 +148,7 @@ func TestInputValidationSqueeze(t *testing.T) { { []tensor.Tensor{}, nil, - fmt.Errorf("squeeze operator: expected 1-2 input tensors, got 0"), + ops.ErrInvalidOptionalInputCount(0, &Squeeze{}), }, { []tensor.Tensor{ @@ -158,7 +157,7 @@ func TestInputValidationSqueeze(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("squeeze operator: expected 1-2 input tensors, got 3"), + ops.ErrInvalidOptionalInputCount(3, &Squeeze{}), }, { []tensor.Tensor{ @@ -166,7 +165,7 @@ func TestInputValidationSqueeze(t *testing.T) { ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - fmt.Errorf("squeeze operator: input 1 does not allow type int"), + ops.ErrInvalidInputType(1, "int", &Squeeze{}), }, } @@ -175,6 +174,7 @@ func TestInputValidationSqueeze(t *testing.T) { validated, err := squeeze.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { if test.expected != nil { assert.Equal(t, test.expected, validated) diff --git a/ops/opset13/sub.go b/ops/opset13/sub.go index f0db11b..9c59508 100644 --- a/ops/opset13/sub.go +++ b/ops/opset13/sub.go @@ -1,11 +1,16 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinSubInputs = 2 + MaxSubInputs = 2 +) + // Sub represents the ONNX sub operator. type Sub struct{} @@ -15,23 +20,18 @@ func newSub() ops.Operator { } // Init initializes the sub operator. -func (s *Sub) Init(attributes []*onnx.AttributeProto) error { +func (s *Sub) Init(*onnx.NodeProto) error { return nil } // Apply applies the sub operator. func (s *Sub) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - A, B, err := ops.MultidirectionalBroadcast(inputs[0], inputs[1]) - if err != nil { - return nil, err - } - - out, err := tensor.Sub(A, B) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Sub, + ops.MultidirectionalBroadcasting, + ) } // ValidateInputs validates the inputs that will be given to Apply for this operator. @@ -41,12 +41,12 @@ func (s *Sub) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { // GetMinInputs returns the minimum number of input tensors this operator expects. func (s *Sub) GetMinInputs() int { - return 2 + return MinSubInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (s *Sub) GetMaxInputs() int { - return 2 + return MaxSubInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/sub_test.go b/ops/opset13/sub_test.go index 8273920..6812be0 100644 --- a/ops/opset13/sub_test.go +++ b/ops/opset13/sub_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -104,14 +103,14 @@ func TestInputValidationSub(t *testing.T) { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - fmt.Errorf("sub operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Sub{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - fmt.Errorf("sub operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Sub{}), }, } @@ -120,6 +119,7 @@ func TestInputValidationSub(t *testing.T) { validated, err := sub.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/tan.go b/ops/opset13/tan.go new file mode 100644 index 0000000..a7b4a3b --- /dev/null +++ b/ops/opset13/tan.go @@ -0,0 +1,75 @@ +package opset13 + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Tan represents the ONNX tan operator. +type Tan struct{} + +// newTan creates a new tan operator. +func newTan() ops.Operator { + return &Tan{} +} + +// Init initializes the tan operator. +func (t *Tan) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the tan operator. +func (t *Tan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(tan[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(tan[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (t *Tan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(t, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (t *Tan) GetMinInputs() int { + return 1 +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (t *Tan) GetMaxInputs() int { + return 1 +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (t *Tan) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (t *Tan) String() string { + return "tan operator" +} + +func tan[T ops.FloatType](x T) T { + return T(math.Tan(float64(x))) +} diff --git a/ops/opset13/tan_test.go b/ops/opset13/tan_test.go new file mode 100644 index 0000000..2fbaf88 --- /dev/null +++ b/ops/opset13/tan_test.go @@ -0,0 +1,99 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestTanInit(t *testing.T) { + a := &Tan{} + + // since 'tan' does not have any attributes we pass in nil. This should not + // fail initializing the tan. + err := a.Init(nil) + assert.Nil(t, err) +} + +func TestTan(t *testing.T) { + tests := []struct { + tan *Tan + backing []float32 + shape []int + expected []float32 + }{ + { + &Tan{}, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{1.5574077, -2.1850398, -0.14254655, 1.1578213}, + }, + { + &Tan{}, + []float32{1, 2, 3, 4}, + []int{1, 4}, + []float32{1.5574077, -2.1850398, -0.14254655, 1.1578213}, + }, + { + &Tan{}, + []float32{2, 2, 2, 2}, + []int{1, 4}, + []float32{-2.1850398, -2.1850398, -2.1850398, -2.1850398}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.tan.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationTan(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, &Tan{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", &Tan{}), + }, + } + + for _, test := range tests { + tan := &Tan{} + validated, err := tan.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/tanh.go b/ops/opset13/tanh.go index d199a5c..b435fb9 100644 --- a/ops/opset13/tanh.go +++ b/ops/opset13/tanh.go @@ -1,8 +1,8 @@ package opset13 import ( - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) @@ -15,13 +15,14 @@ func newTanh() ops.Operator { } // Init initializes the sigmoid operator. -func (t *Tanh) Init(attributes []*onnx.AttributeProto) error { +func (t *Tanh) Init(*onnx.NodeProto) error { return nil } // Apply the sigmoid operator to the input node. func (t *Tanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { out, err := ops.Tanh(inputs[0]) + return []tensor.Tensor{out}, err } diff --git a/ops/opset13/tanh_test.go b/ops/opset13/tanh_test.go index 053d12b..44b5409 100644 --- a/ops/opset13/tanh_test.go +++ b/ops/opset13/tanh_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -71,11 +70,11 @@ func TestInputValidationTanh(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("tanh operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Tanh{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("tanh operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Tanh{}), }, } @@ -84,6 +83,7 @@ func TestInputValidationTanh(t *testing.T) { validated, err := tanh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/transpose.go b/ops/opset13/transpose.go index 13cc543..c89aa67 100644 --- a/ops/opset13/transpose.go +++ b/ops/opset13/transpose.go @@ -1,13 +1,16 @@ package opset13 import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinTransposeInputs = 1 + MaxTransposeInputs = 1 +) + // Transpose represents the ONNX transpose operator. type Transpose struct { perm []int @@ -19,21 +22,24 @@ func newTranspose() ops.Operator { } // Init initializes the transpose operator. -func (t *Transpose) Init(attributes []*onnx.AttributeProto) error { +func (t *Transpose) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { - return fmt.Errorf(ops.InvalidAttrCountErrTemplate, t, 1, len(attributes)) + return ops.ErrInvalidAttributeCount(1, len(attributes), t) } attr := attributes[0] if attr.GetName() != "perm" { - return fmt.Errorf(ops.UnknownAttributeErrTemplate, t, attr.GetName()) + return ops.ErrInvalidAttribute(attr.GetName(), t) } attrPerm := attr.GetInts() for _, val := range attrPerm { t.perm = append(t.perm, int(val)) } + return nil } @@ -54,12 +60,12 @@ func (t *Transpose) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, err // GetMinInputs returns the minimum number of input tensors this operator expects. func (t *Transpose) GetMinInputs() int { - return 1 + return MinTransposeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (t *Transpose) GetMaxInputs() int { - return 1 + return MaxTransposeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes diff --git a/ops/opset13/transpose_test.go b/ops/opset13/transpose_test.go index d8ce2af..afe1c9e 100644 --- a/ops/opset13/transpose_test.go +++ b/ops/opset13/transpose_test.go @@ -1,18 +1,17 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) func TestTransposeInit(t *testing.T) { trans := &Transpose{} - err := trans.Init(TransposeOnnxAttributeProtoFixture()) + err := trans.Init(TransposeOnnxNodeProtoFixture()) assert.Nil(t, err) assert.Equal(t, []int{1, 0}, trans.perm) @@ -20,17 +19,17 @@ func TestTransposeInit(t *testing.T) { func TestTransposeInitFailWrongAttribute(t *testing.T) { trans := &Transpose{} - err := trans.Init([]*onnx.AttributeProto{{Name: "unknownAttribute"}}) + err := trans.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "unknownAttribute"}}}) - expected := fmt.Errorf(ops.UnknownAttributeErrTemplate, trans, "unknownAttribute") + expected := ops.ErrInvalidAttribute("unknownAttribute", trans) assert.Equal(t, expected, err) } func TestTransposeInitFailAttrCount(t *testing.T) { trans := &Transpose{} - err := trans.Init([]*onnx.AttributeProto{}) + err := trans.Init(ops.EmptyNodeProto()) - expected := fmt.Errorf(ops.InvalidAttrCountErrTemplate, trans, 1, 0) + expected := ops.ErrInvalidAttributeCount(1, 0, trans) assert.Equal(t, expected, err) } @@ -85,11 +84,11 @@ func TestInputValidationTranspose(t *testing.T) { }, { []tensor.Tensor{}, - fmt.Errorf("transpose operator: expected 1 input tensors, got 0"), + ops.ErrInvalidInputCount(0, &Transpose{}), }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("transpose operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Transpose{}), }, } @@ -98,14 +97,17 @@ func TestInputValidationTranspose(t *testing.T) { validated, err := transpose.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } } } -func TransposeOnnxAttributeProtoFixture() []*onnx.AttributeProto { - return []*onnx.AttributeProto{ - {Name: "perm", Ints: []int64{1, 0}}, +func TransposeOnnxNodeProtoFixture() *onnx.NodeProto { + return &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "perm", Ints: []int64{1, 0}}, + }, } } diff --git a/ops/opset13/unsqueeze.go b/ops/opset13/unsqueeze.go index 312342f..b7d4530 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/opset13/unsqueeze.go @@ -1,14 +1,18 @@ package opset13 import ( - "fmt" "sort" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" "gorgonia.org/tensor" ) +const ( + MinUnsqueezeInputs = 2 + MaxUnsqueezeInputs = 2 +) + // Unsqueeze represents the ONNX unsqueeze operator. type Unsqueeze struct{} @@ -18,13 +22,14 @@ func newUnsqueeze() ops.Operator { } // Init initializes the unsqueeze operator. -func (u *Unsqueeze) Init(attributes []*onnx.AttributeProto) error { +func (u *Unsqueeze) Init(*onnx.NodeProto) error { return nil } // Apply applies the unsqueeze operator. func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { dataShape := inputs[0].Shape() + axes, err := ops.AnyToIntSlice(inputs[1].Data()) if err != nil { return nil, err @@ -33,7 +38,7 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { outputRank := len(dataShape) + len(axes) if !ops.AllInRange(axes, -outputRank, outputRank-1) { - return nil, fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, outputRank, outputRank) + return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) } // negative entries should be offset by the rank of the output tensor @@ -43,13 +48,18 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { sort.Ints(axes) if ops.HasDuplicates(axes) { - return nil, fmt.Errorf("Axes cannot have duplicate entries after offset, axes: %v", axes) + return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u) } newShape := insertOnes(dataShape, axes) - out := inputs[0].Clone().(tensor.Tensor) + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + err = out.Reshape(newShape...) + return []tensor.Tensor{out}, err } @@ -60,12 +70,12 @@ func (u *Unsqueeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, err // GetMinInputs returns the minimum number of input tensors this operator expects. func (u *Unsqueeze) GetMinInputs() int { - return 2 + return MinUnsqueezeInputs } // GetMaxInputs returns the maximum number of input tensors this operator expects. func (u *Unsqueeze) GetMaxInputs() int { - return 2 + return MaxUnsqueezeInputs } // GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes @@ -82,7 +92,7 @@ func (u *Unsqueeze) String() string { // Creates a new array, which is `original` with ones added at the indices specified by `indices` // `indices` may not contain duplicates, the elements are assumed to be in the range 0 <= x < N // and should be sorted in increasing order. -// Is done in a single pass through the new array with length: len(original) + len(indices) +// Is done in a single pass through the new array with length: len(original) + len(indices). func insertOnes(original, indices []int) []int { N := len(indices) + len(original) @@ -91,6 +101,7 @@ func insertOnes(original, indices []int) []int { originalIdx := 0 indicesIdx := 0 + for i := 0; i < N; i++ { if indicesIdx < len(indices) && indices[indicesIdx] == i { newShape[i] = 1 @@ -100,5 +111,6 @@ func insertOnes(original, indices []int) []int { originalIdx++ } } + return newShape } diff --git a/ops/opset13/unsqueeze_test.go b/ops/opset13/unsqueeze_test.go index cc7d21b..445d0c5 100644 --- a/ops/opset13/unsqueeze_test.go +++ b/ops/opset13/unsqueeze_test.go @@ -1,11 +1,10 @@ package opset13 import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" "gorgonia.org/tensor" ) @@ -20,21 +19,23 @@ func TestUnsqueezeInit(t *testing.T) { func TestAxesOutRangeError(t *testing.T) { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) axes := []int64{4} data := ops.Arange(9, 1) // 3 x 3 tensor dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) - _, err := op.Apply([]tensor.Tensor{dataIn, axesIn}) - expected := fmt.Errorf(ops.AxesNotAllInRangeErrTemplate, 3, 3) + _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) + expected := ops.ErrNotAllAxesInRange(3, 3) assert.Equal(t, err, expected) } func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) // -1 will be offset to 3 (since outputrank = 4) axes := []int64{3, -1} @@ -42,21 +43,22 @@ func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) - _, err := op.Apply([]tensor.Tensor{dataIn, axesIn}) - assert.EqualError(t, err, "Axes cannot have duplicate entries after offset, axes: [3 3]") + _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) + assert.EqualError(t, err, "invalid input tensor for unsqueeze operator: axes cannot have duplicate entries after offset") } func TestDuplicateEntriesNotAllowed(t *testing.T) { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) axes := []int64{0, 0} data := ops.Arange(9, 1) // 3 x 3 tensor dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) - _, err := op.Apply([]tensor.Tensor{dataIn, axesIn}) - assert.EqualError(t, err, "Axes cannot have duplicate entries after offset, axes: [0 0]") + _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) + assert.EqualError(t, err, "invalid input tensor for unsqueeze operator: axes cannot have duplicate entries after offset") } func TestUnsqueeze(t *testing.T) { @@ -71,24 +73,45 @@ func TestUnsqueeze(t *testing.T) { {[]int64{1, 2, 3, 4}, []int{2, 2}, []int64{0, -1}, []int{1, 2, 2, 1}}, {[]int64{1, 2, 3, 4}, []int{2, 2}, []int64{-1, 0}, []int{1, 2, 2, 1}}, - {[]int16{1, 2, 3, 4, 5, 6, 7, 8}, []int{2, 2, 2}, - []int64{0, 2, 4, 6}, []int{1, 2, 1, 2, 1, 2, 1}}, + { + []int16{1, 2, 3, 4, 5, 6, 7, 8}, + []int{2, 2, 2}, + []int64{0, 2, 4, 6}, + []int{1, 2, 1, 2, 1, 2, 1}, + }, - {[]complex128{1, 2, 3, 4, 5, 6, 7, 8}, []int{2, 2, 2}, - []int64{6, 0, 4, 2}, []int{1, 2, 1, 2, 1, 2, 1}}, + { + []complex128{1, 2, 3, 4, 5, 6, 7, 8}, + []int{2, 2, 2}, + []int64{6, 0, 4, 2}, + []int{1, 2, 1, 2, 1, 2, 1}, + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8}, []int{2, 2, 2}, - []int64{-7, -5, -3, -1}, []int{1, 2, 1, 2, 1, 2, 1}}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8}, + []int{2, 2, 2}, + []int64{-7, -5, -3, -1}, + []int{1, 2, 1, 2, 1, 2, 1}, + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8}, []int{2, 2, 2}, - []int64{-1, -7, -3, -5}, []int{1, 2, 1, 2, 1, 2, 1}}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8}, + []int{2, 2, 2}, + []int64{-1, -7, -3, -5}, + []int{1, 2, 1, 2, 1, 2, 1}, + }, - {[]float32{1, 2, 3, 4, 5, 6, 7, 8}, []int{2, 2, 2}, - []int64{0, 1, 2, 3}, []int{1, 1, 1, 1, 2, 2, 2}}, + { + []float32{1, 2, 3, 4, 5, 6, 7, 8}, + []int{2, 2, 2}, + []int64{0, 1, 2, 3}, + []int{1, 1, 1, 1, 2, 2, 2}, + }, } for _, test := range tests { op := &Unsqueeze{} - op.Init(nil) + err := op.Init(nil) + assert.Nil(t, err) axes := test.axes data := test.data @@ -125,21 +148,21 @@ func TestInputValidationUnsqueeze(t *testing.T) { }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - fmt.Errorf("unsqueeze operator: expected 2 input tensors, got 1"), + ops.ErrInvalidInputCount(1, &Unsqueeze{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), }, - fmt.Errorf("unsqueeze operator: input 0 does not allow type int"), + ops.ErrInvalidInputType(0, "int", &Unsqueeze{}), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), }, - fmt.Errorf("unsqueeze operator: input 1 does not allow type int32"), + ops.ErrInvalidInputType(1, "int32", &Unsqueeze{}), }, } @@ -148,6 +171,7 @@ func TestInputValidationUnsqueeze(t *testing.T) { validated, err := unsqueeze.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) + if test.err == nil { assert.Equal(t, test.inputs, validated) } diff --git a/ops/opset13/xor.go b/ops/opset13/xor.go new file mode 100644 index 0000000..f668a69 --- /dev/null +++ b/ops/opset13/xor.go @@ -0,0 +1,61 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var ( + MinXorInputs = 2 + MaxXorInputs = 2 +) + +// Xor represents the ONNX xor operator. +type Xor struct{} + +// newXor creates a new xor operator. +func newXor() ops.Operator { + return &Xor{} +} + +// Init initializes the xor operator. +func (x *Xor) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the xor operator. +func (x *Xor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Xor, + ops.MultidirectionalBroadcasting, + ) +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (x *Xor) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(x, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (x *Xor) GetMinInputs() int { + return MinXorInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (x *Xor) GetMaxInputs() int { + return MaxXorInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (x *Xor) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (x *Xor) String() string { + return "xor operator" +} diff --git a/ops/opset13/xor_test.go b/ops/opset13/xor_test.go new file mode 100644 index 0000000..68658a4 --- /dev/null +++ b/ops/opset13/xor_test.go @@ -0,0 +1,102 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestXorInit(t *testing.T) { + x := &Xor{} + + err := x.Init(nil) + assert.Nil(t, err) +} + +func TestXor(t *testing.T) { + tests := []struct { + xor *Xor + backings [][]bool + shapes [][]int + expected []bool + }{ + { + &Xor{}, + [][]bool{{true, false, true, false}, {true, true, true, false}}, + [][]int{{2, 2}, {2, 2}}, + []bool{false, true, false, false}, + }, + { + &Xor{}, + [][]bool{{true, false, true, false}, {true, false}}, + [][]int{{2, 2}, {1, 2}}, + []bool{false, false, false, false}, + }, + { + &Xor{}, + [][]bool{{true, false, true, false}, {true, false}}, + [][]int{{2, 2}, {2, 1}}, + []bool{false, true, true, false}, + }, + { + &Xor{}, + [][]bool{{true, false, true, false, true, false}, {false, false}}, + [][]int{{3, 2}, {1, 2}}, + []bool{true, false, true, false, true, false}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + } + + res, err := test.xor.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationXor(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + }, + ops.ErrInvalidInputCount(1, &Xor{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]bool{false, false}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(1, "int", &Xor{}), + }, + } + + for _, test := range tests { + or := &Xor{} + validated, err := or.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/recurrent_utils.go b/ops/recurrent_utils.go new file mode 100644 index 0000000..98564f1 --- /dev/null +++ b/ops/recurrent_utils.go @@ -0,0 +1,73 @@ +package ops + +import ( + "gorgonia.org/tensor" +) + +// SequenceProcessDirection is the direction in which a sequential input is processed. +// We can process sequential inputs forward (from first to last), in reverse (from +// last to first) or bidirectional (which is both forward and reverse added together). +type SequenceProcessDirection string + +const ( + Forward SequenceProcessDirection = "forward" + Reverse SequenceProcessDirection = "reverse" + Bidirectional SequenceProcessDirection = "bidirectional" +) + +// These constants define attributes that are applicable to GRU, LSTM and RNN operators. +const ( + ActivationAlphaAttr = "activation_alpha" + ActivationBetaAttr = "activation_beta" + ActivationsAttr = "activations" + ClipAttr = "clip" + DirectionAttr = "direction" + HiddenSizeAttr = "hidden_size" +) + +// ExtractMatrices extracts a given number of matrices from tensor M. +// M contains concatenated matrices along a certain dimension. +// M is assumed to have a shape of (num_directions, nMatrices * hidden_size, ...) and we extract the +// by slicing over the 'nMatrices * hidden_size' dimension. +// This method is specific for recurrent operators RNN, GRU and LSTM. +func ExtractMatrices(M tensor.Tensor, nMatrices, nDimensions, hiddenSize int) ([]tensor.Tensor, error) { + dirSlice := NewSlicer(0) + matrices := make([]tensor.Tensor, nMatrices) + + for i := 0; i < nMatrices; i++ { + hiddenSlice := NewSlicer(i*hiddenSize, (i+1)*hiddenSize) + + allSlices := make([]tensor.Slice, nDimensions) + allSlices[0] = dirSlice + allSlices[1] = hiddenSlice + + for i := 2; i < nDimensions; i++ { + allSlices[i] = nil + } + + m, err := M.Slice(allSlices...) + if err != nil { + return nil, err + } + + matrices[i] = m + } + + return matrices, nil +} + +// ZeroTensor returns a tensor filled with zeros with the given shape. +func ZeroTensor(shape ...int) tensor.Tensor { + return tensor.New( + tensor.WithShape(shape...), + tensor.WithBacking(Zeros(NElements(shape...))), + ) +} + +// OnesTensor returns a new tensor with the same shape as the given tensor intialized with all ones. +func OnesTensor(t tensor.Tensor) tensor.Tensor { + return tensor.New( + tensor.WithShape(t.Shape()...), + tensor.WithBacking(Ones(NElements(t.Shape()...))), + ) +} diff --git a/ops/slicer.go b/ops/slicer.go index 2dc1a00..8d82430 100644 --- a/ops/slicer.go +++ b/ops/slicer.go @@ -13,6 +13,8 @@ type Slicer struct { // will be set to 1. If options are given, it is assumed that the first element will be the value // for the end index and the second element the value for the step of the Slicer. func NewSlicer(start int, options ...int) tensor.Slice { + const maxOptionLength = 2 + end := start + 1 step := 1 @@ -20,7 +22,7 @@ func NewSlicer(start int, options ...int) tensor.Slice { end = options[0] } - if len(options) >= 2 { + if len(options) >= maxOptionLength { step = options[1] } diff --git a/ops/types.go b/ops/types.go new file mode 100644 index 0000000..edea5fb --- /dev/null +++ b/ops/types.go @@ -0,0 +1,18 @@ +package ops + +import "gorgonia.org/tensor" + +// FloatType is a type that describes a float value. Can be either float32 or float64. +type FloatType interface { + float32 | float64 +} + +// AllTypes is a type constraint which allows all types. +var AllTypes = []tensor.Dtype{ + tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, + tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, + tensor.Float32, tensor.Float64, + tensor.Complex64, tensor.Complex128, + tensor.String, + tensor.Bool, +} diff --git a/ops/unidir_broadcast.go b/ops/unidir_broadcast.go index 41b35d5..406876f 100644 --- a/ops/unidir_broadcast.go +++ b/ops/unidir_broadcast.go @@ -1,22 +1,27 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) +type BroadcastType int + +const ( + NoBroadcasting BroadcastType = 0 + UnidirectionalBroadcasting BroadcastType = 1 + MultidirectionalBroadcasting BroadcastType = 2 +) + // UnidirectionalBroadcast tries to broadcast tensor B to tensor A according to the ONNX standards. func UnidirectionalBroadcast(A, B tensor.Tensor) (tensor.Tensor, tensor.Tensor, error) { - reshapedB, err := reshapeTensorsForUnidirBroadcast(A, B) if err != nil { - return nil, nil, fmt.Errorf(UnidirBroadcastErrTemplate, A.Shape(), B.Shape()) + return nil, nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) } newB, err := repeatTensorsForUnidirBroadcast(A, reshapedB) if err != nil { - return nil, nil, fmt.Errorf(UnidirBroadcastErrTemplate, A.Shape(), B.Shape()) + return nil, nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) } return A, newB, nil @@ -35,7 +40,7 @@ func reshapeTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) case nDimsA == nDimsB: return B, nil default: - return nil, fmt.Errorf("tensor B may not have more dimensions than tensor A") + return nil, ErrUnidirBroadcast(A.Shape(), B.Shape()) } } @@ -45,6 +50,7 @@ func reshapeTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) // Example: shapeA=(2, 3, 4) and shapeB=(1, 3, 4) yields shapeNewB=(2, 3, 4). func repeatTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) { var err error + shapeA := A.Shape() shapeB := B.Shape() @@ -55,7 +61,7 @@ func repeatTensorsForUnidirBroadcast(A, B tensor.Tensor) (tensor.Tensor, error) if sizeDimA != sizeDimB { if sizeDimB != 1 { - return nil, fmt.Errorf("incompatible dimensions") + return nil, ErrUnidirBroadcast(shapeA, shapeB) } B, err = tensor.Repeat(B, axis, sizeDimA) diff --git a/ops/unidir_broadcast_test.go b/ops/unidir_broadcast_test.go index 9890e02..98d2ea5 100644 --- a/ops/unidir_broadcast_test.go +++ b/ops/unidir_broadcast_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -42,29 +41,17 @@ func TestUnidirectionalBroadcast(t *testing.T) { { [][]int{{1, 3, 1}, {3, 2}}, nil, - fmt.Errorf( - UnidirBroadcastErrTemplate, - []int{1, 3, 1}, - []int{3, 2}, - ), + ErrUnidirBroadcast([]int{1, 3, 1}, []int{3, 2}), }, { [][]int{{5}, {2, 3, 4}}, nil, - fmt.Errorf( - UnidirBroadcastErrTemplate, - []int{5}, - []int{2, 3, 4}, - ), + ErrUnidirBroadcast([]int{5}, []int{2, 3, 4}), }, { [][]int{{1, 4, 5}, {1, 1, 3}}, nil, - fmt.Errorf( - UnidirBroadcastErrTemplate, - []int{1, 4, 5}, - []int{1, 1, 3}, - ), + ErrUnidirBroadcast([]int{1, 4, 5}, []int{1, 1, 3}), }, } @@ -75,6 +62,7 @@ func TestUnidirectionalBroadcast(t *testing.T) { newA, newB, err := UnidirectionalBroadcast(A, B) assert.Equal(t, test.err, err) + if err == nil { assert.Equal(t, test.expectedShape, newA.Shape()) assert.Equal(t, test.expectedShape, newB.Shape()) diff --git a/ops/utils.go b/ops/utils.go index 369aedf..98d205f 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -1,8 +1,6 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) @@ -11,6 +9,7 @@ func Abs(x int) int { if x < 0 { x *= -1 } + return x } @@ -26,21 +25,26 @@ func AllInRange(arr []int, min, max int) bool { return false } } + return true } -// HasDuplicates checks if there are duplicates in the sorted array `arr` +// HasDuplicates checks if there are duplicates in the sorted array `arr`. func HasDuplicates(arr []int) bool { if len(arr) < 1 { return false } + prev := arr[0] + for _, x := range arr[1:] { if prev == x { return true } + prev = x } + return false } @@ -51,50 +55,61 @@ func OffsetArrayIfNegative(arr []int, offset int) { if ax < 0 { ax += offset } + arr[i] = ax } } // OffsetTensorIfNegative adds an offset to every negative element in tensor t. // Works only for tensors with Dtype int (same as offset). -func OffsetTensorIfNegative(t tensor.Tensor, offset int) { +func OffsetTensorIfNegative(t tensor.Tensor, offset int) error { f := func(n int) int { if n < 0 { return n + offset } + return n } - t.Apply(f, tensor.WithReuse(t)) + + if _, err := t.Apply(f, tensor.WithReuse(t)); err != nil { + return err + } + + return nil } // AnyToIntSlice casts the data of a node to an int list. This will only // be done if the data is of some sort of int type. -func AnyToIntSlice(any interface{}) ([]int, error) { +func AnyToIntSlice(value interface{}) ([]int, error) { var res []int - switch data := any.(type) { + switch data := value.(type) { case []int8: for _, value := range data { res = append(res, int(value)) } + return res, nil case []int16: for _, value := range data { res = append(res, int(value)) } + return res, nil case []int32: for _, value := range data { res = append(res, int(value)) } + return res, nil case []int64: for _, value := range data { res = append(res, int(value)) } + return res, nil default: - return nil, fmt.Errorf("could not cast %v to int list", data) + return nil, ErrCast } } @@ -116,14 +131,14 @@ func GetValueAsTensorType(value float64, dtype tensor.Dtype) (interface{}, error case tensor.Float64: return value, nil default: - return nil, fmt.Errorf("unknown type %v, cannot cast constant to this type", dtype) + return nil, ErrCast } } // IfScalarToSlice will wrap the value in a slice if it is a scalar in a slice with that value, // otherwise will return itself. -func IfScalarToSlice(any interface{}) interface{} { - switch data := any.(type) { +func IfScalarToSlice(value any) any { + switch data := value.(type) { case int8: return []int8{data} case int16: @@ -143,7 +158,7 @@ func IfScalarToSlice(any interface{}) interface{} { case complex128: return []complex128{data} default: - return any + return value } } @@ -153,6 +168,7 @@ func Zeros(size int) []float32 { for i := range res { res[i] = 0.0 } + return res } @@ -162,6 +178,7 @@ func Full(size int, value float32) []float32 { for i := range res { res[i] = value } + return res } @@ -171,6 +188,7 @@ func Ones(size int) []float32 { for i := range res { res[i] = 1.0 } + return res } @@ -180,6 +198,7 @@ func Arange(size int, step float32) []float32 { for i := range res { res[i] = float32(i) * step } + return res } @@ -193,12 +212,15 @@ func NElements(shp ...int) int { return nElem } -// PairwiseAssign essentially does pairwise t1 = t2 in place! +// PairwiseAssign essentially does pairwise t1 = t2 in place!. func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { if !t1.Shape().Eq(t2.Shape()) { - return fmt.Errorf("Shapes of tensors must be equal, were %v and %v", t1.Shape(), t2.Shape()) + return ErrInvalidShape } + it := t1.Iterator() + // We cannot check the error here since it is a post statement so ignore the nolint errcheck here. + // nolint errcheck for it.Reset(); !it.Done(); it.Next() { coord := it.Coord() @@ -212,5 +234,6 @@ func PairwiseAssign(t1, t2 tensor.Tensor) (err error) { return err } } + return nil } diff --git a/ops/utils_test.go b/ops/utils_test.go index a43d633..9ea0d87 100644 --- a/ops/utils_test.go +++ b/ops/utils_test.go @@ -1,7 +1,6 @@ package ops import ( - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -106,7 +105,9 @@ func TestOffsetTensorIfNegative(t *testing.T) { } for _, test := range tests { tIn := tensor.New(tensor.WithShape(len(test.in)), tensor.WithBacking(test.in)) - OffsetTensorIfNegative(tIn, test.offset) + err := OffsetTensorIfNegative(tIn, test.offset) + + assert.Nil(t, err) assert.Equal(t, test.expected, tIn.Data()) } } @@ -140,12 +141,13 @@ func TestAnyToIntSlice(t *testing.T) { { "some string", nil, - fmt.Errorf("could not cast some string to int list"), + ErrCast, }, } for _, test := range tests { res, err := AnyToIntSlice(test.in) + assert.Equal(t, test.expected, res) assert.Equal(t, test.err, err) } @@ -204,7 +206,7 @@ func TestGetValueAsTensorType(t *testing.T) { 1.0, tensor.Complex64, nil, - fmt.Errorf("unknown type complex64, cannot cast constant to this type"), + ErrCast, }, } @@ -313,7 +315,7 @@ func TestPairwiseAssign(t *testing.T) { { tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float32{1, 2, 3, 4})), tensor.New(tensor.WithShape(1, 2), tensor.WithBacking([]float32{1, 1})), - fmt.Errorf("Shapes of tensors must be equal, were (2, 2) and (1, 2)"), + ErrInvalidShape, }, } @@ -321,6 +323,7 @@ func TestPairwiseAssign(t *testing.T) { err := PairwiseAssign(test.t1, test.t2) assert.Equal(t, err, test.err) + if err == nil { assert.Equal(t, test.t2.Data(), test.t1.Data()) } diff --git a/ops/validate_inputs.go b/ops/validate_inputs.go index 8b954cb..fa82a48 100644 --- a/ops/validate_inputs.go +++ b/ops/validate_inputs.go @@ -1,21 +1,9 @@ package ops import ( - "fmt" - "gorgonia.org/tensor" ) -// AllTypes is a type constraint which allows all types. -var AllTypes = []tensor.Dtype{ - tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, - tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, - tensor.Float32, tensor.Float64, - tensor.Complex64, tensor.Complex128, - tensor.String, - tensor.Bool, -} - // ValidateInputs validates if a list of nodes has enough (not too few or too many) nodes. // When there are fewer input nodes then the given max, the list is padded with nils. // Expects either 1 requirement ==> the expected number of inputs, or 2 requirements, @@ -38,19 +26,22 @@ func ValidateInputs(op Operator, inputs []tensor.Tensor) ([]tensor.Tensor, error func checkNInputs(op Operator, inputs []tensor.Tensor) (int, error) { nInputs := len(inputs) - var padLength int + padLength := 0 min := op.GetMinInputs() max := op.GetMaxInputs() + if min == max { if nInputs != min { - return 0, fmt.Errorf(InvalidInputCountErrTemplate, op, min, nInputs) + return 0, ErrInvalidInputCount(nInputs, op) } + padLength = min } else { if nInputs < min || nInputs > max { - return 0, fmt.Errorf(InvalidOptionalInputCountErrTemplate, op, min, max, nInputs) + return 0, ErrInvalidOptionalInputCount(nInputs, op) } + padLength = max } @@ -62,11 +53,13 @@ func padInputs(inputs []tensor.Tensor, length int) []tensor.Tensor { for len(inputs) < length { inputs = append(inputs, nil) } + return inputs } func checkInputTypes(op Operator, inputs []tensor.Tensor) error { typeConstraints := op.GetInputTypeConstraints() + for i, input := range inputs { // Optional inputs can be nil, we can not check for type constraints then. if input == nil { @@ -76,9 +69,10 @@ func checkInputTypes(op Operator, inputs []tensor.Tensor) error { typeConstraint := newTypeConstraint(typeConstraints[i]) if _, ok := typeConstraint[input.Dtype()]; !ok { - return fmt.Errorf("%v: input %d does not allow type %v", op, i, input.Dtype()) + return ErrInvalidInputType(i, input.Dtype().Name(), op) } } + return nil } @@ -89,5 +83,6 @@ func newTypeConstraint(allowedTypes []tensor.Dtype) map[tensor.Dtype]bool { for _, allowedType := range allowedTypes { typeConstraint[allowedType] = true } + return typeConstraint } diff --git a/ops/validate_inputs_test.go b/ops/validate_inputs_test.go index 1065a45..1edde37 100644 --- a/ops/validate_inputs_test.go +++ b/ops/validate_inputs_test.go @@ -1,11 +1,10 @@ package ops import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" "gorgonia.org/tensor" ) @@ -76,7 +75,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(1, 0), 0, - fmt.Errorf(InvalidInputCountErrTemplate, &MockOp{}, 2, 1), + ErrInvalidInputCount(1, &MockOp{minInputs: 2, maxInputs: 2}), }, { &MockOp{ @@ -92,7 +91,7 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(7, 0), 0, - fmt.Errorf(InvalidOptionalInputCountErrTemplate, &MockOp{}, 3, 5, 7), + ErrInvalidOptionalInputCount(7, &MockOp{minInputs: 3, maxInputs: 5}), }, { &MockOp{ @@ -102,15 +101,17 @@ func TestValidateInputs(t *testing.T) { }, PaddedInputsFixture(2, 0), 0, - fmt.Errorf("%v: input %d does not allow type %v", &MockOp{}, 1, tensor.Float32), + ErrInvalidInputType(1, "float32", &MockOp{}), }, } for _, test := range tests { inputs, err := ValidateInputs(test.op, test.inputs) + if test.err != nil { + assert.EqualError(t, err, test.err.Error()) + } expectedLength := len(test.inputs) + test.expectedNil - assert.Equal(t, test.err, err) assert.Equal(t, expectedLength, len(inputs)) // Check if the added nodes are all nil. @@ -136,12 +137,15 @@ func TestPadInputs(t *testing.T) { func PaddedInputsFixture(nTensors, nNil int) []tensor.Tensor { result := make([]tensor.Tensor, nTensors+nNil) i := 0 + for ; i < nTensors; i++ { result[i] = tensor.New(tensor.WithBacking([]float32{0.0})) } + for ; i < nTensors+nNil; i++ { result[i] = nil } + return result } @@ -151,11 +155,11 @@ type MockOp struct { inputTypeConstraints [][]tensor.Dtype } -func (m *MockOp) Init(attr []*onnx.AttributeProto) error { +func (m *MockOp) Init(*onnx.NodeProto) error { return nil } -func (m *MockOp) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func (m *MockOp) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return nil, nil } diff --git a/ops_test.go b/ops_test.go index f7543e9..431df70 100644 --- a/ops_test.go +++ b/ops_test.go @@ -2,16 +2,17 @@ package gonnx import ( "fmt" - "io/ioutil" + "io" "os" "sort" "strings" "testing" + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops/opset13" "github.com/stretchr/testify/assert" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops/opset13" "google.golang.org/protobuf/proto" + "gorgonia.org/tensor" ) // Currently we ignore some of tests provided by ONNX. This has to do with the @@ -25,10 +26,8 @@ import ( var ignoredTests = []string{ "test_add_uint8", // Opset14 "test_div_uint8", // Opset14 - "test_gru_defaults", // Opset14 "test_gru_batchwise", // Opset14 - "test_gru_seq_length", // Opset14 - "test_gru_with_initial_bias", // Opset14 + "test_lstm_batchwise", // Opset14 "test_mul_uint8", // Opset14 "test_sub_uint8", // Opset14 "test_shape_clip_end", // Opset15 @@ -42,26 +41,67 @@ var ignoredTests = []string{ "test_shape_start_negative_1", // Opset15 "test_reshape_allowzero_reordered", // Opset14 - "test_constant_pad", // Pad is not implemented yet. - "test_constant_pad_axes", // Pad is not implemented yet. - "test_gemm_alpha", // For gemm in opset 11. - "test_gemm_default_no_bias", // For gemm in opset 11. - "test_gemm_default_scalar_bias", // For gemm in opset 11. - "test_relu_expanded_ver18", // CastLike operator not implemented yet. - "test_slice_start_out_of_bounds", // ONNX expects nil output, but we throw an error. - "test_slice_end_out_of_bounds", // ONNX expects nil output, but we throw an error. - "test_slice_neg_steps", // ONNX expects nil output, but we throw an error. - "test_slice_neg", // ONNX expects nil output, but we throw an error. - "test_transpose_default", // For transpose in opset 9. - - "test_cast_FLOAT_to_STRING", // Unsupported datatype STRING. - "test_cast_STRING_to_FLOAT", // Unsupported datatype STRING. - "test_cast_DOUBLE_to_FLOAT16", // Unsupported datatype FLOAT16. - "test_cast_FLOAT_to_FLOAT16", // Unsupported datatype FLOAT16. - "test_cast_FLOAT16_to_DOUBLE", // Unsupported datatype FLOAT16. - "test_cast_FLOAT16_to_FLOAT", // Unsupported datatype FLOAT16. - "test_cast_BFLOAT16_to_FLOAT", // Unsupported datatype BFLOAT16. - "test_cast_FLOAT_to_BFLOAT16", // Unsupported datatype BFLOAT16. + "test_constant_pad", // Pad is not implemented yet. + "test_constant_pad_axes", // Pad is not implemented yet. + "test_gemm_alpha", // For gemm in opset 11. + "test_gemm_default_no_bias", // For gemm in opset 11. + "test_gemm_default_scalar_bias", // For gemm in opset 11. + "test_lstm_with_peepholes", // Sequence lens attribute is not supported yet. + "test_relu_expanded_ver18", // CastLike operator not implemented yet. + "test_softmax_default_axis_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_axis_1_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_negative_axis_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_example_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_axis_0_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_large_number_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_axis_2_expanded_ver18", // ReduceMax operator not implemented yet. + "test_softmax_axis_0_expanded", // ReduceMax operator not implemented yet. + "test_softmax_negative_axis_expanded", // ReduceMax operator not implemented yet. + "test_softmax_large_number_expanded", // ReduceMax operator not implemented yet. + "test_softmax_axis_1_expanded", // ReduceMax operator not implemented yet. + "test_softmax_example_expanded", // ReduceMax operator not implemented yet. + "test_softmax_axis_2_expanded", // ReduceMax operator not implemented yet. + "test_softmax_default_axis_expanded", // ReduceMax operator not implemented yet. + "test_slice_start_out_of_bounds", // ONNX expects nil output, but we throw an error. + "test_slice_end_out_of_bounds", // ONNX expects nil output, but we throw an error. + "test_slice_neg_steps", // ONNX expects nil output, but we throw an error. + "test_slice_neg", // ONNX expects nil output, but we throw an error. + "test_transpose_default", // For transpose in opset 9. + + "test_equal_string", // Unsupported datatype String. + "test_equal_string_broadcast", // Unsupported datatype String. + "test_cast_FLOAT_to_STRING", // Unsupported datatype STRING. + "test_cast_STRING_to_FLOAT", // Unsupported datatype STRING. + "test_cast_DOUBLE_to_FLOAT16", // Unsupported datatype FLOAT16. + "test_cast_FLOAT_to_FLOAT16", // Unsupported datatype FLOAT16. + "test_cast_FLOAT16_to_DOUBLE", // Unsupported datatype FLOAT16. + "test_cast_FLOAT16_to_FLOAT", // Unsupported datatype FLOAT16. + "test_cast_BFLOAT16_to_FLOAT", // Unsupported datatype BFLOAT16. + "test_cast_FLOAT_to_BFLOAT16", // Unsupported datatype BFLOAT16. + "test_cast_FLOAT_to_FLOAT8E5M2", // Unsupported datatype. + "test_cast_FLOAT_to_FLOAT8E4M3FN", // Unsupported datatype. + "test_cast_FLOAT_to_FLOAT8E4M3FNUZ", // Unsupported datatype FLOAT8E4M3FNUZ. + "test_cast_FLOAT_to_FLOAT8E5M2FNUZ", // Unsupported datatype. + "test_cast_FLOAT16_to_FLOAT8E5M2", // Unsupported datatype. + "test_cast_FLOAT16_to_FLOAT8E4M3FN", // Unsupported datatype. + "test_cast_FLOAT16_to_FLOAT8E4M3FNUZ", // Unsupported datatype. + "test_cast_FLOAT16_to_FLOAT8E5M2FNUZ", // Unsupported datatype. + "test_cast_FLOAT8E5M2_to_FLOAT", // Unsupported datatype. + "test_cast_FLOAT8E5M2_to_FLOAT16", // Unsupported datatype. + "test_cast_FLOAT8E4M3FN_to_FLOAT", // Unsupported datatype. + "test_cast_FLOAT8E4M3FN_to_FLOAT16", // Unsupported datatype. + "test_cast_FLOAT8E4M3FNUZ_to_FLOAT", // Unsupported datatype. + "test_cast_FLOAT8E4M3FNUZ_to_FLOAT16", // Unsupported datatype. + "test_cast_FLOAT8E5M2FNUZ_to_FLOAT", // Unsupported datatype. + "test_cast_FLOAT8E5M2FNUZ_to_FLOAT16", // Unsupported datatype. + "test_cast_no_saturate_FLOAT_to_FLOAT8E5M2", // Unsupported datatype FLOAT8E5M2. + "test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ", // Unsupported datatype. + "test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ", // Unsupported datatype. + "test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN", // Unsupported datatype. + "test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ", // Unsupported datatype. + "test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ", // Unsupported datatype. + "test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN", // Unsupported datatype. + "test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2", // Unsupported datatype. "test_unsqueeze_axis_3", // Tests an old version of Unsqueeze (<= 11) "test_constantofshape_int_shape_zero", // Empty tensors are not supported in gorgonia @@ -69,6 +109,9 @@ var ignoredTests = []string{ "test_gather_elements_1", // Operator GatherElements is not implemented "test_gather_elements_negative_indices", // Operator GatherElements is not implemented + "test_prelu_broadcast_expanded", // Unsupported operator CastLike + "test_prelu_example_expanded", // Unsupported operator CastLike + "test_constant_pad_negative_axes", // Unsupported operator Pad } type ONNXTestCase struct { @@ -79,8 +122,9 @@ type ONNXTestCase struct { } func TestOps(t *testing.T) { - var runnedTests []string + runnedTests := []string{} opNames := opset13.GetOpNames() + for _, opName := range opNames { tests, err := getTestCasesForOp(opName) assert.Nil(t, err) @@ -93,15 +137,22 @@ func TestOps(t *testing.T) { for outputName := range test.outputs { expectedTensor := test.outputs[outputName] actualTensor := outputs[outputName] - assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001) + + if expectedTensor.Dtype() == tensor.Bool { + assert.ElementsMatch(t, expectedTensor.Data(), actualTensor.Data()) + } else { + assert.InDeltaSlice(t, expectedTensor.Data(), actualTensor.Data(), 0.00001) + } } }) runnedTests = append(runnedTests, test.name) } } + sort.Strings(expectedTests) sort.Strings(runnedTests) + assert.Equal(t, expectedTests, runnedTests) } @@ -119,6 +170,7 @@ func getTestCasesForOp(opName string) ([]*ONNXTestCase, error) { } var tests []*ONNXTestCase + for _, testFolder := range testFolders { if shouldRunTest(testFolder, opFilter) { testcase, err := getTestCase(fmt.Sprintf("./test_data/%v", testFolder)) @@ -147,6 +199,7 @@ func shouldRunTest(folder, opFilter string) bool { return true } } + return false } @@ -159,6 +212,7 @@ func getTestCase(folder string) (*ONNXTestCase, error) { } basePath := fmt.Sprintf("%v/test_data_set_0", folder) + inputs, err := readTestTensors(basePath, "input", model.mp.Graph.GetInput()) if err != nil { return nil, err @@ -172,11 +226,17 @@ func getTestCase(folder string) (*ONNXTestCase, error) { testcase.model = model testcase.inputs = inputs testcase.outputs = outputs + return testcase, nil } func readTestModel(folder string) (*Model, error) { - bytesModel, err := ioutil.ReadFile(folder + "/model.onnx") + file, err := os.Open(folder + "/model.onnx") + if err != nil { + return nil, err + } + + bytesModel, err := io.ReadAll(file) if err != nil { return nil, err } @@ -202,9 +262,14 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) ( tensors := make(Tensors) for i := 0; i < len(inputs); i++ { - filePath := fmt.Sprintf("%v/%v_%d.pb", basePath, baseFile, i) - bytesInput, err := ioutil.ReadFile(filePath) + + file, err := os.Open(filePath) + if err != nil { + return nil, err + } + + bytesInput, err := io.ReadAll(file) if err != nil { return nil, err } @@ -227,8 +292,26 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) ( // With this we check if we truly run all tests we expected from the integration test. var expectedTests = []string{ + "test_abs", + "test_acos", + "test_acos_example", + "test_acosh", + "test_acosh_example", "test_add", "test_add_bcast", + "test_and_bcast3v1d", + "test_and_bcast3v2d", + "test_and_bcast4v2d", + "test_and_bcast4v3d", + "test_and_bcast4v4d", + "test_asin", + "test_asin_example", + "test_asinh", + "test_asinh_example", + "test_atan", + "test_atan_example", + "test_atanh", + "test_atanh_example", "test_cast_DOUBLE_to_FLOAT", "test_cast_FLOAT_to_DOUBLE", "test_concat_1d_axis_0", @@ -246,9 +329,19 @@ var expectedTests = []string{ "test_constant", "test_constantofshape_float_ones", "test_constantofshape_int_zeros", + "test_conv_with_autopad_same", + "test_conv_with_strides_and_asymmetric_padding", + "test_conv_with_strides_no_padding", + "test_conv_with_strides_padding", + "test_cos", + "test_cos_example", + "test_cosh", + "test_cosh_example", "test_div", "test_div_bcast", "test_div_example", + "test_equal", + "test_equal_bcast", "test_gather_0", "test_gather_1", "test_gather_2d_indices", @@ -261,12 +354,39 @@ var expectedTests = []string{ "test_gemm_default_zero_bias", "test_gemm_beta", "test_gemm_transposeB", + "test_greater", + "test_greater_bcast", + "test_greater_equal", + "test_greater_equal_bcast", + "test_greater_equal_bcast_expanded", + "test_greater_equal_expanded", + "test_gru_defaults", + "test_gru_seq_length", + "test_gru_with_initial_bias", + "test_less", + "test_less_bcast", + "test_less_equal", + "test_less_equal_bcast", + "test_less_equal_bcast_expanded", + "test_less_equal_expanded", + "test_lstm_defaults", + "test_lstm_with_initial_bias", "test_matmul_4d", "test_matmul_3d", "test_matmul_2d", "test_mul", "test_mul_bcast", "test_mul_example", + "test_not_2d", + "test_not_3d", + "test_not_4d", + "test_or_bcast3v1d", + "test_or_bcast3v2d", + "test_or_bcast4v2d", + "test_or_bcast4v3d", + "test_or_bcast4v4d", + "test_prelu_broadcast", + "test_prelu_example", "test_relu", "test_reshape_extended_dims", "test_reshape_negative_dim", @@ -277,18 +397,32 @@ var expectedTests = []string{ "test_reshape_reordered_last_dims", "test_reshape_zero_and_negative_dim", "test_reshape_zero_dim", + "test_rnn_seq_length", "test_shape", + "test_sin", + "test_sin_example", "test_sigmoid_example", "test_sigmoid", + "test_sinh", + "test_sinh_example", "test_slice_negative_axes", "test_slice_default_steps", "test_slice", "test_slice_default_axes", + "test_softmax_axis_0", + "test_softmax_axis_1", + "test_softmax_axis_2", + "test_softmax_default_axis", "test_squeeze_negative_axes", + "test_softmax_example", + "test_softmax_large_number", + "test_softmax_negative_axis", "test_squeeze", "test_sub", "test_sub_bcast", "test_sub_example", + "test_tan", + "test_tan_example", "test_tanh", "test_tanh_example", "test_transpose_all_permutations_2", @@ -304,4 +438,9 @@ var expectedTests = []string{ "test_unsqueeze_three_axes", "test_unsqueeze_two_axes", "test_unsqueeze_unsorted_axes", + "test_xor_bcast3v1d", + "test_xor_bcast3v2d", + "test_xor_bcast4v2d", + "test_xor_bcast4v3d", + "test_xor_bcast4v4d", } diff --git a/opset.go b/opset.go index 676d2a3..2b83900 100644 --- a/opset.go +++ b/opset.go @@ -1,10 +1,8 @@ package gonnx import ( - "fmt" - - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops" - "gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops/opset13" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/opset13" ) // OpGetter is a function that gets an operator based on a string. @@ -16,13 +14,9 @@ var operatorGetters = map[int64]OpGetter{ // ResolveOperatorGetter resolves the getter for operators based on the opset version. func ResolveOperatorGetter(opsetID int64) (OpGetter, error) { - if GetOperator, ok := operatorGetters[opsetID]; ok { - return GetOperator, nil + if getOperator, ok := operatorGetters[opsetID]; ok { + return getOperator, nil } - var opsets []int64 - for version := range operatorGetters { - opsets = append(opsets, version) - } - return nil, fmt.Errorf("expected opset to be in %d, got operator set %d", opsets, opsetID) + return nil, ops.ErrUnsupportedOpsetVersion } diff --git a/opset_test.go b/opset_test.go index e6bb90b..bd986c9 100644 --- a/opset_test.go +++ b/opset_test.go @@ -1,14 +1,14 @@ package gonnx import ( - "fmt" "testing" + "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" ) func TestResolveOperatorGetterFail(t *testing.T) { opGetter, err := ResolveOperatorGetter(12) assert.Nil(t, opGetter) - assert.Equal(t, fmt.Errorf("expected opset to be in [13], got operator set 12"), err) + assert.Equal(t, ops.ErrUnsupportedOpsetVersion, err) } diff --git a/sample_models/onnx_models/mnist-8-opset13.onnx b/sample_models/onnx_models/mnist-8-opset13.onnx new file mode 100644 index 0000000..5258a6b --- /dev/null +++ b/sample_models/onnx_models/mnist-8-opset13.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d6267e75ad19e51ad643554f861f21fc76bcb54b625074a845ccf329c465bad6 +size 26454 diff --git a/sample_models/requirements.txt b/sample_models/requirements.txt index c72c959..3b70550 100644 --- a/sample_models/requirements.txt +++ b/sample_models/requirements.txt @@ -1,4 +1,4 @@ -numpy==1.21.2 -scikit-learn==0.24.2 -skl2onnx==1.9.2 +numpy==1.26.1 +scikit-learn==1.3.1 +skl2onnx==1.15.0 torch==1.9.0