Skip to content

Commit

Permalink
v1.0.0 (#188)
Browse files Browse the repository at this point in the history
* Added the Abs operator

* Issue [#96](#96):
Update import paths.

* Setup CI/CD checks

* Breaking pipeline?

* Add linter and test jobs

* Fix syntax

* Split build jobs

* Fix imports for Abs operator

* add PRelu operator

* fix typos in prelu comments

* Fix new ONNX tests

* Better ordering

* Issue #151 : Update import path in readme

* Keep using Go1.19

* Fix a lot of linter issues

* Fix more lint

* Got to cast operator and disable `captLocal`

* WIP: more lint fixes

* WIP: More lint fixes

* More wip on lint errors

* WIP on lint

* More WIP

* Fix all lints except errors

* Fix all lints

* Fixed part of tests

* Fix validate input tests.

* Worked on fixing tests

* Fixed rest of tests

* Resolve all MR comments

* Added cos operator (#159)

* Added cos operator

* Replace s struct identifier

* Missed characters

* Fixed comment

* Resolved MR comments

* Fix tests

* Fix lint

* remove unused errors

* Fix naming

* Group declarations

---------

Co-authored-by: Swopper050 <[email protected]>

* Added acosh operator (#163)

* Added acosh operator

* Merged develop

---------

Co-authored-by: Swopper050 <[email protected]>

* Add Conv operator (#177)

* Working on conv operator

* Added all attributes for conv operator

* WIP on conv operator

* Shit's hard

* WIP on conv operator

* Set defaults for attributes

* Finished computation of dilated kernel

* Start of 1D conv implementation

* Almost finished 1D convolution

* Working 2D conv!

* Working convolution operator

* Add mnist model

* Added tests + bugfixes in conv

* Full coverage

* Remove unnecessary if

* Kept division by 2

* Fix last MR comment

* Fix lint?

---------

Co-authored-by: Swopper050 <[email protected]>

* Added acos operator (#162)

* Added acos operator

* Merge develop

* Group declarations

---------

Co-authored-by: Swopper050 <[email protected]>

* Added Sin operator (#157)

* Added Sin operator

* Added ONNX sin test coverage

* Fix tests

* Remove unused error

* Remove unused error

* Fix lint

* Fix lint

* Use float type

---------

Co-authored-by: Swopper050 <[email protected]>

* Added asin operator (#161)

* Added asin operator

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>

* Added Sinh operator (#158)

* Added sinh operator

* Updated comments

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>

* Added atan operator (#165)

* Added atan operator

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>

* Added atanh operator (#166)

* Added atanh operator

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>

* Added tan operator (#167)

* Added tan operator

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>

* Added asinh operator (#168)

* Added asinh operator

* Remove unused types

---------

Co-authored-by: Swopper050 <[email protected]>

* Added cosh operator (#160)

* Added cosh operator

* Group declarations

* Correct Apply comment

---------

Co-authored-by: Swopper050 <[email protected]>

* Added softmax operator (#171)

* Added softmax operator

* Merged develop

* Resolved MR comments

---------

Co-authored-by: Swopper050 <[email protected]>

* Added not operator (#170)

* Added not operator

* Resolved MR comments

---------

Co-authored-by: Swopper050 <[email protected]>

* Added boolean operators (#180)

* Added boolean operators

* Fix lint

* Fix final linst

* Create a generic apply boolean operator

* Refactored boolean operators to use binary op generic function

* Refactored even more

* Generalized binary operations

---------

Co-authored-by: Swopper050 <[email protected]>

* Added comparison operators (#173)

* Added equal, greater, greaterOrEqual operator

* Merged develop

* Resolved MR comments

---------

Co-authored-by: Swopper050 <[email protected]>

* Add RNN operator (#181)

* WIP on RNN

* WIP on RNN

* Working RNN version

* Added tests for RNN

* Resolved MR comments

* Do not export rnn specific constants

---------

Co-authored-by: Swopper050 <[email protected]>

* Add LSTM operator (#183)

* WIP on RNN

* WIP on RNN

* Working RNN version

* Added tests for RNN

* Working version of LSTM operator

* Reusable attrs and tests for LSTM

* Refactored recurrent operators to share code

* Resolved MR comments

* Do not export rnn specific constants

---------

Co-authored-by: Swopper050 <[email protected]>

* Add LinearRegressor operator (#184)

* WIP on LinearRegressor

* Added tests for linear regressor

* Added test descriptions and docstring

* Do not export constants

---------

Co-authored-by: Swopper050 <[email protected]>

---------

Co-authored-by: Swopper050 <[email protected]>
Co-authored-by: wisse <[email protected]>
Co-authored-by: wipsel <[email protected]>
Co-authored-by: Yannick Dylla <[email protected]>
  • Loading branch information
5 people authored Dec 14, 2023
1 parent a4a13c1 commit 3fbddf4
Show file tree
Hide file tree
Showing 133 changed files with 9,363 additions and 927 deletions.
75 changes: 75 additions & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go

name: Go

on:
push:
branches: [ "develop" ]
pull_request:
branches: [ "develop" ]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19

- name: Install linter
run: make install_lint

- name: Lint
run: make lint

tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19

- name: Install dependencies
run: make install

- name: Install Gotestsum
run: make install_gotestsum

- name: Setup ONNX test data
run: make test_data

- name: Tests
run: make test

build_amd64:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19

- name: Build amd64
run: make build_amd64

build_arm64:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19

- name: Build arm64
run: make build_arm64
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
test_data/
.coverage.out

sample_models/.env
15 changes: 11 additions & 4 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ linters:
linters-settings:
gomnd:
ignored-functions:
- 'strconv.ParseInt'
- 'strconv.ParseFloat'
- 'strconv.FormatInt'
- 'strconv.FormatFloat'
- "strconv.ParseInt"
- "strconv.ParseFloat"
- "strconv.FormatInt"
- "strconv.FormatFloat"
gocritic:
disabled-checks:
# In the world of AI tensor's are often denoted with a capital letter.
# We want to adopt the go style guide as much as possible but we also want
# to be able to easily show when a variable is a Tensor. So we chose to
# disable captLocal. Note that any other parameter should use a lower case letters.
- "captLocal"
issues:
max-issues-per-linter: 0
max-same-issues: 0
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ A simple example is shown below:
package main

import (
"github.com/AdvancedClimateSystems/gonnx"
"github.com/advancedclimatesystems/gonnx"
"gorgonia.org/tensor"
)

Expand Down
33 changes: 29 additions & 4 deletions errors.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
package gonnx

// InvalidShapeError is used when the shape of an input tensor does not match the expectation.
const InvalidShapeError = "input shape does not match for %v: expected %v but got %v"
import (
"errors"
"fmt"

// SetOutputTensorsError is used when the output of an operation could not be set.
const SetOutputTensorsError = "could not set output tensors, expected %v tensors but got %v"
"github.com/advancedclimatesystems/gonnx/onnx"
)

var errModel = errors.New("gonnx model error")

type InvalidShapeError struct {
expected onnx.Shape
actual []int
}

func (i InvalidShapeError) Error() string {
return fmt.Sprintf("invalid shape error expected: %v actual %v", i.expected, i.actual)
}

func ErrInvalidShape(expected onnx.Shape, actual []int) error {
return InvalidShapeError{
expected: expected,
actual: actual,
}
}

// ErrModel is used for when an error ocured during setup of running onnx models.
// The user can specify a formatted message using the standard formatting rules.
func ErrModel(format string, a ...any) error {
return fmt.Errorf("%w: %s", errModel, fmt.Sprintf(format, a...))
}
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.19

require (
github.com/stretchr/testify v1.8.1
gitlab.advancedclimate.nl/smartbase/software/core/airgo v0.0.0-b1
google.golang.org/protobuf v1.28.1
gorgonia.org/tensor v0.9.24
)
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
gitlab.advancedclimate.nl/smartbase/software/core/airgo v0.0.0-b1 h1:RkBEoyN6PzTsB8i4ZHLrBpG2q3KxPjJS5kLoLT9Op8c=
gitlab.advancedclimate.nl/smartbase/software/core/airgo v0.0.0-b1/go.mod h1:EzFgx3V68Upc4HwJDakYaztVI9AZn0ABtegpQJT0Q+4=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760 h1:FyBZqvoA/jbNzuAWLQE2kG820zMAkcilx6BMjGbL/E4=
go4.org/unsafe/assume-no-moving-gc v0.0.0-20220617031537-928513b29760/go.mod h1:FftLjUGFEDu5k8lt0ddY+HcrH/qU/0qk+H8j9/nTl3E=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
Expand Down Expand Up @@ -123,8 +121,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand Down Expand Up @@ -164,8 +162,8 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/netlib v0.0.0-20201012070519-2390d26c3658 h1:/DNJ3wcvPHjTLVNG6rmSHK7uEwdBihyiJRJXB16wXoU=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
Expand Down
46 changes: 22 additions & 24 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ package gonnx

import (
"archive/zip"
"fmt"
"io/ioutil"
"io"
"os"

"gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/onnx"
"gitlab.advancedclimate.nl/smartbase/software/core/airgo/gonnx/ops"
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"google.golang.org/protobuf/proto"
"gorgonia.org/tensor"
)

// Tensors is a map with tensors
// Tensors is a map with tensors.
type Tensors map[string]tensor.Tensor

// Model defines a model that can be used for inference.
Expand All @@ -23,7 +23,7 @@ type Model struct {

// NewModelFromFile creates a new model from a path to a file.
func NewModelFromFile(path string) (*Model, error) {
bytesModel, err := ioutil.ReadFile(path)
bytesModel, err := os.ReadFile(path)
if err != nil {
return nil, err
}
Expand All @@ -38,7 +38,7 @@ func NewModelFromZipFile(file *zip.File) (*Model, error) {
return nil, err
}

bytesModel, err := ioutil.ReadAll(fc)
bytesModel, err := io.ReadAll(fc)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -66,6 +66,7 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) {
opsetImports := mp.GetOpsetImport()

var opsetID int64

for i := 0; i < len(opsetImports); i++ {
version := opsetImports[i].GetVersion()
if version > opsetID {
Expand All @@ -78,12 +79,11 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) {
return nil, err
}

model := &Model{
return &Model{
mp: mp,
parameters: params,
GetOperator: GetOperator,
}
return model, nil
}, nil
}

// ModelProtoFromBytes creates an onnx.ModelProto based on a list of bytes.
Expand All @@ -92,6 +92,7 @@ func ModelProtoFromBytes(bytesModel []byte) (*onnx.ModelProto, error) {
if err := proto.Unmarshal(bytesModel, mp); err != nil {
return nil, err
}

return mp, nil
}

Expand All @@ -108,16 +109,13 @@ func (m *Model) InputShapes() onnx.Shapes {
// InputDimSize returns the size of the input dimension given an input tensor.
func (m *Model) InputDimSize(input string, i int) (int, error) {
if !m.hasInput(input) {
return 0, fmt.Errorf("input %v does not exist", input)
return 0, ErrModel("input %v does not exist", input)
}

inputShape := m.mp.Graph.InputShapes()[input]

if i >= len(inputShape) {
err := fmt.Errorf(
"input %v only has %d dimensions, but index %d was required", input, len(inputShape), i,
)
return 0, err
return 0, ErrModel("input %v only has %d dimensions, but index %d was required", input, len(inputShape), i)
}

return int(inputShape[i].Size), nil
Expand Down Expand Up @@ -159,7 +157,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) {
return nil, err
}

var tensors = make(Tensors)
tensors := make(Tensors)
for inputName, inputTensor := range inputs {
tensors[inputName] = inputTensor
}
Expand All @@ -179,7 +177,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) {
}
}

var outputTensors = make(Tensors)
outputTensors := make(Tensors)
for _, outputName := range m.OutputNames() {
outputTensors[outputName] = tensors[outputName]
}
Expand All @@ -189,7 +187,7 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) {

// applyOp applies the operation to the graph.
func (m *Model) applyOp(op ops.Operator, n *onnx.NodeProto, tensors Tensors) error {
if err := op.Init(n.GetAttribute()); err != nil {
if err := op.Init(n); err != nil {
return err
}

Expand Down Expand Up @@ -222,13 +220,13 @@ func (m *Model) validateShapes(inputTensors Tensors) error {

tensor, ok := inputTensors[name]
if !ok {
return fmt.Errorf("tensor: %v not found", name)
return ErrModel("tensor: %v not found", name)
}

shapeReceived := tensor.Shape()

if len(shapeReceived) != len(shapeExpected) {
return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived)
return ErrInvalidShape(shapeExpected, shapeReceived)
}

for i, dim := range shapeExpected {
Expand All @@ -239,7 +237,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error {
}

if dim.Size != int64(shapeReceived[i]) {
return fmt.Errorf(InvalidShapeError, name, shapeExpected, shapeReceived)
return ErrInvalidShape(shapeExpected, shapeReceived)
}
}
}
Expand All @@ -249,6 +247,7 @@ func (m *Model) validateShapes(inputTensors Tensors) error {

func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, error) {
var inputTensors []tensor.Tensor

for _, tensorName := range names {
// An empty name can happen in between optional inputs, like:
// [<required_input>, <optional_input>, nil, <optional_input>]
Expand All @@ -259,7 +258,7 @@ func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, e
} else if tensor, ok := tensors[tensorName]; ok {
inputTensors = append(inputTensors, tensor)
} else {
return nil, fmt.Errorf("no tensor yet for name %v", tensorName)
return nil, ErrModel("no tensor yet for name %v", tensorName)
}
}

Expand All @@ -269,9 +268,8 @@ func getInputTensorsForNode(names []string, tensors Tensors) ([]tensor.Tensor, e
func setOutputTensorsOfNode(
names []string, outputTensors []tensor.Tensor, tensors Tensors,
) error {

if len(names) != len(outputTensors) {
return fmt.Errorf(SetOutputTensorsError, len(names), len(outputTensors))
return ErrModel("could not set output tensor")
}

for i, tensor := range outputTensors {
Expand Down
Loading

0 comments on commit 3fbddf4

Please sign in to comment.