Skip to content

Commit

Permalink
Scalable operator set implementations (#219)
Browse files Browse the repository at this point in the history
* WIP on dynamic operator sets

* WIP on migrating operators

* WIP on operator migration

* Refactored versions

* Finished all operator refactors

* Fix lint errors

* Add constants file

* Remove print statement

* Fix lint last ones

* Proposal for base operators

* POC: new design for multiple operator versions

* Rewrote abs operator so it shares code

* Rewrote cos into base operator

* Refactor acosh into base operator

* Refactor xor

* Refactored unsqueeze operator

* Refactored Transpose operator

* Refactored tanh operator

* Refactored Tan operator

* Refactored input validation tests

* Refactored Sub operator

* Refactored squeeze operator

* Refactored Softmax operator

* Refactored slice, sinh and sin operators

* Refactored gemm, relu, reshape, rnn, scaler, shape and sigmoid operator

* Refactored reducemin operator

* Refactored ReduceMax operator

* Refactored PRelu operator

* Refactor Mul, Not and Or operator

* Refactored MatMul operator

* Refactored LessOrEqual operator

* Refacotred LSTM and LogSoftmax operators

* Refactored LinearRegressor operator

* Refactored Less operator

* Refactor GRU, GreaterOrEqual, Greater, Gather and Concat operators

* Refactored Expand and Equal operators

* Refactored Div operator

* Refactored Cosh operator

* Refactored cos operator

* Wip on constant of shape

* Refactored remaining operators

* Changed the way we use NewOperatorConstructor

* Small fixes

* Reinitialize operators on multiple uses

* Fix lint

* Fix lstm tests

---------

Co-authored-by: wisse <[email protected]>
  • Loading branch information
Swopper050 and wisse authored Dec 18, 2024
1 parent 7ed9d6f commit 7c2b171
Show file tree
Hide file tree
Showing 229 changed files with 7,253 additions and 4,921 deletions.
1 change: 0 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ linters:
- godot
- godox
- goerr113
- gomnd
- goprintffuncname
- govet
- ineffassign
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ test_ci: ## Run tests using normal test runner for ci output.

test_data: ## Creates test data from the ONNX test module.
rm -R ./test_data; mkdir ./test_data; touch ./test_data/
git clone --depth 1 --branch v1.15.0 https://github.com/onnx/onnx.git temp_onnx
git clone --depth 1 --branch v1.17.0 https://github.com/onnx/onnx.git temp_onnx
cp -r temp_onnx/onnx/backend/test/data/node/* ./test_data
rm -Rf temp_onnx

Expand All @@ -58,7 +58,7 @@ install: ## Install project with its depedencies.

install_lint: ## Install the linter.
curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh \
| sh -s -- -b $(shell go env GOPATH)/bin v1.50.1
| sh -s -- -b $(shell go env GOPATH)/bin v1.61.0

install_gotestsum: ## Install a tool for prettier test output.
curl -sfL https://github.com/gotestyourself/gotestsum/releases/download/v1.9.0/gotestsum_1.9.0_linux_amd64.tar.gz \
Expand Down
22 changes: 11 additions & 11 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ type Tensors map[string]tensor.Tensor

// Model defines a model that can be used for inference.
type Model struct {
mp *onnx.ModelProto
parameters Tensors
GetOperator OpGetter
mp *onnx.ModelProto
parameters Tensors
Opset Opset
}

// NewModelFromFile creates a new model from a path to a file.
Expand Down Expand Up @@ -74,15 +74,15 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) {
}
}

GetOperator, err := ResolveOperatorGetter(opsetID)
opset, err := ResolveOpset(opsetID)
if err != nil {
return nil, err
}

return &Model{
mp: mp,
parameters: params,
GetOperator: GetOperator,
mp: mp,
parameters: params,
Opset: opset,
}, nil
}

Expand Down Expand Up @@ -167,12 +167,12 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) {
}

for _, n := range m.mp.Graph.GetNode() {
op, err := m.GetOperator(n.GetOpType())
if err != nil {
return nil, err
op, ok := m.Opset[n.GetOpType()]
if !ok {
return nil, ops.ErrUnknownOperatorType(n.GetOpType())
}

if err := m.applyOp(op, n, tensors); err != nil {
if err := m.applyOp(op(), n, tensors); err != nil {
return nil, err
}
}
Expand Down
2 changes: 2 additions & 0 deletions model_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gonnx

import (
"fmt"
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
Expand Down Expand Up @@ -96,6 +97,7 @@ func TestModel(t *testing.T) {
}

for _, test := range tests {
fmt.Println(test.path)
model, err := NewModelFromFile(test.path)
assert.Nil(t, err)

Expand Down
44 changes: 44 additions & 0 deletions ops/abs/abs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package abs

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var absTypeConstraint = [][]tensor.Dtype{
{tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64},
}

// Abs represents the ONNX abs operator.
type Abs struct {
ops.BaseOperator
}

// newAbs creates a new abs operator.
func newAbs(version int, typeConstraint [][]tensor.Dtype) ops.Operator {
return &Abs{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraint,
"abs",
),
}
}

// Init initializes the abs operator.
func (a *Abs) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the abs operator.
func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := tensor.Abs(inputs[0])
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}
107 changes: 101 additions & 6 deletions ops/opset13/abs_test.go → ops/abs/abs_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package opset13
package abs

import (
"testing"
Expand Down Expand Up @@ -59,83 +59,178 @@ func TestAbs(t *testing.T) {

func TestInputValidationAbs(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
inputs []tensor.Tensor
err error
version int64
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint8{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint16{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint32{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint64{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int8{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int16{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int32{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int64{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
6,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Abs{}),
ops.ErrInvalidInputCount(0, ops.NewBaseOperator(6, 1, 1, absTypeConstraint, "abs")),
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Abs{}),
ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(6, 1, 1, absTypeConstraint, "abs")),
6,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint8{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint16{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint32{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint64{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int8{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int16{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int32{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int64{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
13,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, ops.NewBaseOperator(13, 1, 1, absTypeConstraint, "abs")),
13,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, absTypeConstraint, "abs")),
13,
},
}

for _, test := range tests {
abs := &Abs{}
abs := absVersions[test.version]()
validated, err := abs.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
Expand Down
14 changes: 14 additions & 0 deletions ops/abs/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package abs

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var absVersions = ops.OperatorVersions{
6: ops.NewOperatorConstructor(newAbs, 6, absTypeConstraint),
13: ops.NewOperatorConstructor(newAbs, 13, absTypeConstraint),
}

func GetAbsVersions() ops.OperatorVersions {
return absVersions
}
Loading

0 comments on commit 7c2b171

Please sign in to comment.