diff --git a/.golangci.yml b/.golangci.yml index e71a88c..b8d0e05 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -20,7 +20,6 @@ linters: - godot - godox - goerr113 - - gomnd - goprintffuncname - govet - ineffassign diff --git a/Makefile b/Makefile index ec0e94b..2a6d69a 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test_ci: ## Run tests using normal test runner for ci output. test_data: ## Creates test data from the ONNX test module. rm -R ./test_data; mkdir ./test_data; touch ./test_data/ - git clone --depth 1 --branch v1.15.0 https://github.com/onnx/onnx.git temp_onnx + git clone --depth 1 --branch v1.17.0 https://github.com/onnx/onnx.git temp_onnx cp -r temp_onnx/onnx/backend/test/data/node/* ./test_data rm -Rf temp_onnx @@ -58,7 +58,7 @@ install: ## Install project with its depedencies. install_lint: ## Install the linter. curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh \ - | sh -s -- -b $(shell go env GOPATH)/bin v1.50.1 + | sh -s -- -b $(shell go env GOPATH)/bin v1.61.0 install_gotestsum: ## Install a tool for prettier test output. curl -sfL https://github.com/gotestyourself/gotestsum/releases/download/v1.9.0/gotestsum_1.9.0_linux_amd64.tar.gz \ diff --git a/model.go b/model.go index 9c9fcd1..3be4bcb 100644 --- a/model.go +++ b/model.go @@ -16,9 +16,9 @@ type Tensors map[string]tensor.Tensor // Model defines a model that can be used for inference. type Model struct { - mp *onnx.ModelProto - parameters Tensors - GetOperator OpGetter + mp *onnx.ModelProto + parameters Tensors + Opset Opset } // NewModelFromFile creates a new model from a path to a file. @@ -74,15 +74,15 @@ func NewModel(mp *onnx.ModelProto) (*Model, error) { } } - GetOperator, err := ResolveOperatorGetter(opsetID) + opset, err := ResolveOpset(opsetID) if err != nil { return nil, err } return &Model{ - mp: mp, - parameters: params, - GetOperator: GetOperator, + mp: mp, + parameters: params, + Opset: opset, }, nil } @@ -167,12 +167,12 @@ func (m *Model) Run(inputs Tensors) (Tensors, error) { } for _, n := range m.mp.Graph.GetNode() { - op, err := m.GetOperator(n.GetOpType()) - if err != nil { - return nil, err + op, ok := m.Opset[n.GetOpType()] + if !ok { + return nil, ops.ErrUnknownOperatorType(n.GetOpType()) } - if err := m.applyOp(op, n, tensors); err != nil { + if err := m.applyOp(op(), n, tensors); err != nil { return nil, err } } diff --git a/model_test.go b/model_test.go index ed1c4b9..54b9ed3 100644 --- a/model_test.go +++ b/model_test.go @@ -1,6 +1,7 @@ package gonnx import ( + "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -96,6 +97,7 @@ func TestModel(t *testing.T) { } for _, test := range tests { + fmt.Println(test.path) model, err := NewModelFromFile(test.path) assert.Nil(t, err) diff --git a/ops/abs/abs.go b/ops/abs/abs.go new file mode 100644 index 0000000..9987463 --- /dev/null +++ b/ops/abs/abs.go @@ -0,0 +1,44 @@ +package abs + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var absTypeConstraint = [][]tensor.Dtype{ + {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Abs represents the ONNX abs operator. +type Abs struct { + ops.BaseOperator +} + +// newAbs creates a new abs operator. +func newAbs(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Abs{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraint, + "abs", + ), + } +} + +// Init initializes the abs operator. +func (a *Abs) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the abs operator. +func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := tensor.Abs(inputs[0]) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/opset13/abs_test.go b/ops/abs/abs_test.go similarity index 57% rename from ops/opset13/abs_test.go rename to ops/abs/abs_test.go index e9e0791..82d2c1d 100644 --- a/ops/opset13/abs_test.go +++ b/ops/abs/abs_test.go @@ -1,4 +1,4 @@ -package opset13 +package abs import ( "testing" @@ -59,83 +59,178 @@ func TestAbs(t *testing.T) { func TestInputValidationAbs(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + inputs []tensor.Tensor + err error + version int64 }{ { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint8{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint16{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int8{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int16{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, + 6, }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Abs{}), + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(6, 1, 1, absTypeConstraint, "abs")), + 6, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Abs{}), + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(6, 1, 1, absTypeConstraint, "abs")), + 6, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint16{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int16{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(13, 1, 1, absTypeConstraint, "abs")), + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, absTypeConstraint, "abs")), + 13, }, } for _, test := range tests { - abs := &Abs{} + abs := absVersions[test.version]() validated, err := abs.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/abs/versions.go b/ops/abs/versions.go new file mode 100644 index 0000000..545d455 --- /dev/null +++ b/ops/abs/versions.go @@ -0,0 +1,14 @@ +package abs + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var absVersions = ops.OperatorVersions{ + 6: ops.NewOperatorConstructor(newAbs, 6, absTypeConstraint), + 13: ops.NewOperatorConstructor(newAbs, 13, absTypeConstraint), +} + +func GetAbsVersions() ops.OperatorVersions { + return absVersions +} diff --git a/ops/acos/acos.go b/ops/acos/acos.go new file mode 100644 index 0000000..a3de114 --- /dev/null +++ b/ops/acos/acos.go @@ -0,0 +1,59 @@ +package acos + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acos represents the ONNX acos operator. +type Acos struct { + ops.BaseOperator +} + +// newAcos creates a new acos operator. +func newAcos() ops.Operator { + return &Acos{ + BaseOperator: ops.NewBaseOperator( + 7, + 1, + 1, + [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, + "acos", + ), + } +} + +// Init initializes the acos operator. +func (c *Acos) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the acos operator. +func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(acos[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(acos[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func acos[T ops.FloatType](x T) T { + return T(math.Acos(float64(x))) +} diff --git a/ops/opset13/acos_test.go b/ops/acos/acos_test.go similarity index 82% rename from ops/opset13/acos_test.go rename to ops/acos/acos_test.go index e2c755a..7eb3ed9 100644 --- a/ops/opset13/acos_test.go +++ b/ops/acos/acos_test.go @@ -1,4 +1,4 @@ -package opset13 +package acos import ( "testing" @@ -19,25 +19,25 @@ func TestAcosInit(t *testing.T) { func TestAcos(t *testing.T) { tests := []struct { - acos *Acos + acos ops.Operator backing []float32 shape []int expected []float32 }{ { - &Acos{}, + newAcos(), []float32{-1, -1, 0, 1}, []int{2, 2}, []float32{3.1415927, 3.1415927, 1.5707964, 0}, }, { - &Acos{}, + newAcos(), []float32{1, 0.5, 0.0, -0.5}, []int{1, 4}, []float32{0, 1.0471976, 1.5707964, 2.0943952}, }, { - &Acos{}, + newAcos(), []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{3.1415927, 3.1415927, 3.1415927, 3.1415927}, @@ -76,18 +76,18 @@ func TestInputValidationAcos(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Acos{}), + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(7, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acos")), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Acos{}), + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(7, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acos")), }, } for _, test := range tests { - acos := &Acos{} + acos := newAcos() validated, err := acos.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/acos/versions.go b/ops/acos/versions.go new file mode 100644 index 0000000..31d6092 --- /dev/null +++ b/ops/acos/versions.go @@ -0,0 +1,13 @@ +package acos + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var acosVersions = ops.OperatorVersions{ + 7: newAcos, +} + +func GetAcosVersions() ops.OperatorVersions { + return acosVersions +} diff --git a/ops/acosh/acosh.go b/ops/acosh/acosh.go new file mode 100644 index 0000000..8e91430 --- /dev/null +++ b/ops/acosh/acosh.go @@ -0,0 +1,53 @@ +package acosh + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Acosh represents the ONNX acosh operator. +type Acosh struct { + ops.BaseOperator +} + +// newAcosh creates a new acosh operator. +func newAcosh() ops.Operator { + return &Acosh{ + BaseOperator: ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh"), + } +} + +// Init initializes the acosh operator. +func (c *Acosh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the acosh operator. +func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(acosh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(acosh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func acosh[T ops.FloatType](x T) T { + return T(math.Acosh(float64(x))) +} diff --git a/ops/opset13/acosh_test.go b/ops/acosh/acosh_test.go similarity index 78% rename from ops/opset13/acosh_test.go rename to ops/acosh/acosh_test.go index d6c155d..ae01374 100644 --- a/ops/opset13/acosh_test.go +++ b/ops/acosh/acosh_test.go @@ -1,4 +1,4 @@ -package opset13 +package acosh import ( "testing" @@ -8,7 +8,7 @@ import ( "gorgonia.org/tensor" ) -func TestAcoshInit(t *testing.T) { +func TestAcosh9Init(t *testing.T) { c := &Acosh{} // since 'acosh' does not have any attributes we pass in nil. This should not @@ -17,27 +17,27 @@ func TestAcoshInit(t *testing.T) { assert.Nil(t, err) } -func TestAcosh(t *testing.T) { +func TestAcosh9(t *testing.T) { tests := []struct { - acosh *Acosh + acosh ops.Operator backing []float32 shape []int expected []float32 }{ { - &Acosh{}, + newAcosh(), []float32{1, 2, 3, 4}, []int{2, 2}, []float32{0, 1.316958, 1.7627472, 2.063437}, }, { - &Acosh{}, + newAcosh(), []float32{1, 2, 3, 4}, []int{1, 4}, []float32{0, 1.316958, 1.7627472, 2.063437}, }, { - &Acosh{}, + newAcosh(), []float32{2, 2, 2, 2}, []int{1, 4}, []float32{1.316958, 1.316958, 1.316958, 1.316958}, @@ -76,18 +76,18 @@ func TestInputValidationAcosh(t *testing.T) { }, { []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Acosh{}), + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Acosh{}), + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "acosh")), }, } for _, test := range tests { - acosh := &Acosh{} + acosh := newAcosh() validated, err := acosh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops/acosh/versions.go b/ops/acosh/versions.go new file mode 100644 index 0000000..0953e5d --- /dev/null +++ b/ops/acosh/versions.go @@ -0,0 +1,13 @@ +package acosh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var acoshVersions = ops.OperatorVersions{ + 9: newAcosh, +} + +func GetAcoshVersions() ops.OperatorVersions { + return acoshVersions +} diff --git a/ops/add/add.go b/ops/add/add.go new file mode 100644 index 0000000..7c03d59 --- /dev/null +++ b/ops/add/add.go @@ -0,0 +1,45 @@ +package add + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var addTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Add represents the ONNX add operator. +type Add struct { + ops.BaseOperator +} + +// newAdd creates a new add operator. +func newAdd(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Add{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "add", + ), + } +} + +// Init initializes the add operator. +func (a *Add) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the add operator. +func (a *Add) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Add, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/add_test.go b/ops/add/add_test.go similarity index 84% rename from ops/opset13/add_test.go rename to ops/add/add_test.go index f7dacd1..faf07cf 100644 --- a/ops/opset13/add_test.go +++ b/ops/add/add_test.go @@ -1,4 +1,4 @@ -package opset13 +package add import ( "testing" @@ -19,25 +19,25 @@ func TestAddInit(t *testing.T) { func TestAdd(t *testing.T) { tests := []struct { - add *Add + version int64 backings [][]float32 shapes [][]int expected []float32 }{ { - &Add{}, + 13, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []float32{1, 2, 3, 4}, }, { - &Add{}, + 13, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []float32{2, 3, 4, 5, 6, 7}, }, { - &Add{}, + 13, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []float32{0, 2, 2, 4}, @@ -50,7 +50,9 @@ func TestAdd(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.add.Apply(inputs) + add := addVersions[test.version]() + + res, err := add.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) @@ -70,10 +72,12 @@ func TestAddFail(t *testing.T) { func TestInputValidationAdd(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -81,6 +85,7 @@ func TestInputValidationAdd(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -88,6 +93,7 @@ func TestInputValidationAdd(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -95,6 +101,7 @@ func TestInputValidationAdd(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -102,6 +109,7 @@ func TestInputValidationAdd(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -109,6 +117,7 @@ func TestInputValidationAdd(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -116,22 +125,24 @@ func TestInputValidationAdd(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Add{}), + ops.ErrInvalidInputCount(1, add13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Add{}), + ops.ErrInvalidInputType(0, "int", add13BaseOpFixture()), }, } for _, test := range tests { - add := &Add{} + add := addVersions[test.version]() validated, err := add.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -141,3 +152,7 @@ func TestInputValidationAdd(t *testing.T) { } } } + +func add13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, addTypeConstraints, "add") +} diff --git a/ops/add/versions.go b/ops/add/versions.go new file mode 100644 index 0000000..e0ff80f --- /dev/null +++ b/ops/add/versions.go @@ -0,0 +1,14 @@ +package add + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var addVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newAdd, 7, addTypeConstraints), + 13: ops.NewOperatorConstructor(newAdd, 13, addTypeConstraints), +} + +func GetAddVersions() ops.OperatorVersions { + return addVersions +} diff --git a/ops/and/and.go b/ops/and/and.go new file mode 100644 index 0000000..21a2cdf --- /dev/null +++ b/ops/and/and.go @@ -0,0 +1,42 @@ +package and + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var andTypeConstraints = [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} + +// And represents the ONNX and operator. +type And struct { + ops.BaseOperator +} + +// newAnd creates a new and operator. +func newAnd(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &And{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "and", + ), + } +} + +// Init initializes the and operator. +func (a *And) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the and operator. +func (a *And) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.And, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/and_test.go b/ops/and/and_test.go similarity index 81% rename from ops/opset13/and_test.go rename to ops/and/and_test.go index b17fc35..201c9ca 100644 --- a/ops/opset13/and_test.go +++ b/ops/and/and_test.go @@ -1,4 +1,4 @@ -package opset13 +package and import ( "testing" @@ -19,31 +19,31 @@ func TestAndInit(t *testing.T) { func TestAnd(t *testing.T) { tests := []struct { - and *And + version int64 backings [][]bool shapes [][]int expected []bool }{ { - &And{}, + 7, [][]bool{{true, false, true, false}, {true, true, true, false}}, [][]int{{2, 2}, {2, 2}}, []bool{true, false, true, false}, }, { - &And{}, + 7, [][]bool{{true, false, true, false}, {true, false}}, [][]int{{2, 2}, {1, 2}}, []bool{true, false, true, false}, }, { - &And{}, + 7, [][]bool{{true, false, true, false}, {true, false}}, [][]int{{2, 2}, {2, 1}}, []bool{true, false, false, false}, }, { - &And{}, + 7, [][]bool{{true, false, true, false, true, false}, {false, false}}, [][]int{{3, 2}, {1, 2}}, []bool{false, false, false, false, false, false}, @@ -56,7 +56,9 @@ func TestAnd(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.and.Apply(inputs) + and := andVersions[test.version]() + + res, err := and.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -66,10 +68,12 @@ func TestAnd(t *testing.T) { func TestInputValidationAnd(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), ops.TensorWithBackingFixture([]bool{false, false}, 2), @@ -77,22 +81,24 @@ func TestInputValidationAnd(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), }, - ops.ErrInvalidInputCount(1, &And{}), + ops.ErrInvalidInputCount(1, and7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(1, "int", &And{}), + ops.ErrInvalidInputType(1, "int", and7BaseOpFixture()), }, } for _, test := range tests { - and := &And{} + and := andVersions[test.version]() validated, err := and.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -102,3 +108,7 @@ func TestInputValidationAnd(t *testing.T) { } } } + +func and7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, andTypeConstraints, "and") +} diff --git a/ops/and/versions.go b/ops/and/versions.go new file mode 100644 index 0000000..111cba9 --- /dev/null +++ b/ops/and/versions.go @@ -0,0 +1,13 @@ +package and + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var andVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newAnd, 7, andTypeConstraints), +} + +func GetAndVersions() ops.OperatorVersions { + return andVersions +} diff --git a/ops/opset13/argmax.go b/ops/argmax/argmax.go similarity index 62% rename from ops/opset13/argmax.go rename to ops/argmax/argmax.go index 5150ea7..fbb17d0 100644 --- a/ops/opset13/argmax.go +++ b/ops/argmax/argmax.go @@ -1,4 +1,4 @@ -package opset13 +package argmax import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,44 +6,44 @@ import ( "gorgonia.org/tensor" ) -const ( - MinArgMaxInputs = 1 - MaxArgMaxInputs = 1 -) +var argMaxTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} // ArgMax represents the ONNX argmax operator. type ArgMax struct { + ops.BaseOperator + axis int keepDims bool selectLastIndex bool } // newArgMax creates a new argmax operator. -func newArgMax() ops.Operator { +func newArgMax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &ArgMax{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "argmax", + ), keepDims: true, selectLastIndex: false, } } -type ArgMaxAttribute string - -const ( - axis = "axis" - keepDims = "keepdims" - selectLastIndex = "select_last_index" -) - // Init initializes the argmax operator. func (a *ArgMax) Init(n *onnx.NodeProto) error { attributes := n.GetAttribute() for _, attr := range attributes { switch attr.GetName() { - case axis: + case "axis": a.axis = int(attr.GetI()) - case keepDims: + case "keepdims": a.keepDims = ops.Int64ToBool(attr.GetI()) - case selectLastIndex: + case "select_last_index": a.selectLastIndex = ops.Int64ToBool(attr.GetI()) // We have no way yet to perform argmax and keeping the @@ -97,31 +97,3 @@ func (a *ArgMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{reduced}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *ArgMax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *ArgMax) GetMinInputs() int { - return MinArgMaxInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *ArgMax) GetMaxInputs() int { - return MaxArgMaxInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *ArgMax) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *ArgMax) String() string { - return "argmax operator" -} diff --git a/ops/opset13/argmax_test.go b/ops/argmax/argmax_test.go similarity index 71% rename from ops/opset13/argmax_test.go rename to ops/argmax/argmax_test.go index ea6fef8..fc7a245 100644 --- a/ops/opset13/argmax_test.go +++ b/ops/argmax/argmax_test.go @@ -1,4 +1,4 @@ -package opset13 +package argmax import ( "testing" @@ -30,21 +30,34 @@ func TestArgMaxInit(t *testing.T) { func TestArgMax(t *testing.T) { tests := []struct { - argmax *ArgMax + version int64 + node *onnx.NodeProto backing []float32 shape []int expectedShape tensor.Shape expectedData []int64 }{ { - &ArgMax{axis: 0, keepDims: true}, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: 0}, + {Name: "keepdims", I: 1}, + }, + }, []float32{0, 1, 2, 3}, []int{2, 2}, []int{1, 2}, []int64{1, 1}, }, { - &ArgMax{axis: -1, keepDims: true}, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: -1}, + {Name: "keepdims", I: 1}, + }, + }, []float32{0, 1, 2, 3}, []int{2, 2}, []int{2, 1}, @@ -57,7 +70,11 @@ func TestArgMax(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.argmax.Apply(inputs) + argmax := argMaxVersions[test.version]() + err := argmax.Init(test.node) + assert.Nil(t, err) + + res, err := argmax.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expectedShape, res[0].Shape()) @@ -67,62 +84,71 @@ func TestArgMax(t *testing.T) { func TestInputValidationArgMax(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, - ops.ErrInvalidInputCount(2, &ArgMax{}), + ops.ErrInvalidInputCount(2, argMax13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &ArgMax{}), + ops.ErrInvalidInputType(0, "int", argMax13BaseOpFixture()), }, } for _, test := range tests { - argmax := &ArgMax{} + argmax := argMaxVersions[test.version]() validated, err := argmax.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -132,3 +158,7 @@ func TestInputValidationArgMax(t *testing.T) { } } } + +func argMax13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, argMaxTypeConstraints, "argmax") +} diff --git a/ops/argmax/versions.go b/ops/argmax/versions.go new file mode 100644 index 0000000..7f4ac6d --- /dev/null +++ b/ops/argmax/versions.go @@ -0,0 +1,15 @@ +package argmax + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var argMaxVersions = ops.OperatorVersions{ + 11: ops.NewOperatorConstructor(newArgMax, 11, argMaxTypeConstraints), + 12: ops.NewOperatorConstructor(newArgMax, 12, argMaxTypeConstraints), + 13: ops.NewOperatorConstructor(newArgMax, 13, argMaxTypeConstraints), +} + +func GetArgMaxVersions() ops.OperatorVersions { + return argMaxVersions +} diff --git a/ops/asin/asin.go b/ops/asin/asin.go new file mode 100644 index 0000000..babe8ce --- /dev/null +++ b/ops/asin/asin.go @@ -0,0 +1,61 @@ +package asin + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var asinTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Asin represents the ONNX asin operator. +type Asin struct { + ops.BaseOperator +} + +// newSin creates a new asin operator. +func newAsin(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Asin{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "asin", + ), + } +} + +// Init initializes the asin operator. +func (s *Asin) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the asin operator. +func (s *Asin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(asin[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(asin[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func asin[T ops.FloatType](x T) T { + return T(math.Asin(float64(x))) +} diff --git a/ops/opset13/asin_test.go b/ops/asin/asin_test.go similarity index 77% rename from ops/opset13/asin_test.go rename to ops/asin/asin_test.go index c145649..c9e7a37 100644 --- a/ops/opset13/asin_test.go +++ b/ops/asin/asin_test.go @@ -1,4 +1,4 @@ -package opset13 +package asin import ( "testing" @@ -19,25 +19,25 @@ func TestAsinInit(t *testing.T) { func TestAsin(t *testing.T) { tests := []struct { - asin *Asin + version int64 backing []float32 shape []int expected []float32 }{ { - &Asin{}, + 7, []float32{-1, -1, 0, 1}, []int{2, 2}, []float32{-1.5707964, -1.5707964, 0, 1.5707964}, }, { - &Asin{}, + 7, []float32{1, 0.5, 0.0, -0.5}, []int{1, 4}, []float32{1.5707964, 0.5235988, 0, -0.5235988}, }, { - &Asin{}, + 7, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{-1.5707964, -1.5707964, -1.5707964, -1.5707964}, @@ -49,7 +49,9 @@ func TestAsin(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.asin.Apply(inputs) + asin := asinVersions[test.version]() + + res, err := asin.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +61,40 @@ func TestAsin(t *testing.T) { func TestInputValidationAsin(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Asin{}), + ops.ErrInvalidInputCount(0, asin7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Asin{}), + ops.ErrInvalidInputType(0, "int", asin7BaseOpFixture()), }, } for _, test := range tests { - asin := &Asin{} + asin := asinVersions[test.version]() validated, err := asin.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +104,7 @@ func TestInputValidationAsin(t *testing.T) { } } } + +func asin7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 1, 1, asinTypeConstraints, "asin") +} diff --git a/ops/asin/versions.go b/ops/asin/versions.go new file mode 100644 index 0000000..e8f7626 --- /dev/null +++ b/ops/asin/versions.go @@ -0,0 +1,13 @@ +package asin + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var asinVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newAsin, 7, asinTypeConstraints), +} + +func GetAsinVersions() ops.OperatorVersions { + return asinVersions +} diff --git a/ops/asinh/asinh.go b/ops/asinh/asinh.go new file mode 100644 index 0000000..209eb8b --- /dev/null +++ b/ops/asinh/asinh.go @@ -0,0 +1,61 @@ +package asinh + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var asinhTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Asinh represents the ONNX asinh operator. +type Asinh struct { + ops.BaseOperator +} + +// newAsinh creates a new asinh operator. +func newAsinh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Asinh{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "asinh", + ), + } +} + +// Init initializes the asinh operator. +func (a *Asinh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the asinh operator. +func (a *Asinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(asinh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(asinh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func asinh[T ops.FloatType](x T) T { + return T(math.Asinh(float64(x))) +} diff --git a/ops/opset13/asinh_test.go b/ops/asinh/asinh_test.go similarity index 76% rename from ops/opset13/asinh_test.go rename to ops/asinh/asinh_test.go index da5c6fc..3dbf8db 100644 --- a/ops/opset13/asinh_test.go +++ b/ops/asinh/asinh_test.go @@ -1,4 +1,4 @@ -package opset13 +package asinh import ( "testing" @@ -19,25 +19,25 @@ func TestAsinhInit(t *testing.T) { func TestAsinh(t *testing.T) { tests := []struct { - asinh *Asinh + version int64 backing []float32 shape []int expected []float32 }{ { - &Asinh{}, + 9, []float32{1, 2, 3, 4}, []int{2, 2}, []float32{0.8813736, 1.4436355, 1.8184465, 2.0947125}, }, { - &Asinh{}, + 9, []float32{1, 2, 3, 4}, []int{1, 4}, []float32{0.8813736, 1.4436355, 1.8184465, 2.0947125}, }, { - &Asinh{}, + 9, []float32{2, 2, 2, 2}, []int{1, 4}, []float32{1.4436355, 1.4436355, 1.4436355, 1.4436355}, @@ -49,7 +49,9 @@ func TestAsinh(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.asinh.Apply(inputs) + asinh := asinhVersions[test.version]() + + res, err := asinh.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +61,40 @@ func TestAsinh(t *testing.T) { func TestInputValidationAsinh(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Asinh{}), + ops.ErrInvalidInputCount(0, asinh9BaseOpFixture()), }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Asinh{}), + ops.ErrInvalidInputType(0, "int", asinh9BaseOpFixture()), }, } for _, test := range tests { - asinh := &Asinh{} + asinh := asinhVersions[test.version]() validated, err := asinh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +104,7 @@ func TestInputValidationAsinh(t *testing.T) { } } } + +func asinh9BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(9, 1, 1, asinhTypeConstraints, "asinh") +} diff --git a/ops/asinh/versions.go b/ops/asinh/versions.go new file mode 100644 index 0000000..5c9588e --- /dev/null +++ b/ops/asinh/versions.go @@ -0,0 +1,13 @@ +package asinh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var asinhVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newAsinh, 9, asinhTypeConstraints), +} + +func GetAsinhVersions() ops.OperatorVersions { + return asinhVersions +} diff --git a/ops/atan/atan.go b/ops/atan/atan.go new file mode 100644 index 0000000..1cf6c0b --- /dev/null +++ b/ops/atan/atan.go @@ -0,0 +1,61 @@ +package atan + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var atanTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Atan represents the ONNX atan operator. +type Atan struct { + ops.BaseOperator +} + +// newAtan creates a new atan operator. +func newAtan(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Atan{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "atan", + ), + } +} + +// Init initializes the atan operator. +func (a *Atan) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the atan operator. +func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(atan[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(atan[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func atan[T ops.FloatType](x T) T { + return T(math.Atan(float64(x))) +} diff --git a/ops/opset13/atan_test.go b/ops/atan/atan_test.go similarity index 77% rename from ops/opset13/atan_test.go rename to ops/atan/atan_test.go index f6d1d97..cf861d6 100644 --- a/ops/opset13/atan_test.go +++ b/ops/atan/atan_test.go @@ -1,4 +1,4 @@ -package opset13 +package atan import ( "testing" @@ -19,25 +19,25 @@ func TestAtanInit(t *testing.T) { func TestAtan(t *testing.T) { tests := []struct { - atan *Atan + version int64 backing []float32 shape []int expected []float32 }{ { - &Atan{}, + 7, []float32{1, 2, 3, 4}, []int{2, 2}, []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, }, { - &Atan{}, + 7, []float32{1, 2, 3, 4}, []int{1, 4}, []float32{0.7853982, 1.1071488, 1.2490457, 1.3258177}, }, { - &Atan{}, + 7, []float32{2, 2, 2, 2}, []int{1, 4}, []float32{1.1071488, 1.1071488, 1.1071488, 1.1071488}, @@ -49,7 +49,9 @@ func TestAtan(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.atan.Apply(inputs) + atan := atanVersions[test.version]() + + res, err := atan.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +61,40 @@ func TestAtan(t *testing.T) { func TestInputValidationAtan(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Atan{}), + ops.ErrInvalidInputCount(0, atan7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Atan{}), + ops.ErrInvalidInputType(0, "int", atan7BaseOpFixture()), }, } for _, test := range tests { - atan := &Atan{} + atan := atanVersions[test.version]() validated, err := atan.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +104,7 @@ func TestInputValidationAtan(t *testing.T) { } } } + +func atan7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 1, 1, atanTypeConstraints, "atan") +} diff --git a/ops/atan/versions.go b/ops/atan/versions.go new file mode 100644 index 0000000..af63f8a --- /dev/null +++ b/ops/atan/versions.go @@ -0,0 +1,13 @@ +package atan + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var atanVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newAtan, 7, atanTypeConstraints), +} + +func GetAtanVersions() ops.OperatorVersions { + return atanVersions +} diff --git a/ops/atanh/atanh.go b/ops/atanh/atanh.go new file mode 100644 index 0000000..6170dfb --- /dev/null +++ b/ops/atanh/atanh.go @@ -0,0 +1,61 @@ +package atanh + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var atanhTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Atanh represents the ONNX atanh operator. +type Atanh struct { + ops.BaseOperator +} + +// newAtanh creates a new atanh operator. +func newAtanh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Atanh{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "atanh", + ), + } +} + +// Init initializes the atanh operator. +func (a *Atanh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the atanh operator. +func (a *Atanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(atanh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(atanh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func atanh[T ops.FloatType](x T) T { + return T(math.Atanh(float64(x))) +} diff --git a/ops/opset13/atanh_test.go b/ops/atanh/atanh_test.go similarity index 76% rename from ops/opset13/atanh_test.go rename to ops/atanh/atanh_test.go index 65441a7..73196d2 100644 --- a/ops/opset13/atanh_test.go +++ b/ops/atanh/atanh_test.go @@ -1,4 +1,4 @@ -package opset13 +package atanh import ( "testing" @@ -19,25 +19,25 @@ func TestAtanhInit(t *testing.T) { func TestAtanh(t *testing.T) { tests := []struct { - atanh *Atanh + version int64 backing []float32 shape []int expected []float32 }{ { - &Atanh{}, + 9, []float32{-0.9, -0.5, 0, 0.5}, []int{2, 2}, []float32{-1.4722193, -0.54930615, 0, 0.54930615}, }, { - &Atanh{}, + 9, []float32{-0.9, -0.5, 0, 0.5}, []int{1, 4}, []float32{-1.4722193, -0.54930615, 0, 0.54930615}, }, { - &Atanh{}, + 9, []float32{0.5, 0.5, 0.5, 0.5}, []int{1, 4}, []float32{0.54930615, 0.54930615, 0.54930615, 0.54930615}, @@ -49,7 +49,9 @@ func TestAtanh(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.atanh.Apply(inputs) + atanh := atanhVersions[test.version]() + + res, err := atanh.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +61,40 @@ func TestAtanh(t *testing.T) { func TestInputValidationAtanh(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Atanh{}), + ops.ErrInvalidInputCount(0, atanh9BaseOpFixture()), }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Atanh{}), + ops.ErrInvalidInputType(0, "int", atanh9BaseOpFixture()), }, } for _, test := range tests { - atanh := &Atanh{} + atanh := atanhVersions[test.version]() validated, err := atanh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +104,7 @@ func TestInputValidationAtanh(t *testing.T) { } } } + +func atanh9BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(9, 1, 1, atanhTypeConstraints, "atanh") +} diff --git a/ops/atanh/versions.go b/ops/atanh/versions.go new file mode 100644 index 0000000..1acf1b4 --- /dev/null +++ b/ops/atanh/versions.go @@ -0,0 +1,13 @@ +package atanh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var atanhVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newAtanh, 9, atanhTypeConstraints), +} + +func GetAtanhVersions() ops.OperatorVersions { + return atanhVersions +} diff --git a/ops/base.go b/ops/base.go new file mode 100644 index 0000000..f1ce4e2 --- /dev/null +++ b/ops/base.go @@ -0,0 +1,55 @@ +package ops + +import ( + "fmt" + + "gorgonia.org/tensor" +) + +// Concrete implementation for shared operator methods. +type BaseOperator struct { + name string + version int + minInputs int + maxInputs int + inputTypeConstraints [][]tensor.Dtype +} + +func NewBaseOperator(version, minInputs, maxInputs int, inputTypeConstraints [][]tensor.Dtype, name string) BaseOperator { + return BaseOperator{ + name: name, + version: version, + minInputs: minInputs, + maxInputs: maxInputs, + inputTypeConstraints: inputTypeConstraints, + } +} + +// ValidateInputs validates the inputs for the operator. +func (f BaseOperator) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ValidateInputs(f, inputs) +} + +// Version returns the version of the operator. +func (f BaseOperator) Version() int { + return f.version +} + +// GetMinInputs returns the minimum number of input tensors. +func (f BaseOperator) GetMinInputs() int { + return f.minInputs +} + +// GetMaxInputs returns the maximum number of input tensors. +func (f BaseOperator) GetMaxInputs() int { + return f.maxInputs +} + +// GetInputTypeConstraints returns allowed input types. +func (f BaseOperator) GetInputTypeConstraints() [][]tensor.Dtype { + return f.inputTypeConstraints +} + +func (f BaseOperator) String() string { + return fmt.Sprintf("%s v%d", f.name, f.version) +} diff --git a/ops/cast/cast.go b/ops/cast/cast.go new file mode 100644 index 0000000..71f63a8 --- /dev/null +++ b/ops/cast/cast.go @@ -0,0 +1,59 @@ +package cast + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var castTypeConstraints = [][]tensor.Dtype{ + {tensor.Int16, tensor.Uint16, tensor.Int32, tensor.Uint32, tensor.Int64, tensor.Uint64, tensor.Float32, tensor.Float64}, +} + +// Cast represents the ONNX cast operator. +type Cast struct { + ops.BaseOperator + + to int32 // DataType to cast to, as defined by TensorProto +} + +// newCast creates a new cast operator. +func newCast(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Cast{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "cast", + ), + } +} + +// Init initializes the cast operator. +func (c *Cast) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + if attr.GetName() == "to" { + c.to = int32(attr.GetI()) + } else { + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the cast operator. +func (c *Cast) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := ops.ConvertTensorDtype(inputs[0], c.to) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/opset13/cast_test.go b/ops/cast/cast_test.go similarity index 74% rename from ops/opset13/cast_test.go rename to ops/cast/cast_test.go index 74d4648..5d22875 100644 --- a/ops/opset13/cast_test.go +++ b/ops/cast/cast_test.go @@ -1,4 +1,4 @@ -package opset13 +package cast import ( "testing" @@ -19,42 +19,42 @@ func TestCastInit(t *testing.T) { func TestCast(t *testing.T) { tests := []struct { - cast *Cast + version int64 backing interface{} shape []int to int64 expected interface{} }{ { - &Cast{}, + 13, []float32{1.0, 1.0}, []int{2}, 11, []float64{1.0, 1.0}, }, { - &Cast{}, + 9, []float32{1.3, 1.8}, []int{2}, 4, []uint16{1, 1}, }, { - &Cast{}, + 6, []int8{1, 1}, []int{2}, 1, []float32{1.0, 1.0}, }, { - &Cast{}, + 13, []int64{1, 1}, []int{2}, 11, []float64{1.0, 1.0}, }, { - &Cast{}, + 13, []float64{1.4, 1.5}, []int{2}, 3, @@ -63,10 +63,11 @@ func TestCast(t *testing.T) { } for _, test := range tests { - _ = test.cast.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: test.to}}}) + cast := castVersions[test.version]() + _ = cast.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "to", I: test.to}}}) inputs := []tensor.Tensor{ops.TensorWithBackingFixture(test.backing, test.shape...)} - res, err := test.cast.Apply(inputs) + res, err := cast.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -76,34 +77,39 @@ func TestCast(t *testing.T) { func TestInputValidationCast(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]uint32{1, 2}, 2)}, nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), }, - ops.ErrInvalidInputCount(2, &Cast{}), + ops.ErrInvalidInputCount(2, cast13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{true, false}, 2), }, - ops.ErrInvalidInputType(0, "bool", &Cast{}), + ops.ErrInvalidInputType(0, "bool", cast13BaseOpFixture()), }, } for _, test := range tests { - cast := &Cast{} + cast := castVersions[test.version]() validated, err := cast.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -113,3 +119,7 @@ func TestInputValidationCast(t *testing.T) { } } } + +func cast13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, castTypeConstraints, "cast") +} diff --git a/ops/cast/versions.go b/ops/cast/versions.go new file mode 100644 index 0000000..9cbdf0b --- /dev/null +++ b/ops/cast/versions.go @@ -0,0 +1,15 @@ +package cast + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var castVersions = ops.OperatorVersions{ + 6: ops.NewOperatorConstructor(newCast, 6, castTypeConstraints), + 9: ops.NewOperatorConstructor(newCast, 9, castTypeConstraints), + 13: ops.NewOperatorConstructor(newCast, 13, castTypeConstraints), +} + +func GetCastVersions() ops.OperatorVersions { + return castVersions +} diff --git a/ops/opset13/concat.go b/ops/concat/concat.go similarity index 50% rename from ops/opset13/concat.go rename to ops/concat/concat.go index a7a24a0..2e32538 100644 --- a/ops/opset13/concat.go +++ b/ops/concat/concat.go @@ -1,4 +1,4 @@ -package opset13 +package concat import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,20 +6,30 @@ import ( "gorgonia.org/tensor" ) +var concatTypeConstraints = [][]tensor.Dtype{ops.AllTypes} + const ( MinConcatInputs = 1 ) // Concat represents the ONNX concat operator. type Concat struct { - axis int - maxInputs int - inputTypeConstraints [][]tensor.Dtype + ops.BaseOperator + + axis int } // newConcat creates a new concat operator. -func newConcat() ops.Operator { - return &Concat{} +func newConcat(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Concat{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "concat", + ), + } } // Init initializes the concat operator. @@ -56,36 +66,22 @@ func (c *Concat) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } // ValidateInputs validates the inputs that will be given to Apply for this operator. +// Because Concat can have an infinite number of inputs, we set the maximum number +// of inputs dynamically, based on our inputs. Every input can have any type. func (c *Concat) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - // Because Concat can have an infinite number of inputs, we set the maximum number - // of inputs dynamically, based on our inputs. Every input can have any type. - c.maxInputs = len(inputs) - c.inputTypeConstraints = make([][]tensor.Dtype, len(inputs)) + inputTypeConstraints := make([][]tensor.Dtype, len(inputs)) for i := 0; i < len(inputs); i++ { - c.inputTypeConstraints[i] = ops.AllTypes + inputTypeConstraints[i] = ops.AllTypes } - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Concat) GetMinInputs() int { - return MinConcatInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Concat) GetMaxInputs() int { - return c.maxInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Concat) GetInputTypeConstraints() [][]tensor.Dtype { - return c.inputTypeConstraints -} + c.BaseOperator = ops.NewBaseOperator( + c.BaseOperator.Version(), + c.BaseOperator.GetMinInputs(), + len(inputs), + inputTypeConstraints, + "concat", + ) -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Concat) String() string { - return "concat operator" + return ops.ValidateInputs(c.BaseOperator, inputs) } diff --git a/ops/opset13/concat_test.go b/ops/concat/concat_test.go similarity index 68% rename from ops/opset13/concat_test.go rename to ops/concat/concat_test.go index 01fd033..77b9ffa 100644 --- a/ops/opset13/concat_test.go +++ b/ops/concat/concat_test.go @@ -1,4 +1,4 @@ -package opset13 +package concat import ( "testing" @@ -27,21 +27,24 @@ func TestConcatInitFail(t *testing.T) { func TestConcat(t *testing.T) { tests := []struct { - concat *Concat + version int64 + node *onnx.NodeProto backings [][]float32 shapes [][]int expectedShape tensor.Shape expectedBacking []float32 }{ { - &Concat{1, 2, [][]tensor.Dtype{ops.AllTypes, ops.AllTypes}}, + 13, + &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 1}}}, [][]float32{{0, 1, 2, 3}, {10, 20}}, [][]int{{2, 2}, {2, 1}}, []int{2, 3}, []float32{0, 1, 10, 2, 3, 20}, }, { - &Concat{1, 2, [][]tensor.Dtype{ops.AllTypes, ops.AllTypes}}, + 13, + &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 1}}}, [][]float32{{0, 1, 2, 3}, {10, 20, 30, 40, 50, 60}}, [][]int{{2, 2}, {2, 3}}, []int{2, 5}, @@ -55,7 +58,14 @@ func TestConcat(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.concat.Apply(inputs) + concat := concatVersions[test.version]() + err := concat.Init(test.node) + assert.Nil(t, err) + + inputs, err = concat.ValidateInputs(inputs) + assert.Nil(t, err) + + res, err := concat.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expectedShape, res[0].Shape()) @@ -65,24 +75,31 @@ func TestConcat(t *testing.T) { func TestInputValidationConcat(t *testing.T) { tests := []struct { - concat ops.Operator - inputs []tensor.Tensor + version int64 + node *onnx.NodeProto + inputs []tensor.Tensor }{ { - &Concat{1, 2, [][]tensor.Dtype{ops.AllTypes, ops.AllTypes}}, + 13, + &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 1}}}, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), }, }, { - &Concat{1, 1, [][]tensor.Dtype{ops.AllTypes}}, + 13, + &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 1}}}, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, }, } for _, test := range tests { - validated, err := test.concat.ValidateInputs(test.inputs) + concat := concatVersions[test.version]() + err := concat.Init(test.node) + assert.Nil(t, err) + + validated, err := concat.ValidateInputs(test.inputs) assert.Nil(t, err) assert.Equal(t, test.inputs, validated) } diff --git a/ops/concat/versions.go b/ops/concat/versions.go new file mode 100644 index 0000000..5c46e48 --- /dev/null +++ b/ops/concat/versions.go @@ -0,0 +1,15 @@ +package concat + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var concatVersions = ops.OperatorVersions{ + 4: ops.NewOperatorConstructor(newConcat, 4, concatTypeConstraints), + 11: ops.NewOperatorConstructor(newConcat, 11, concatTypeConstraints), + 13: ops.NewOperatorConstructor(newConcat, 13, concatTypeConstraints), +} + +func GetConcatVersions() ops.OperatorVersions { + return concatVersions +} diff --git a/ops/opset13/constant.go b/ops/constant/constant.go similarity index 55% rename from ops/opset13/constant.go rename to ops/constant/constant.go index d0c1261..7c926c1 100644 --- a/ops/opset13/constant.go +++ b/ops/constant/constant.go @@ -1,4 +1,4 @@ -package opset13 +package constant import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -8,12 +8,22 @@ import ( // Constant represents the ONNX constant operator. type Constant struct { + ops.BaseOperator + value tensor.Tensor } // newConstant creates a new constant operator. -func newConstant() ops.Operator { - return &Constant{} +func newConstant(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Constant{ + BaseOperator: ops.NewBaseOperator( + version, + 0, + 0, + typeConstraints, + "constant", + ), + } } // Init initializes the constant operator. It supports all constant types except @@ -27,23 +37,23 @@ func (c *Constant) Init(n *onnx.NodeProto) error { attr := attributes[0] switch attr.GetName() { - case "sparse_value", "value_string", "value_strings": + case sparseValue, valueString, valueStrings: return ops.ErrUnsupportedAttribute(attr.GetName(), c) - case "value": + case value: t, err := onnx.TensorFromProto(attr.GetT()) if err != nil { return err } c.value = t - case "value_float": + case valueFloat: c.value = tensor.New(tensor.FromScalar(attr.GetF())) - case "value_floats": + case valueFloats: floats := attr.GetFloats() c.value = tensor.New(tensor.WithShape(len(floats)), tensor.WithBacking(floats)) - case "value_int": + case valueInt: c.value = tensor.New(tensor.FromScalar(attr.GetI())) - case "value_ints": + case valueInts: ints := attr.GetInts() c.value = tensor.New(tensor.WithShape(len(ints)), tensor.WithBacking(ints)) default: @@ -57,29 +67,3 @@ func (c *Constant) Init(n *onnx.NodeProto) error { func (c *Constant) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{c.value}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Constant) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Constant) GetMinInputs() int { - return 0 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Constant) GetMaxInputs() int { - return 0 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Constant) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Constant) String() string { - return "constant operator" -} diff --git a/ops/constant/constant_11.go b/ops/constant/constant_11.go new file mode 100644 index 0000000..f976356 --- /dev/null +++ b/ops/constant/constant_11.go @@ -0,0 +1,59 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Constant11 represents the ONNX constant operator. +type Constant11 struct { + ops.BaseOperator + + value tensor.Tensor +} + +// newConstant11 creates a new constant operator. +func newConstant11() ops.Operator { + return &Constant11{ + BaseOperator: ops.NewBaseOperator( + 11, + 0, + 0, + [][]tensor.Dtype{}, + "constant", + ), + } +} + +// Init initializes the constant operator. It supports all constant types except +// `sparse_value`. +func (c *Constant11) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + + switch attr.GetName() { + case sparseValue: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + case value: + t, err := onnx.TensorFromProto(attr.GetT()) + if err != nil { + return err + } + + c.value = t + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the constant operator. +func (c *Constant11) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { + return []tensor.Tensor{c.value}, nil +} diff --git a/ops/constant/constant_legacy.go b/ops/constant/constant_legacy.go new file mode 100644 index 0000000..e005287 --- /dev/null +++ b/ops/constant/constant_legacy.go @@ -0,0 +1,56 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Constant9 represents the ONNX constant operator for version 9 and 1. +type Constant9 struct { + ops.BaseOperator + + value tensor.Tensor +} + +// newConstant9 creates a new constant operator. +func newConstant9(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Constant9{ + BaseOperator: ops.NewBaseOperator( + version, + 0, + 0, + typeConstraints, + "constant", + ), + } +} + +// Init initializes the constant operator. +func (c *Constant9) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + if len(attributes) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attributes), c) + } + + attr := attributes[0] + + switch attr.GetName() { + case value: + t, err := onnx.TensorFromProto(attr.GetT()) + if err != nil { + return err + } + + c.value = t + default: + return ops.ErrUnsupportedAttribute(attr.GetName(), c) + } + + return nil +} + +// Apply applies the constant operator. +func (c *Constant9) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { + return []tensor.Tensor{c.value}, nil +} diff --git a/ops/opset13/constant_test.go b/ops/constant/constant_test.go similarity index 81% rename from ops/opset13/constant_test.go rename to ops/constant/constant_test.go index ffebccf..e1f23c8 100644 --- a/ops/opset13/constant_test.go +++ b/ops/constant/constant_test.go @@ -1,4 +1,4 @@ -package opset13 +package constant import ( "encoding/binary" @@ -11,55 +11,66 @@ import ( ) func TestConstantInit(t *testing.T) { + constant := Constant{} + tests := []struct { + version int64 initAttr []*onnx.AttributeProto expected interface{} err error }{ { + 13, ConstantValueAttrProtoFixture(), tensor.New(tensor.WithBacking([]int64{1, 1, 1})), nil, }, { + 13, ConstantValueFloatAttrProtoFixture(), tensor.New(tensor.FromScalar(float32(0.2))), nil, }, { + 13, ConstantValueFloatsAttrProtoFixture(), tensor.New(tensor.WithBacking([]float32{0.1, 0.2})), nil, }, { + 13, ConstantValueIntAttrProtoFixture(), tensor.New(tensor.FromScalar(int64(1))), nil, }, { + 13, ConstantValueIntsAttrProtoFixture(), tensor.New(tensor.WithBacking([]int64{1, 2, 3})), nil, }, { + 13, []*onnx.AttributeProto{{Name: "sparse_value"}}, nil, - ops.ErrUnsupportedAttribute("sparse_value", &Constant{}), + ops.ErrUnsupportedAttribute("sparse_value", &constant), }, { + 13, []*onnx.AttributeProto{{Name: "unknownAttribute"}}, nil, - ops.ErrUnsupportedAttribute("unknownAttribute", &Constant{}), + ops.ErrUnsupportedAttribute("unknownAttribute", &constant), }, { + 13, []*onnx.AttributeProto{}, nil, - ops.ErrInvalidAttributeCount(1, 0, &Constant{}), + ops.ErrInvalidAttributeCount(1, 0, &constant), }, } for _, test := range tests { - constant := &Constant{} + constant.value = nil err := constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) assert.Equal(t, test.err, err) @@ -72,40 +83,41 @@ func TestConstantInit(t *testing.T) { func TestConstant(t *testing.T) { tests := []struct { - constant *Constant + version int64 initAttr []*onnx.AttributeProto expected interface{} }{ { - &Constant{}, + 13, ConstantValueAttrProtoFixture(), []int64{1, 1, 1}, }, { - &Constant{}, + 13, ConstantValueFloatAttrProtoFixture(), float32(0.2), }, { - &Constant{}, + 13, ConstantValueFloatsAttrProtoFixture(), []float32{0.1, 0.2}, }, { - &Constant{}, + 13, ConstantValueIntAttrProtoFixture(), int64(1), }, { - &Constant{}, + 13, ConstantValueIntsAttrProtoFixture(), []int64{1, 2, 3}, }, } for _, test := range tests { - _ = test.constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) - res, err := test.constant.Apply([]tensor.Tensor{}) + constant := constantVersions[test.version]() + _ = constant.Init(&onnx.NodeProto{Attribute: test.initAttr}) + res, err := constant.Apply([]tensor.Tensor{}) assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) @@ -122,23 +134,26 @@ func TestConstantSingleIntShapeTensor(t *testing.T) { func TestInputValidationConstant(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{}, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Constant{}), + ops.ErrInvalidInputCount(1, constant13BaseOpFixture()), }, } for _, test := range tests { - constant := &Constant{} + constant := constantVersions[test.version]() validated, err := constant.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -177,3 +192,7 @@ func ConstantValueIntAttrProtoFixture() []*onnx.AttributeProto { func ConstantValueIntsAttrProtoFixture() []*onnx.AttributeProto { return []*onnx.AttributeProto{{Name: "value_ints", Ints: []int64{1, 2, 3}}} } + +func constant13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 0, 0, [][]tensor.Dtype{}, "constant") +} diff --git a/ops/constant/constants.go b/ops/constant/constants.go new file mode 100644 index 0000000..5ec00f6 --- /dev/null +++ b/ops/constant/constants.go @@ -0,0 +1,12 @@ +package constant + +const ( + value = "value" + sparseValue = "sparse_value" + valueString = "value_string" + valueStrings = "value_strings" + valueFloat = "value_float" + valueFloats = "value_floats" + valueInt = "value_int" + valueInts = "value_ints" +) diff --git a/ops/constant/versions.go b/ops/constant/versions.go new file mode 100644 index 0000000..f236ada --- /dev/null +++ b/ops/constant/versions.go @@ -0,0 +1,18 @@ +package constant + +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var constantVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newConstant9, 1, [][]tensor.Dtype{}), + 9: ops.NewOperatorConstructor(newConstant9, 9, [][]tensor.Dtype{}), + 11: newConstant11, + 12: ops.NewOperatorConstructor(newConstant, 12, [][]tensor.Dtype{}), + 13: ops.NewOperatorConstructor(newConstant, 13, [][]tensor.Dtype{}), +} + +func GetConstantVersions() ops.OperatorVersions { + return constantVersions +} diff --git a/ops/opset13/constant_of_shape.go b/ops/constantofshape/constant_of_shape.go similarity index 60% rename from ops/opset13/constant_of_shape.go rename to ops/constantofshape/constant_of_shape.go index 9511108..82a2b05 100644 --- a/ops/opset13/constant_of_shape.go +++ b/ops/constantofshape/constant_of_shape.go @@ -1,4 +1,4 @@ -package opset13 +package constantofshape import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,21 +6,28 @@ import ( "gorgonia.org/tensor" ) -const ( - MinConstantOfShapeInputs = 1 - MaxConstantOfShapeInputs = 1 -) +var constantOfShapeTypeConstraints = [][]tensor.Dtype{{tensor.Int64}} // ConstantOfShape represents the ONNX constant of shape operator. type ConstantOfShape struct { + ops.BaseOperator + // One element tensor, giving the value and type of the output tensor // defaults to value 0 and type float32. value *tensor.Dense } // newConstantOfShape creates a new constant of shape operator. -func newConstantOfShape() ops.Operator { - return &ConstantOfShape{} +func newConstantOfShape(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &ConstantOfShape{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "constantofshape", + ), + } } // Init initializes the constant of shape operator. @@ -76,31 +83,3 @@ func (c *ConstantOfShape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) return []tensor.Tensor{t}, err } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *ConstantOfShape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *ConstantOfShape) GetMinInputs() int { - return MinConstantOfShapeInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *ConstantOfShape) GetMaxInputs() int { - return MaxConstantOfShapeInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *ConstantOfShape) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Int64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *ConstantOfShape) String() string { - return "constant of shape operator" -} diff --git a/ops/opset13/constant_of_shape_test.go b/ops/constantofshape/constant_of_shape_test.go similarity index 76% rename from ops/opset13/constant_of_shape_test.go rename to ops/constantofshape/constant_of_shape_test.go index e294c25..5707009 100644 --- a/ops/opset13/constant_of_shape_test.go +++ b/ops/constantofshape/constant_of_shape_test.go @@ -1,4 +1,4 @@ -package opset13 +package constantofshape import ( "encoding/binary" @@ -68,17 +68,18 @@ func TestConstantOfShape(t *testing.T) { // Test cases, verifying that all these types work. // Unfortunately uint* and bool are not supported. tests := []struct { + version int64 input interface{} expectTensor interface{} }{ - {float32(42.0), []float32{42.0, 42.0, 42.0, 42.0}}, - {float64(42.0), []float64{42.0, 42.0, 42.0, 42.0}}, - {int8(42), []int8{42.0, 42.0, 42.0, 42.0}}, - {int16(42), []int16{42.0, 42.0, 42.0, 42.0}}, - {int32(42), []int32{42.0, 42.0, 42.0, 42.0}}, - {int64(42), []int64{42.0, 42.0, 42.0, 42.0}}, - {int32(-1), []int32{-1, -1, -1, -1}}, - {int32(0), []int32{0, 0, 0, 0}}, + {9, float32(42.0), []float32{42.0, 42.0, 42.0, 42.0}}, + {9, float64(42.0), []float64{42.0, 42.0, 42.0, 42.0}}, + {9, int8(42), []int8{42.0, 42.0, 42.0, 42.0}}, + {9, int16(42), []int16{42.0, 42.0, 42.0, 42.0}}, + {9, int32(42), []int32{42.0, 42.0, 42.0, 42.0}}, + {9, int64(42), []int64{42.0, 42.0, 42.0, 42.0}}, + {9, int32(-1), []int32{-1, -1, -1, -1}}, + {9, int32(0), []int32{0, 0, 0, 0}}, } for _, test := range tests { @@ -88,10 +89,11 @@ func TestConstantOfShape(t *testing.T) { assert.NotNil(t, tp) node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} + op, ok := constantOfShapeVersions[test.version]().(*ConstantOfShape) + assert.True(t, ok) - // Create operator - op := ConstantOfShape{} err := op.Init(node) + assert.NoError(t, err) assert.Equal(t, test.input, op.value.Data()) @@ -132,18 +134,18 @@ func TestIncorrectInput(t *testing.T) { } node := &onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "value", T: tp}}} - op := &ConstantOfShape{} + op := constantOfShapeVersions[9]() err := op.Init(node) assert.NotNil(t, err) assert.Equal( t, - "constant of shape operator invalid tensor found, reason: expected tensor to have one element", + "constantofshape v9 invalid tensor found, reason: expected tensor to have one element", err.Error(), ) } func TestNegativeShapeNotAllowed(t *testing.T) { - op := &ConstantOfShape{} + op := constantOfShapeVersions[9]() _ = op.Init(ops.EmptyNodeProto()) shape := []int64{1, -1} @@ -154,12 +156,12 @@ func TestNegativeShapeNotAllowed(t *testing.T) { assert.Equal( t, - "constant of shape operator invalid tensor found, reason: empty dimensions are not allowed", + "constantofshape v9 invalid tensor found, reason: empty dimensions are not allowed", err.Error()) } func TestEmptyTensorNotAllowed(t *testing.T) { - op := &ConstantOfShape{} + op := constantOfShapeVersions[9]() _ = op.Init(ops.EmptyNodeProto()) shape := []int64{0} @@ -170,7 +172,7 @@ func TestEmptyTensorNotAllowed(t *testing.T) { assert.Equal( t, - "constant of shape operator invalid tensor found, reason: empty dimensions are not allowed", + "constantofshape v9 invalid tensor found, reason: empty dimensions are not allowed", err.Error()) } @@ -189,27 +191,31 @@ func TestScalarShapeInput(t *testing.T) { func TestInputValidationConstantOfShape(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1}, 1), }, nil, }, { + 9, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &ConstantOfShape{}), + ops.ErrInvalidInputCount(0, constantOfShape9BaseOpFixture()), }, { + 9, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &ConstantOfShape{}), + ops.ErrInvalidInputType(0, "int", constantOfShape9BaseOpFixture()), }, } for _, test := range tests { - constantOfShape := &ConstantOfShape{} + constantOfShape := constantOfShapeVersions[test.version]() validated, err := constantOfShape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -219,3 +225,7 @@ func TestInputValidationConstantOfShape(t *testing.T) { } } } + +func constantOfShape9BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(9, 1, 1, constantOfShapeTypeConstraints, "constantofshape") +} diff --git a/ops/constantofshape/versions.go b/ops/constantofshape/versions.go new file mode 100644 index 0000000..9f5f65c --- /dev/null +++ b/ops/constantofshape/versions.go @@ -0,0 +1,13 @@ +package constantofshape + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var constantOfShapeVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newConstantOfShape, 9, constantOfShapeTypeConstraints), +} + +func GetConstantOfShapeVersions() ops.OperatorVersions { + return constantOfShapeVersions +} diff --git a/ops/opset13/conv.go b/ops/conv/conv.go similarity index 93% rename from ops/opset13/conv.go rename to ops/conv/conv.go index 801a5e9..01fefb3 100644 --- a/ops/opset13/conv.go +++ b/ops/conv/conv.go @@ -1,4 +1,4 @@ -package opset13 +package conv import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,6 +6,12 @@ import ( "gorgonia.org/tensor" ) +var convTypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, +} + var ( MinConvInputs = 2 MaxConvInputs = 3 @@ -30,6 +36,8 @@ const nNonSpatialDims = 2 // Conv represents the ONNX conv operator. type Conv struct { + ops.BaseOperator + autoPad AutoPadSetting dilations []int group int @@ -39,8 +47,15 @@ type Conv struct { } // newConv creates a new conv operator. -func newConv() ops.Operator { +func newConv(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &Conv{ + BaseOperator: ops.NewBaseOperator( + version, + MinConvInputs, + MaxConvInputs, + typeConstraints, + "conv", + ), autoPad: NotSet, } } @@ -125,7 +140,7 @@ func (c *Conv) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { case NDims2DConvolution: out, err = c.applyConv2D(x, kernel) default: - return nil, ops.ErrInvalidInput("the convolution operator currently only supports 1D or 2D convolution, i.e. shape [N x C x H (x W)]", c) + return nil, ops.ErrInvalidInput("the convolution operator currently only supports 1D or 2D convolution, i.e. shape [N x C x H (x W)]", c.BaseOperator) } if err != nil { @@ -142,36 +157,6 @@ func (c *Conv) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Conv) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Conv) GetMinInputs() int { - return MinConvInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Conv) GetMaxInputs() int { - return MaxConvInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Conv) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Conv) String() string { - return "conv operator" -} - // setDefaultDilations sets the dilations attribute to the default. Can be called when no // dilations were set when initializing. func (c *Conv) setDefaultDilations(x tensor.Tensor) { diff --git a/ops/opset13/conv_test.go b/ops/conv/conv_test.go similarity index 74% rename from ops/opset13/conv_test.go rename to ops/conv/conv_test.go index 8da4b87..0c264bd 100644 --- a/ops/opset13/conv_test.go +++ b/ops/conv/conv_test.go @@ -1,4 +1,4 @@ -package opset13 +package conv import ( "testing" @@ -37,7 +37,8 @@ func TestConvInitUnsupported(t *testing.T) { func TestConv(t *testing.T) { tests := []struct { - conv *Conv + version int64 + node *onnx.NodeProto shapes [][]int backings [][]float32 expectedShape tensor.Shape @@ -45,13 +46,16 @@ func TestConv(t *testing.T) { }{ // Test 1D Convolution. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{3}, - pads: []int{0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{3}}, + {Name: "pads", Ints: []int64{0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 6}, {1, 1, 3}}, [][]float32{{0, 1, 2, 3, 4, 5}, {1, 1, 1}}, @@ -60,13 +64,16 @@ func TestConv(t *testing.T) { }, // Test 2D Convolution. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, @@ -75,13 +82,16 @@ func TestConv(t *testing.T) { }, // Test SAME_LOWER autopad setting. { - &Conv{ - autoPad: "SAME_LOWER", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("SAME_LOWER")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, @@ -90,13 +100,16 @@ func TestConv(t *testing.T) { }, // Test SAME_UPPER autopad setting. { - &Conv{ - autoPad: "SAME_UPPER", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("SAME_UPPER")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, @@ -105,13 +118,16 @@ func TestConv(t *testing.T) { }, // Test VALID autopad setting. { - &Conv{ - autoPad: "VALID", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("VALID")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}}, @@ -120,13 +136,16 @@ func TestConv(t *testing.T) { }, // Test dilation attribute. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{2, 2}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{2, 2}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 4, 4}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {1, 1, 1, 1}}, @@ -135,13 +154,16 @@ func TestConv(t *testing.T) { }, // Test pads attribute. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{1, 1}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{1, 1, 2, 2}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{1, 1}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{1, 1, 2, 2}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 2, 2}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, @@ -150,13 +172,16 @@ func TestConv(t *testing.T) { }, // Test strides attribute. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{2, 2}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{2, 2}}, + }, }, [][]int{{1, 1, 4, 4}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {1, 1, 1, 1}}, @@ -165,13 +190,16 @@ func TestConv(t *testing.T) { }, // Test batch dimension. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{2, 1, 3, 3}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, {1, 1, 1, 1}}, @@ -180,13 +208,16 @@ func TestConv(t *testing.T) { }, // Test 2D convolution with multiple channels. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 2, 3, 3}, {1, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, {1, 1, 1, 1}}, @@ -195,13 +226,16 @@ func TestConv(t *testing.T) { }, // Test multiple kernels. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 3, 3}, {2, 1, 2, 2}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1, 2, 2, 2, 2}}, @@ -210,13 +244,16 @@ func TestConv(t *testing.T) { }, // Test bias. { - &Conv{ - autoPad: "NOTSET", - dilations: []int{}, - group: 1, - kernelShape: []int{2, 2}, - pads: []int{0, 0, 0, 0}, - strides: []int{1, 1}, + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "auto_pad", S: []byte("NOTSET")}, + {Name: "dilations", Ints: []int64{}}, + {Name: "group", I: 1}, + {Name: "kernel_shape", Ints: []int64{2, 2}}, + {Name: "pads", Ints: []int64{0, 0, 0, 0}}, + {Name: "strides", Ints: []int64{1, 1}}, + }, }, [][]int{{1, 1, 3, 3}, {1, 1, 2, 2}, {1}}, [][]float32{{0, 1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 1, 1}, {0.5}}, @@ -236,7 +273,11 @@ func TestConv(t *testing.T) { inputs[2] = ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...) } - res, err := test.conv.Apply(inputs) + conv := convVersions[test.version]() + err := conv.Init(test.node) + assert.Nil(t, err) + + res, err := conv.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expectedShape, res[0].Shape()) @@ -246,10 +287,12 @@ func TestConv(t *testing.T) { func TestInputValidationConv(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 11, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -258,6 +301,7 @@ func TestInputValidationConv(t *testing.T) { nil, }, { + 11, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -266,22 +310,24 @@ func TestInputValidationConv(t *testing.T) { nil, }, { + 11, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidOptionalInputCount(1, &Conv{}), + ops.ErrInvalidOptionalInputCount(1, conv11BaseOpFixture()), }, { + 11, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Conv{}), + ops.ErrInvalidInputType(0, "int", conv11BaseOpFixture()), }, } for _, test := range tests { - conv := &Conv{} + conv := convVersions[test.version]() validated, err := conv.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -690,3 +736,7 @@ func ConvUnsupportedOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func conv11BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(11, 2, 3, convTypeConstraints, "conv") +} diff --git a/ops/conv/versions.go b/ops/conv/versions.go new file mode 100644 index 0000000..3f795e9 --- /dev/null +++ b/ops/conv/versions.go @@ -0,0 +1,14 @@ +package conv + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var convVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newConv, 1, convTypeConstraints), + 11: ops.NewOperatorConstructor(newConv, 11, convTypeConstraints), +} + +func GetConvVersions() ops.OperatorVersions { + return convVersions +} diff --git a/ops/cos/cos.go b/ops/cos/cos.go new file mode 100644 index 0000000..eba410b --- /dev/null +++ b/ops/cos/cos.go @@ -0,0 +1,61 @@ +package cos + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var cosTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Cos represents the ONNX cos operator. +type Cos struct { + ops.BaseOperator +} + +// newCos creates a new cos operator. +func newCos(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Cos{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "cos", + ), + } +} + +// Init initializes the cos operator. +func (c *Cos) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the cos operator. +func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(cos[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(cos[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func cos[T ops.FloatType](x T) T { + return T(math.Cos(float64(x))) +} diff --git a/ops/opset13/cos_test.go b/ops/cos/cos_test.go similarity index 77% rename from ops/opset13/cos_test.go rename to ops/cos/cos_test.go index b1087c4..56d019e 100644 --- a/ops/opset13/cos_test.go +++ b/ops/cos/cos_test.go @@ -1,4 +1,4 @@ -package opset13 +package cos import ( "testing" @@ -19,25 +19,25 @@ func TestCosInit(t *testing.T) { func TestCos(t *testing.T) { tests := []struct { - cos *Cos + version int64 backing []float32 shape []int expected []float32 }{ { - &Cos{}, + 7, []float32{-2, -1, 0, 1}, []int{2, 2}, []float32{-0.41614684, 0.5403023, 1, 0.5403023}, }, { - &Cos{}, + 7, []float32{1, 3, 4, 5}, []int{1, 4}, []float32{0.5403023, -0.9899925, -0.6536436, 0.2836622}, }, { - &Cos{}, + 7, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{0.5403023, 0.5403023, 0.5403023, 0.5403023}, @@ -49,7 +49,9 @@ func TestCos(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.cos.Apply(inputs) + cos := cosVersions[test.version]() + + res, err := cos.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +61,40 @@ func TestCos(t *testing.T) { func TestInputValidationCos(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Cos{}), + ops.ErrInvalidInputCount(0, cos7BaseOperator()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Cos{}), + ops.ErrInvalidInputType(0, "int", cos7BaseOperator()), }, } for _, test := range tests { - cos := &Cos{} + cos := cosVersions[test.version]() validated, err := cos.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +104,7 @@ func TestInputValidationCos(t *testing.T) { } } } + +func cos7BaseOperator() ops.BaseOperator { + return ops.NewBaseOperator(7, 1, 1, cosTypeConstraints, "cos") +} diff --git a/ops/cos/versions.go b/ops/cos/versions.go new file mode 100644 index 0000000..cbf8ebd --- /dev/null +++ b/ops/cos/versions.go @@ -0,0 +1,13 @@ +package cos + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var cosVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newCos, 7, cosTypeConstraints), +} + +func GetCosVersions() ops.OperatorVersions { + return cosVersions +} diff --git a/ops/cosh/cosh.go b/ops/cosh/cosh.go new file mode 100644 index 0000000..0f52ff5 --- /dev/null +++ b/ops/cosh/cosh.go @@ -0,0 +1,61 @@ +package cosh + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var coshTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Cosh represents the ONNX cosh operator. +type Cosh struct { + ops.BaseOperator +} + +// newCosh creates a new cosh operator. +func newCosh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Cosh{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "cosh", + ), + } +} + +// Init initializes the cosh operator. +func (c *Cosh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the cosh operator. +func (c *Cosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(cosh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(cosh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func cosh[T ops.FloatType](x T) T { + return T(math.Cosh(float64(x))) +} diff --git a/ops/opset13/cosh_test.go b/ops/cosh/cosh_test.go similarity index 77% rename from ops/opset13/cosh_test.go rename to ops/cosh/cosh_test.go index 3359ada..35af5ae 100644 --- a/ops/opset13/cosh_test.go +++ b/ops/cosh/cosh_test.go @@ -1,4 +1,4 @@ -package opset13 +package cosh import ( "testing" @@ -19,25 +19,25 @@ func TestCoshInit(t *testing.T) { func TestCosh(t *testing.T) { tests := []struct { - cosh *Cosh + version int64 backing []float32 shape []int expected []float32 }{ { - &Cosh{}, + 9, []float32{-2, -1, 0, 1}, []int{2, 2}, []float32{3.7621956, 1.5430807, 1, 1.5430807}, }, { - &Cosh{}, + 9, []float32{1, 3, 4, 5}, []int{1, 4}, []float32{1.5430807, 10.067662, 27.308233, 74.209946}, }, { - &Cosh{}, + 9, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{1.5430807, 1.5430807, 1.5430807, 1.5430807}, @@ -49,7 +49,9 @@ func TestCosh(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.cosh.Apply(inputs) + cosh := coshVersions[test.version]() + + res, err := cosh.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +61,40 @@ func TestCosh(t *testing.T) { func TestInputValidationCosh(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Cosh{}), + ops.ErrInvalidInputCount(0, cosh9BaseOpFixture()), }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Cosh{}), + ops.ErrInvalidInputType(0, "int", cosh9BaseOpFixture()), }, } for _, test := range tests { - cosh := &Cosh{} + cosh := coshVersions[test.version]() validated, err := cosh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +104,7 @@ func TestInputValidationCosh(t *testing.T) { } } } + +func cosh9BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(9, 1, 1, coshTypeConstraints, "cosh") +} diff --git a/ops/cosh/versions.go b/ops/cosh/versions.go new file mode 100644 index 0000000..334b4ea --- /dev/null +++ b/ops/cosh/versions.go @@ -0,0 +1,13 @@ +package cosh + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var coshVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newCosh, 9, coshTypeConstraints), +} + +func GetCoshVersions() ops.OperatorVersions { + return coshVersions +} diff --git a/ops/div/div.go b/ops/div/div.go new file mode 100644 index 0000000..90a48a3 --- /dev/null +++ b/ops/div/div.go @@ -0,0 +1,45 @@ +package div + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var divTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Div represents the ONNX div operator. +type Div struct { + ops.BaseOperator +} + +// newDiv creates a new div operator. +func newDiv(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Div{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "div", + ), + } +} + +// Init initializes the div operator. +func (d *Div) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the div operator. +func (d *Div) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Div, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/div_test.go b/ops/div/div_13_test.go similarity index 83% rename from ops/opset13/div_test.go rename to ops/div/div_13_test.go index 06a4f45..7acd68f 100644 --- a/ops/opset13/div_test.go +++ b/ops/div/div_13_test.go @@ -1,4 +1,4 @@ -package opset13 +package div import ( "testing" @@ -19,25 +19,25 @@ func TestDivInit(t *testing.T) { func TestDiv(t *testing.T) { tests := []struct { - div *Div + version int64 shapes [][]int backings [][]float32 expected []float32 }{ { - &Div{}, + 13, [][]int{{2, 2}, {2, 2}}, [][]float32{{10, 10, 10, 10}, {2, 5, 2.5, 1.0}}, []float32{5, 2, 4, 10}, }, { - &Div{}, + 13, [][]int{{2, 2}, {2}}, [][]float32{{1, 1, 1, 1}, {1, 2}}, []float32{1, 0.5, 1, 0.5}, }, { - &Div{}, + 13, [][]int{{2, 2}, {1}}, [][]float32{{1, 1, 1, 1}, {2}}, []float32{0.5, 0.5, 0.5, 0.5}, @@ -49,7 +49,10 @@ func TestDiv(t *testing.T) { ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.div.Apply(inputs) + + div := divVersions[test.version]() + + res, err := div.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) @@ -58,10 +61,12 @@ func TestDiv(t *testing.T) { func TestInputValidationDiv(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -69,6 +74,7 @@ func TestInputValidationDiv(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -76,6 +82,7 @@ func TestInputValidationDiv(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -83,6 +90,7 @@ func TestInputValidationDiv(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -90,6 +98,7 @@ func TestInputValidationDiv(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -97,6 +106,7 @@ func TestInputValidationDiv(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -104,22 +114,24 @@ func TestInputValidationDiv(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Div{}), + ops.ErrInvalidInputCount(1, div13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Div{}), + ops.ErrInvalidInputType(0, "int", div13BaseOpFixture()), }, } for _, test := range tests { - div := &Div{} + div := divVersions[test.version]() validated, err := div.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -129,3 +141,7 @@ func TestInputValidationDiv(t *testing.T) { } } } + +func div13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, divTypeConstraints, "div") +} diff --git a/ops/div/versions.go b/ops/div/versions.go new file mode 100644 index 0000000..20b814e --- /dev/null +++ b/ops/div/versions.go @@ -0,0 +1,14 @@ +package div + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var divVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newDiv, 7, divTypeConstraints), + 13: ops.NewOperatorConstructor(newDiv, 13, divTypeConstraints), +} + +func GetDivVersions() ops.OperatorVersions { + return divVersions +} diff --git a/ops/equal/equal.go b/ops/equal/equal.go new file mode 100644 index 0000000..0e85255 --- /dev/null +++ b/ops/equal/equal.go @@ -0,0 +1,47 @@ +package equal + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var equal7TypeConstraints = [][]tensor.Dtype{ + {tensor.Bool, tensor.Int32, tensor.Int64}, + {tensor.Bool, tensor.Int32, tensor.Int64}, +} + +var equalTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} + +// Equal represents the ONNX equal operator. +type Equal struct { + ops.BaseOperator +} + +// newEqual creates a new equal operator. +func newEqual(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Equal{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "equal", + ), + } +} + +// Init initializes the equal operator. +func (e *Equal) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the equal operator. +func (e *Equal) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Equal, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/equal_test.go b/ops/equal/equal_test.go similarity index 74% rename from ops/opset13/equal_test.go rename to ops/equal/equal_test.go index 9014e78..3271c67 100644 --- a/ops/opset13/equal_test.go +++ b/ops/equal/equal_test.go @@ -1,4 +1,4 @@ -package opset13 +package equal import ( "testing" @@ -19,25 +19,25 @@ func TestEqualInit(t *testing.T) { func TestEqual(t *testing.T) { tests := []struct { - equal *Equal + version int64 backings [][]float32 shapes [][]int expected []bool }{ { - &Equal{}, + 13, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{false, true, false, false}, }, { - &Equal{}, + 13, [][]float32{{0, 1, 2, 2, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{false, false, true, true, false, false}, }, { - &Equal{}, + 13, [][]float32{{0, 1}, {0, 1, 0, 1}}, [][]int{{2}, {2, 2}}, []bool{true, true, true, true}, @@ -50,7 +50,9 @@ func TestEqual(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.equal.Apply(inputs) + equal := equalVersions[test.version]() + + res, err := equal.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -60,10 +62,12 @@ func TestEqual(t *testing.T) { func TestInputValidationEqual(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -71,6 +75,7 @@ func TestInputValidationEqual(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -78,6 +83,7 @@ func TestInputValidationEqual(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -85,6 +91,7 @@ func TestInputValidationEqual(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -92,6 +99,7 @@ func TestInputValidationEqual(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -99,6 +107,7 @@ func TestInputValidationEqual(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -106,22 +115,33 @@ func TestInputValidationEqual(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Equal{}), + ops.ErrInvalidInputCount(1, equal13BaseOpFixture()), + }, + { + 7, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "float32", equal7BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Equal{}), + ops.ErrInvalidInputType(0, "int", equal13BaseOpFixture()), }, } for _, test := range tests { - equal := &Equal{} + equal := equalVersions[test.version]() + validated, err := equal.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -131,3 +151,11 @@ func TestInputValidationEqual(t *testing.T) { } } } + +func equal7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, equal7TypeConstraints, "equal") +} + +func equal13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, equalTypeConstraints, "equal") +} diff --git a/ops/equal/versions.go b/ops/equal/versions.go new file mode 100644 index 0000000..79de43c --- /dev/null +++ b/ops/equal/versions.go @@ -0,0 +1,15 @@ +package equal + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var equalVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newEqual, 7, equal7TypeConstraints), + 11: ops.NewOperatorConstructor(newEqual, 11, equalTypeConstraints), + 13: ops.NewOperatorConstructor(newEqual, 13, equalTypeConstraints), +} + +func GetEqualVersions() ops.OperatorVersions { + return equalVersions +} diff --git a/ops/errors.go b/ops/errors.go index 282ace6..0518d6f 100644 --- a/ops/errors.go +++ b/ops/errors.go @@ -98,7 +98,7 @@ const ( type InputError struct { kind InputErrorKind - operator Operator + Operator BaseOperator reason string // Attributes for input type error. @@ -116,61 +116,61 @@ type InputError struct { func (i *InputError) Error() string { switch i.kind { case InputErrorType: - return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.operator, i.actualType) + return fmt.Sprintf("input %d for op %v does not allow dtype %v", i.inputNumber, i.Operator, i.actualType) case InputErrorCount: if i.hasOptionalInputs { - return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.operator.GetMaxInputs(), i.actualCount) + return fmt.Sprintf(InvalidOptionalInputCountErrTemplate, i.Operator, i.Operator.GetMinInputs(), i.Operator.GetMaxInputs(), i.actualCount) } - return fmt.Sprintf(InvalidInputCountErrTemplate, i.operator, i.operator.GetMinInputs(), i.actualCount) + return fmt.Sprintf(InvalidInputCountErrTemplate, i.Operator, i.Operator.GetMinInputs(), i.actualCount) case InputErrorUnsupported: - return fmt.Sprintf(UnsupportedInputErrTemplate, i.operator, i.inputName) + return fmt.Sprintf(UnsupportedInputErrTemplate, i.Operator, i.inputName) case InputErrorInvalid: - return fmt.Sprintf(InvalidInputErrTemplate, i.operator, i.reason) + return fmt.Sprintf(InvalidInputErrTemplate, i.Operator, i.reason) default: - return fmt.Sprintf("%s unknown error input error kind %s", i.operator.String(), i.kind) + return fmt.Sprintf("%s unknown error input error kind %s", i.Operator.String(), i.kind) } } -func ErrInvalidInputType(inputNumber int, dType string, operator Operator) error { +func ErrInvalidInputType(inputNumber int, dType string, operator BaseOperator) error { return &InputError{ kind: InputErrorType, - operator: operator, + Operator: operator, inputNumber: inputNumber, actualType: dType, } } -func ErrInvalidInputCount(actual int, operator Operator) error { +func ErrInvalidInputCount(actual int, operator BaseOperator) error { return &InputError{ kind: InputErrorCount, actualCount: actual, - operator: operator, + Operator: operator, } } -func ErrInvalidOptionalInputCount(actual int, operator Operator) error { +func ErrInvalidOptionalInputCount(actual int, operator BaseOperator) error { return &InputError{ kind: InputErrorCount, hasOptionalInputs: true, actualCount: actual, - operator: operator, + Operator: operator, } } -func ErrUnsupportedInput(inputName string, operator Operator) error { +func ErrUnsupportedInput(inputName string, operator BaseOperator) error { return &InputError{ kind: InputErrorUnsupported, inputName: inputName, - operator: operator, + Operator: operator, } } -func ErrInvalidInput(reason string, operator Operator) error { +func ErrInvalidInput(reason string, operator BaseOperator) error { return &InputError{ kind: InputErrorInvalid, reason: reason, - operator: operator, + Operator: operator, } } @@ -221,14 +221,20 @@ func ErrUnknownOperatorType(operatorType string) error { return fmt.Errorf("%w: %s", ErrUnsupportedOperator, operatorType) } +var ErrUnsupportedOperatorVersion = errors.New("unsupported opset operator version") + +func ErrUnsupportedOperatorVersionType(opsetID int64, operatorType string) error { + return fmt.Errorf("%w: opset %d for operator %s", ErrUnsupportedOperator, opsetID, operatorType) +} + var ErrAxisNotInRange = errors.New("axis out of range") -func ErrNotAllAxesInRange(min, max int) error { - return fmt.Errorf("%w: all indices entries must be in the range -%d <= x < %d", ErrAxisNotInRange, min, max) +func ErrNotAllAxesInRange(minVal, maxVal int) error { + return fmt.Errorf("%w: all indices entries must be in the range -%d <= x < %d", ErrAxisNotInRange, minVal, maxVal) } -func ErrAxisOutOfRange(min, max, actual int) error { - return fmt.Errorf("%w: axis argument must be in the range -%d <= x < %d, was %d", ErrAxisNotInRange, min, max, actual) +func ErrAxisOutOfRange(minVal, maxVal, actual int) error { + return fmt.Errorf("%w: axis argument must be in the range -%d <= x < %d, was %d", ErrAxisNotInRange, minVal, maxVal, actual) } var ErrUnsupportedOpsetVersion = errors.New("unsupported opset version") diff --git a/ops/opset13/expand.go b/ops/expand/expand.go similarity index 50% rename from ops/opset13/expand.go rename to ops/expand/expand.go index f84fb3a..c7b4a11 100644 --- a/ops/opset13/expand.go +++ b/ops/expand/expand.go @@ -1,4 +1,4 @@ -package opset13 +package expand import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,17 +6,24 @@ import ( "gorgonia.org/tensor" ) -const ( - MinExpandInputs = 2 - MaxExpandInputs = 2 -) +var expandTypeConstraints = [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} // Expand represents the ONNX expand operator. -type Expand struct{} +type Expand struct { + ops.BaseOperator +} // newExpand creates a new expand operator. -func newExpand() ops.Operator { - return &Expand{} +func newExpand(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Expand{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "expand", + ), + } } // Init initializes the expand operator. @@ -53,29 +60,3 @@ func (f *Expand) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{input}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (f *Expand) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(f, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (f *Expand) GetMinInputs() int { - return MinExpandInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (f *Expand) GetMaxInputs() int { - return MaxExpandInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (f *Expand) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (f *Expand) String() string { - return "expand operator" -} diff --git a/ops/opset13/expand_test.go b/ops/expand/expand_test.go similarity index 82% rename from ops/opset13/expand_test.go rename to ops/expand/expand_test.go index 325d200..71cb76e 100644 --- a/ops/opset13/expand_test.go +++ b/ops/expand/expand_test.go @@ -1,4 +1,4 @@ -package opset13 +package expand import ( "testing" @@ -17,7 +17,7 @@ func TestExpandInit(t *testing.T) { func TestExpand(t *testing.T) { tests := []struct { - expand *Expand + version int64 backing []float32 shape []int newShapeBacking []int64 @@ -25,7 +25,7 @@ func TestExpand(t *testing.T) { expectedData []float32 }{ { - &Expand{}, + 13, []float32{0, 1, 2, 3}, []int{2, 2}, []int64{1, 1, 1}, @@ -33,7 +33,7 @@ func TestExpand(t *testing.T) { []float32{0, 1, 2, 3}, }, { - &Expand{}, + 13, []float32{0, 1, 2, 3}, []int{2, 2}, []int64{1, 3, 1, 1}, @@ -48,7 +48,9 @@ func TestExpand(t *testing.T) { ops.TensorWithBackingFixture(test.newShapeBacking, len(test.newShapeBacking)), } - res, err := test.expand.Apply(inputs) + expand := expandVersions[test.version]() + + res, err := expand.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expectedShape, res[0].Shape()) @@ -58,10 +60,12 @@ func TestExpand(t *testing.T) { func TestInputValidationExpand(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), @@ -69,6 +73,7 @@ func TestInputValidationExpand(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), @@ -76,6 +81,7 @@ func TestInputValidationExpand(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), @@ -83,6 +89,7 @@ func TestInputValidationExpand(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), @@ -90,6 +97,7 @@ func TestInputValidationExpand(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), @@ -97,6 +105,7 @@ func TestInputValidationExpand(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), @@ -104,24 +113,26 @@ func TestInputValidationExpand(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, - ops.ErrInvalidInputCount(3, &Expand{}), + ops.ErrInvalidInputCount(3, expand13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int64{1, 1, 1}, 3), }, - ops.ErrInvalidInputType(0, "int", &Expand{}), + ops.ErrInvalidInputType(0, "int", expand13BaseOpFixture()), }, } for _, test := range tests { - expand := &Expand{} + expand := expandVersions[test.version]() validated, err := expand.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -131,3 +142,7 @@ func TestInputValidationExpand(t *testing.T) { } } } + +func expand13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, expandTypeConstraints, "expand") +} diff --git a/ops/expand/versions.go b/ops/expand/versions.go new file mode 100644 index 0000000..ed20a18 --- /dev/null +++ b/ops/expand/versions.go @@ -0,0 +1,13 @@ +package expand + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var expandVersions = ops.OperatorVersions{ + 13: ops.NewOperatorConstructor(newExpand, 13, expandTypeConstraints), +} + +func GetExpandVersions() ops.OperatorVersions { + return expandVersions +} diff --git a/ops/flatten/constants.go b/ops/flatten/constants.go new file mode 100644 index 0000000..59e0ec8 --- /dev/null +++ b/ops/flatten/constants.go @@ -0,0 +1,3 @@ +package flatten + +const axis = "axis" diff --git a/ops/flatten/flatten.go b/ops/flatten/flatten.go new file mode 100644 index 0000000..b712e55 --- /dev/null +++ b/ops/flatten/flatten.go @@ -0,0 +1,69 @@ +package flatten + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Flatten provides common functionality for all Flatten versions. +type Flatten struct { + ops.BaseOperator + axis int +} + +func newFlatten(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Flatten{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraint, + "flatten", + ), + } +} + +// Init initializes the flatten operator. +func (f *Flatten) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case axis: + f.axis = int(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), f) + } + } + + return nil +} + +// Apply applies the flatten operator. +func (f *Flatten) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + inputShape := inputs[0].Shape() + rank := len(inputShape) + + axis := f.axis + if axis < 0 { + axis = rank + axis + } + + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + var err error + // Handle the special case where axis is 0. + if axis == 0 { + err = out.Reshape(1, ops.NElements(inputShape...)) + } else { + err = out.Reshape(ops.NElements(inputShape[:axis]...), ops.NElements(inputShape[axis:]...)) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/flatten/flatten_test.go b/ops/flatten/flatten_test.go new file mode 100644 index 0000000..9ba478e --- /dev/null +++ b/ops/flatten/flatten_test.go @@ -0,0 +1,342 @@ +package flatten + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestFlattenInit(t *testing.T) { + f := &Flatten{axis: 1} + + err := f.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 2}}}) + assert.Nil(t, err) + + assert.Equal(t, 2, f.axis) +} + +func TestFlatten(t *testing.T) { + tests := []struct { + flatten *Flatten + backing []float32 + shape []int + expectedShape tensor.Shape + }{ + { + &Flatten{}, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []int{1, 4}, + }, + { + &Flatten{}, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []int{1, 6}, + }, + { + &Flatten{axis: 1}, + []float32{0, 1, 2, 3, 4, 5, 6, 7}, + []int{2, 2, 2}, + []int{2, 4}, + }, + { + &Flatten{axis: 2}, + []float32{0, 1, 2, 3, 4, 5, 6, 7}, + []int{2, 2, 2}, + []int{4, 2}, + }, + { + &Flatten{axis: -1}, + []float32{0, 1, 2, 3, 4, 5, 6, 7}, + []int{2, 2, 2}, + []int{4, 2}, + }, + { + &Flatten{axis: -2}, + []float32{0, 1, 2, 3, 4, 5, 6, 7}, + []int{2, 2, 2}, + []int{2, 4}, + }, + { + &Flatten{axis: -3}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, + []int{3, 2, 3}, + []int{1, 18}, + }, + { + &Flatten{axis: 2}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, + []int{3, 2, 3}, + []int{6, 3}, + }, + { + &Flatten{axis: 1}, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, + []int{3, 2, 3}, + []int{3, 6}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + res, err := test.flatten.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expectedShape, res[0].Shape()) + } +} + +func TestInputValidationFlatten(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + version int64 + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputCount(2, ops.NewBaseOperator(13, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "flatten")), + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "flatten")), + 13, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputCount(2, ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "flatten")), + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(9, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "flatten")), + 9, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputCount(2, ops.NewBaseOperator(11, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "flatten")), + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(11, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "flatten")), + 11, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "uint32", ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "flatten")), + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "uint64", ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "flatten")), + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int32", ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "flatten")), + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int64", ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "flatten")), + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputCount(2, ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "flatten")), + 1, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}, "flatten")), + 1, + }, + } + + for _, test := range tests { + flatten := flattenVersions[test.version]() + validated, err := flatten.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/flatten/versions.go b/ops/flatten/versions.go new file mode 100644 index 0000000..9efd57f --- /dev/null +++ b/ops/flatten/versions.go @@ -0,0 +1,17 @@ +package flatten + +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var flattenVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newFlatten, 1, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}), + 9: ops.NewOperatorConstructor(newFlatten, 9, [][]tensor.Dtype{ops.AllTypes}), + 11: ops.NewOperatorConstructor(newFlatten, 11, [][]tensor.Dtype{ops.AllTypes}), + 13: ops.NewOperatorConstructor(newFlatten, 13, [][]tensor.Dtype{ops.AllTypes}), +} + +func GetFlattenVersions() ops.OperatorVersions { + return flattenVersions +} diff --git a/ops/gather/constants.go b/ops/gather/constants.go new file mode 100644 index 0000000..8c5beed --- /dev/null +++ b/ops/gather/constants.go @@ -0,0 +1,3 @@ +package gather + +const axis = "axis" diff --git a/ops/opset13/gather.go b/ops/gather/gather.go similarity index 84% rename from ops/opset13/gather.go rename to ops/gather/gather.go index e6e7f3f..3729dd1 100644 --- a/ops/opset13/gather.go +++ b/ops/gather/gather.go @@ -1,4 +1,4 @@ -package opset13 +package gather import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,19 +6,28 @@ import ( "gorgonia.org/tensor" ) -const ( - MinGatherInputs = 2 - MaxGatherInputs = 2 -) +var gatherTypeConstraints = [][]tensor.Dtype{ + ops.AllTypes, + {tensor.Int32, tensor.Int64}, +} // Gather represents the ONNX gather operator. type Gather struct { + ops.BaseOperator + axis int // axis to gather on, default is 0 } // newGather creates a new gather operator. -func newGather() ops.Operator { +func newGather(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &Gather{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "gather", + ), axis: 0, } } @@ -30,7 +39,7 @@ func (g *Gather) Init(n *onnx.NodeProto) error { if len(attributes) == 1 { attr := attributes[0] - if attr.GetName() == "axis" { + if attr.GetName() == axis { g.axis = int(attr.GetI()) } else { return ops.ErrInvalidAttribute(attr.GetName(), g) @@ -91,35 +100,6 @@ func (g *Gather) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{output}, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *Gather) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(g, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *Gather) GetMinInputs() int { - return MinGatherInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *Gather) GetMaxInputs() int { - return MaxGatherInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (g *Gather) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - ops.AllTypes, - {tensor.Int32, tensor.Int64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (g *Gather) String() string { - return "gather operator" -} - // Perform gather according to the definition given by ONNX : // -------------------------- // For axis = 0 : diff --git a/ops/opset13/gather_test.go b/ops/gather/gather_test.go similarity index 81% rename from ops/opset13/gather_test.go rename to ops/gather/gather_test.go index e48925a..d8d4bdb 100644 --- a/ops/opset13/gather_test.go +++ b/ops/gather/gather_test.go @@ -1,10 +1,11 @@ -package opset13 +package gather import ( "testing" "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/concat" "github.com/stretchr/testify/assert" "gorgonia.org/tensor" ) @@ -31,25 +32,27 @@ func TestGatherInitDefault(t *testing.T) { } func TestGatherInitTooManyAttrs(t *testing.T) { - op := Gather{} + op := Gather{BaseOperator: ops.NewBaseOperator(13, 2, 2, gatherTypeConstraints, "gather")} err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis"}, {Name: "default"}}}) - assert.EqualError(t, err, "gather operator attribute error: invalid count 2 expected 1") + assert.EqualError(t, err, "gather v13 attribute error: invalid count 2 expected 1") } func TestGatherInitInvalidAttrName(t *testing.T) { - op := Gather{} + op := Gather{BaseOperator: ops.NewBaseOperator(13, 2, 2, gatherTypeConstraints, "gather")} err := op.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axes"}}}) // should be axis - assert.EqualError(t, err, "gather operator attribute error: invalid attribute axes") + assert.EqualError(t, err, "gather v13 attribute error: invalid attribute axes") } func TestGather(t *testing.T) { tests := []struct { + version int64 + data interface{} shape []int indices interface{} indShape []int - axis int + node *onnx.NodeProto expected interface{} expectedShape tensor.Shape @@ -65,128 +68,142 @@ func TestGather(t *testing.T) { // Out: (1, 2) { + 13, []float32{1, 2, 3, 4}, []int{4}, []int64{0}, []int{1}, - 0, + makeAxisProto(0), []float32{1}, tensor.Shape([]int{1}), }, { + 13, []float32{1, 2, 3, 4}, []int{2, 2}, []int64{0}, []int{1}, - 0, + makeAxisProto(0), []float32{1, 2}, tensor.Shape([]int{1, 2}), }, { + 13, []float32{1, 2, 3, 4}, []int{2, 2}, []int64{0}, []int{1}, - 1, + makeAxisProto(1), []float32{1, 3}, tensor.Shape([]int{2, 1}), }, { + 13, []float32{1, 2, 3, 4}, []int{2, 2}, []int64{0}, []int{1}, - -1, + makeAxisProto(-1), []float32{1, 3}, tensor.Shape([]int{2, 1}), }, { + 13, []float32{1, 2, 3, 4}, []int{2, 2}, []int64{1}, []int{1}, - 1, + makeAxisProto(1), []float32{2, 4}, tensor.Shape([]int{2, 1}), }, { + 13, []float32{1, 2, 3, 4}, []int{2, 2}, []int64{0}, []int{1, 1}, - 1, + makeAxisProto(1), []float32{1, 3}, tensor.Shape([]int{2, 1, 1}), }, { + 13, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, []int{3, 2, 2}, []int64{0}, []int{1}, - 2, + makeAxisProto(2), []float32{1, 3, 5, 7, 9, 11}, tensor.Shape([]int{3, 2, 1}), }, { + 13, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, []int{3, 2, 2}, []int64{0}, []int{1}, - 1, + makeAxisProto(1), []float32{1, 2, 5, 6, 9, 10}, tensor.Shape([]int{3, 1, 2}), }, { + 13, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, []int{3, 3}, []int64{0, 2}, []int{1, 2}, - 1, + makeAxisProto(1), []float32{1, 3, 4, 6, 7, 9}, tensor.Shape([]int{3, 1, 2}), }, { + 13, []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, []int{3, 2, 2}, []int64{-2}, []int{1}, - 1, + makeAxisProto(1), []float32{1, 2, 5, 6, 9, 10}, tensor.Shape([]int{3, 1, 2}), }, { + 13, []float32{1, 2, 3, 4}, []int{4}, []int64{-4}, []int{1}, - 0, + makeAxisProto(0), []float32{1}, tensor.Shape([]int{1}), }, { + 13, []float32{1, 2, 3, 4}, []int{2, 2}, []int64{0}, []int{1}, - -1, + makeAxisProto(-1), []float32{1, 3}, tensor.Shape([]int{2, 1}), }, } for _, test := range tests { - op := &Gather{test.axis} + op := gatherVersions[test.version]() + err := op.Init(test.node) + assert.Nil(t, err) indices := test.indices data := test.data @@ -202,7 +219,7 @@ func TestGather(t *testing.T) { } func TestCombinedWithOtherOp(t *testing.T) { - concat := &Concat{} + concat := &concat.Concat{} err := concat.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 0}}}) assert.NoError(t, err) @@ -212,7 +229,7 @@ func TestCombinedWithOtherOp(t *testing.T) { data, err := concat.Apply([]tensor.Tensor{data0, data1}) assert.NoError(t, err) - gather := &Gather{0} + gather := gatherVersions[13]() indices := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) res, err := gather.Apply([]tensor.Tensor{data[0], indices}) @@ -221,7 +238,7 @@ func TestCombinedWithOtherOp(t *testing.T) { } func TestScalarInput(t *testing.T) { - op := &Gather{0} + op := gatherVersions[13]() dataIn := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) @@ -247,7 +264,7 @@ func TestGatherAxesIndexOutOfRange(t *testing.T) { } func TestGatherIndexOutOfRange(t *testing.T) { - op := &Gather{0} + op := gatherVersions[13]() dataIn := tensor.New(tensor.WithBacking([]int64{1}), tensor.WithShape(1)) indicesIn := tensor.New(tensor.WithBacking([]int64{2}), tensor.WithShape(1)) @@ -259,10 +276,12 @@ func TestGatherIndexOutOfRange(t *testing.T) { func TestInputValidationGather(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -270,6 +289,7 @@ func TestInputValidationGather(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -277,20 +297,23 @@ func TestInputValidationGather(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputCount(1, &Gather{}), + ops.ErrInvalidInputCount(1, gather13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), }, - ops.ErrInvalidInputType(1, "float32", &Gather{}), + ops.ErrInvalidInputType(1, "float32", gather13BaseOpFixture()), }, } for _, test := range tests { - gather := &Gather{} + gather := gatherVersions[test.version]() + validated, err := gather.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -300,3 +323,7 @@ func TestInputValidationGather(t *testing.T) { } } } + +func gather13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, gatherTypeConstraints, "gather") +} diff --git a/ops/gather/versions.go b/ops/gather/versions.go new file mode 100644 index 0000000..671dc15 --- /dev/null +++ b/ops/gather/versions.go @@ -0,0 +1,13 @@ +package gather + +import "github.com/advancedclimatesystems/gonnx/ops" + +var gatherVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newGather, 1, gatherTypeConstraints), + 11: ops.NewOperatorConstructor(newGather, 11, gatherTypeConstraints), + 13: ops.NewOperatorConstructor(newGather, 13, gatherTypeConstraints), +} + +func GetGatherVersions() ops.OperatorVersions { + return gatherVersions +} diff --git a/ops/gemm/constants.go b/ops/gemm/constants.go new file mode 100644 index 0000000..009a68a --- /dev/null +++ b/ops/gemm/constants.go @@ -0,0 +1,8 @@ +package gemm + +const ( + alpha = "alpha" + beta = "beta" + transA = "transA" + transB = "transB" +) diff --git a/ops/opset13/gemm.go b/ops/gemm/gemm.go similarity index 58% rename from ops/opset13/gemm.go rename to ops/gemm/gemm.go index 2db2a44..4560d48 100644 --- a/ops/opset13/gemm.go +++ b/ops/gemm/gemm.go @@ -1,4 +1,4 @@ -package opset13 +package gemm import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,6 +6,12 @@ import ( "gorgonia.org/tensor" ) +var gemmTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + const ( MinGemmInputs = 2 MaxGemmInputs = 3 @@ -13,6 +19,8 @@ const ( // Gemm represents the ONNX gemm operator. type Gemm struct { + ops.BaseOperator + alpha float32 beta float32 transA bool @@ -20,8 +28,15 @@ type Gemm struct { } // newGemm creates a new gemm operator and initializes it with the default values. -func newGemm() ops.Operator { +func newGemm(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &Gemm{ + BaseOperator: ops.NewBaseOperator( + version, + MinGemmInputs, + MaxGemmInputs, + typeConstraints, + "gemm", + ), alpha: 1.0, beta: 1.0, transA: false, @@ -33,13 +48,13 @@ func newGemm() ops.Operator { func (g *Gemm) Init(n *onnx.NodeProto) error { for _, attr := range n.GetAttribute() { switch attr.GetName() { - case "alpha": + case alpha: g.alpha = attr.GetF() - case "beta": + case beta: g.beta = attr.GetF() - case "transA": + case transA: g.transA = ops.Int64ToBool(attr.GetI()) - case "transB": + case transB: g.transB = ops.Int64ToBool(attr.GetI()) default: return ops.ErrInvalidAttribute(attr.GetName(), g) @@ -103,33 +118,3 @@ func (g *Gemm) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{output}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *Gemm) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(g, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *Gemm) GetMinInputs() int { - return MinGemmInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *Gemm) GetMaxInputs() int { - return MaxGemmInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (g *Gemm) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (g *Gemm) String() string { - return "gemm operator" -} diff --git a/ops/gemm/gemm_legacy.go b/ops/gemm/gemm_legacy.go new file mode 100644 index 0000000..c5195bf --- /dev/null +++ b/ops/gemm/gemm_legacy.go @@ -0,0 +1,98 @@ +package gemm + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Gemm9 represents the ONNX gemm operator, for version <= 9. +type Gemm9 struct { + ops.BaseOperator + + alpha float32 + beta float32 + transA bool + transB bool +} + +// newGemm7 creates a new gemm operator and initializes it with the default values. +func newGemm9(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Gemm9{ + BaseOperator: ops.NewBaseOperator(version, 3, 3, typeConstraints, "gemm"), + alpha: 1.0, + beta: 1.0, + transA: false, + transB: false, + } +} + +// Init initializes the Gemm7 operator based on the ModelProto attributes. +func (g *Gemm9) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case alpha: + g.alpha = attr.GetF() + case beta: + g.beta = attr.GetF() + case transA: + g.transA = ops.Int64ToBool(attr.GetI()) + case transB: + g.transB = ops.Int64ToBool(attr.GetI()) + default: + return ops.ErrInvalidAttribute(attr.GetName(), g) + } + } + + return nil +} + +// Apply applies the gemm operator on the given graph. +func (g *Gemm9) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + a := inputs[0] + b := inputs[1] + c := inputs[2] + + if g.transA { + a, err = tensor.Transpose(a) + if err != nil { + return nil, err + } + } + + if g.transB { + b, err = tensor.Transpose(b) + if err != nil { + return nil, err + } + } + + x, err := tensor.MatMul(a, b) + if err != nil { + return nil, err + } + + x, err = tensor.Mul(x, g.alpha) + if err != nil { + return nil, err + } + + y, err := tensor.Mul(c, g.beta) + if err != nil { + return nil, err + } + + x, y, err = ops.UnidirectionalBroadcast(x, y) + if err != nil { + return nil, err + } + + output, err := tensor.Add(x, y) + if err != nil { + return nil, err + } + + return []tensor.Tensor{output}, nil +} diff --git a/ops/opset13/gemm_test.go b/ops/gemm/gemm_test.go similarity index 54% rename from ops/opset13/gemm_test.go rename to ops/gemm/gemm_test.go index 37255d4..15bb4a9 100644 --- a/ops/opset13/gemm_test.go +++ b/ops/gemm/gemm_test.go @@ -1,4 +1,4 @@ -package opset13 +package gemm import ( "testing" @@ -30,57 +30,114 @@ func TestGemmInitFail(t *testing.T) { func TestGemm(t *testing.T) { tests := []struct { - gemm *Gemm + version int64 + attrs *onnx.NodeProto shapes [][]int expected []float32 }{ { - &Gemm{1, 1, false, false}, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 0}, + }, + }, [][]int{{3, 2}, {2, 5}, {5}}, []float32{5, 7, 9, 11, 13, 15, 21, 27, 33, 39, 25, 35, 45, 55, 65}, }, { - &Gemm{1, 1, true, false}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 1}, + {Name: "transB", I: 0}, + }, + }, [][]int{{2, 3}, {2, 5}, {5}}, []float32{15, 19, 23, 27, 31, 20, 26, 32, 38, 44, 25, 33, 41, 49, 57}, }, { - &Gemm{1, 1, true, true}, + 9, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 1}, + {Name: "transB", I: 1}, + }, + }, [][]int{{2, 3}, {5, 2}, {5}}, []float32{3, 10, 17, 24, 31, 4, 15, 26, 37, 48, 5, 20, 35, 50, 65}, }, { - &Gemm{1, 1, false, true}, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, [][]int{{3, 2}, {5, 2}, {5}}, []float32{1, 4, 7, 10, 13, 3, 14, 25, 36, 47, 5, 24, 43, 62, 81}, }, { - &Gemm{1, 1, false, false}, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 0}, + }, + }, [][]int{{1, 2}, {2, 5}, {5}}, []float32{5, 7, 9, 11, 13}, }, { - &Gemm{1, 1, false, false}, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 0}, + }, + }, [][]int{{1, 2}, {2, 5}}, []float32{5, 6, 7, 8, 9}, }, { - &Gemm{1, 1, false, false}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 0}, + }, + }, [][]int{{20, 4}, {4, 6}, {6}}, []float32{ 84, 91, 98, 105, 112, 119, 228, 251, 274, 297, 320, 343, 372, 411, 450, 489, 528, 567, 516, 571, 626, 681, 736, 791, 660, 731, 802, 873, 944, 1015, 804, - 891, 978, 1065, 1152, 1239, 948, 1051, 1154, 1257, - 1360, 1463, 1092, 1211, 1330, 1449, 1568, 1687, 1236, - 1371, 1506, 1641, 1776, 1911, 1380, 1531, 1682, 1833, - 1984, 2135, 1524, 1691, 1858, 2025, 2192, 2359, 1668, - 1851, 2034, 2217, 2400, 2583, 1812, 2011, 2210, 2409, - 2608, 2807, 1956, 2171, 2386, 2601, 2816, 3031, 2100, - 2331, 2562, 2793, 3024, 3255, 2244, 2491, 2738, 2985, - 3232, 3479, 2388, 2651, 2914, 3177, 3440, 3703, 2532, - 2811, 3090, 3369, 3648, 3927, 2676, 2971, 3266, 3561, - 3856, 4151, 2820, 3131, 3442, 3753, 4064, 4375, + 891, 978, 1065, 1152, 1239, 948, 1051, 1154, 1257, 1360, + 1463, 1092, 1211, 1330, 1449, 1568, 1687, 1236, 1371, + 1506, 1641, 1776, 1911, 1380, 1531, 1682, 1833, 1984, + 2135, 1524, 1691, 1858, 2025, 2192, 2359, 1668, 1851, + 2034, 2217, 2400, 2583, 1812, 2011, 2210, 2409, 2608, + 2807, 1956, 2171, 2386, 2601, 2816, 3031, 2100, 2331, + 2562, 2793, 3024, 3255, 2244, 2491, 2738, 2985, 3232, + 3479, 2388, 2651, 2914, 3177, 3440, 3703, 2532, 2811, + 3090, 3369, 3648, 3927, 2676, 2971, 3266, 3561, 3856, + 4151, 2820, 3131, 3442, 3753, 4064, 4375, }, }, } @@ -96,7 +153,11 @@ func TestGemm(t *testing.T) { inputs = append(inputs, nil) } - res, err := test.gemm.Apply(inputs) + gemm := gemmVersions[test.version]() + err := gemm.Init(test.attrs) + assert.Nil(t, err) + + res, err := gemm.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -106,11 +167,13 @@ func TestGemm(t *testing.T) { func TestInputValidationGemm(t *testing.T) { tests := []struct { + version int64 inputs []tensor.Tensor expected []tensor.Tensor err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -123,6 +186,7 @@ func TestInputValidationGemm(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -132,11 +196,13 @@ func TestInputValidationGemm(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - ops.ErrInvalidOptionalInputCount(1, &Gemm{}), + ops.ErrInvalidOptionalInputCount(1, gemm13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -144,20 +210,31 @@ func TestInputValidationGemm(t *testing.T) { ops.TensorWithBackingFixture([]uint32{1, 2}, 2), }, nil, - ops.ErrInvalidOptionalInputCount(4, &Gemm{}), + ops.ErrInvalidOptionalInputCount(4, gemm13BaseOpFixture()), }, { + 7, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + ops.TensorWithBackingFixture([]uint32{3, 4}, 2), + ops.TensorWithBackingFixture([]uint32{5, 6}, 2), + }, + nil, + ops.ErrInvalidInputType(0, "uint32", gemm7BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - ops.ErrInvalidInputType(0, "int", &Gemm{}), + ops.ErrInvalidInputType(0, "int", gemm13BaseOpFixture()), }, } for _, test := range tests { - gemm := &Gemm{} + gemm := gemmVersions[test.version]() validated, err := gemm.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -182,3 +259,21 @@ func GemmOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func gemm7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator( + 7, + 3, + 3, + [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + }, + "gemm", + ) +} + +func gemm13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 3, gemmTypeConstraints, "gemm") +} diff --git a/ops/gemm/versions.go b/ops/gemm/versions.go new file mode 100644 index 0000000..c21eea8 --- /dev/null +++ b/ops/gemm/versions.go @@ -0,0 +1,25 @@ +package gemm + +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var gemmVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor( + newGemm9, + 7, + [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + }, + ), + 9: ops.NewOperatorConstructor(newGemm9, 9, gemmTypeConstraints), + 11: ops.NewOperatorConstructor(newGemm, 11, gemmTypeConstraints), + 13: ops.NewOperatorConstructor(newGemm, 13, gemmTypeConstraints), +} + +func GetGemmVersions() ops.OperatorVersions { + return gemmVersions +} diff --git a/ops/greater/greater.go b/ops/greater/greater.go new file mode 100644 index 0000000..7f7129d --- /dev/null +++ b/ops/greater/greater.go @@ -0,0 +1,44 @@ +package greater + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var greater7TypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}} + +var greaterTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} + +// Greater represents the ONNX greater operator. +type Greater struct { + ops.BaseOperator +} + +// newGreater creates a new greater operator. +func newGreater(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Greater{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "greater", + ), + } +} + +// Init initializes the greater operator. +func (g *Greater) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the greater operator. +func (g *Greater) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Gt, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/greater_test.go b/ops/greater/greater_test.go similarity index 74% rename from ops/opset13/greater_test.go rename to ops/greater/greater_test.go index 18bc294..886edde 100644 --- a/ops/opset13/greater_test.go +++ b/ops/greater/greater_test.go @@ -1,4 +1,4 @@ -package opset13 +package greater import ( "testing" @@ -19,25 +19,25 @@ func TestGreaterInit(t *testing.T) { func TestGreater(t *testing.T) { tests := []struct { - greater *Greater + version int64 backings [][]float32 shapes [][]int expected []bool }{ { - &Greater{}, + 7, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{false, false, true, true}, }, { - &Greater{}, + 9, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{false, false, false, true, true, true}, }, { - &Greater{}, + 13, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []bool{false, false, false, false}, @@ -50,7 +50,9 @@ func TestGreater(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.greater.Apply(inputs) + greater := greaterVersions[test.version]() + + res, err := greater.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -60,10 +62,12 @@ func TestGreater(t *testing.T) { func TestInputValidationGreater(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -71,6 +75,7 @@ func TestInputValidationGreater(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -78,6 +83,7 @@ func TestInputValidationGreater(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -85,6 +91,7 @@ func TestInputValidationGreater(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -92,6 +99,7 @@ func TestInputValidationGreater(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -99,6 +107,7 @@ func TestInputValidationGreater(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -106,22 +115,32 @@ func TestInputValidationGreater(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Greater{}), + ops.ErrInvalidInputCount(1, greater13BaseOpFixture()), }, { + 7, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int32", greater7BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Greater{}), + ops.ErrInvalidInputType(0, "int", greater13BaseOpFixture()), }, } for _, test := range tests { - greater := &Greater{} + greater := greaterVersions[test.version]() validated, err := greater.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -131,3 +150,11 @@ func TestInputValidationGreater(t *testing.T) { } } } + +func greater7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, greater7TypeConstraints, "greater") +} + +func greater13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, greaterTypeConstraints, "greater") +} diff --git a/ops/greater/versions.go b/ops/greater/versions.go new file mode 100644 index 0000000..1f1194b --- /dev/null +++ b/ops/greater/versions.go @@ -0,0 +1,13 @@ +package greater + +import "github.com/advancedclimatesystems/gonnx/ops" + +var greaterVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newGreater, 7, greater7TypeConstraints), + 9: ops.NewOperatorConstructor(newGreater, 9, greaterTypeConstraints), + 13: ops.NewOperatorConstructor(newGreater, 13, greaterTypeConstraints), +} + +func GetGreaterVersions() ops.OperatorVersions { + return greaterVersions +} diff --git a/ops/greaterorequal/greater_or_equal.go b/ops/greaterorequal/greater_or_equal.go new file mode 100644 index 0000000..0515abc --- /dev/null +++ b/ops/greaterorequal/greater_or_equal.go @@ -0,0 +1,42 @@ +package greaterorequal + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var greaterOrEqualTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} + +// GreaterOrEqual represents the ONNX greaterOrEqual operator. +type GreaterOrEqual struct { + ops.BaseOperator +} + +// newGreaterOrEqual creates a new greaterOrEqual operator. +func newGreaterOrEqual(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &GreaterOrEqual{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "greaterorequal", + ), + } +} + +// Init initializes the greaterOrEqual operator. +func (g *GreaterOrEqual) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the greaterOrEqual operator. +func (g *GreaterOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Gte, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/greater_or_equal_test.go b/ops/greaterorequal/greater_or_equal_test.go similarity index 78% rename from ops/opset13/greater_or_equal_test.go rename to ops/greaterorequal/greater_or_equal_test.go index 37f5dec..b084988 100644 --- a/ops/opset13/greater_or_equal_test.go +++ b/ops/greaterorequal/greater_or_equal_test.go @@ -1,4 +1,4 @@ -package opset13 +package greaterorequal import ( "testing" @@ -19,25 +19,25 @@ func TestGreaterOrEqualInit(t *testing.T) { func TestGreaterOrEqual(t *testing.T) { tests := []struct { - greaterOrEqual *GreaterOrEqual - backings [][]float32 - shapes [][]int - expected []bool + version int64 + backings [][]float32 + shapes [][]int + expected []bool }{ { - &GreaterOrEqual{}, + 12, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{false, true, true, true}, }, { - &GreaterOrEqual{}, + 12, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{false, false, true, true, true, true}, }, { - &GreaterOrEqual{}, + 12, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []bool{true, true, false, false}, @@ -50,7 +50,9 @@ func TestGreaterOrEqual(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.greaterOrEqual.Apply(inputs) + greaterOrEqual := greaterOrEqualVersions[test.version]() + + res, err := greaterOrEqual.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -60,10 +62,12 @@ func TestGreaterOrEqual(t *testing.T) { func TestInputValidationGreaterOrEqual(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -71,6 +75,7 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -78,6 +83,7 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -85,6 +91,7 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -92,6 +99,7 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -99,6 +107,7 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -106,22 +115,24 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &GreaterOrEqual{}), + ops.ErrInvalidInputCount(1, greaterOrEqual12BaseOpFixture()), }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &GreaterOrEqual{}), + ops.ErrInvalidInputType(0, "int", greaterOrEqual12BaseOpFixture()), }, } for _, test := range tests { - greaterOrEqual := &GreaterOrEqual{} + greaterOrEqual := greaterOrEqualVersions[test.version]() validated, err := greaterOrEqual.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -131,3 +142,7 @@ func TestInputValidationGreaterOrEqual(t *testing.T) { } } } + +func greaterOrEqual12BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(12, 2, 2, greaterOrEqualTypeConstraints, "greaterorequal") +} diff --git a/ops/greaterorequal/versions.go b/ops/greaterorequal/versions.go new file mode 100644 index 0000000..35085ca --- /dev/null +++ b/ops/greaterorequal/versions.go @@ -0,0 +1,11 @@ +package greaterorequal + +import "github.com/advancedclimatesystems/gonnx/ops" + +var greaterOrEqualVersions = ops.OperatorVersions{ + 12: ops.NewOperatorConstructor(newGreaterOrEqual, 12, greaterOrEqualTypeConstraints), +} + +func GetGreaterOrEqualVersions() ops.OperatorVersions { + return greaterOrEqualVersions +} diff --git a/ops/opset13/gru.go b/ops/gru/gru.go similarity index 84% rename from ops/opset13/gru.go rename to ops/gru/gru.go index e5f3ff6..01e3440 100644 --- a/ops/opset13/gru.go +++ b/ops/gru/gru.go @@ -1,11 +1,21 @@ -package opset13 +package gru import ( "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/gemm" "gorgonia.org/tensor" ) +var gruTypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, +} + const ( MinGRUInputs = 3 MaxGRUInputs = 6 @@ -14,6 +24,8 @@ const ( // GRU represents the ONNX gru operator. It only supports a simple forward gru // operation with default activations. type GRU struct { + ops.BaseOperator + activationAlpha []float32 activationBeta []float32 activations []string @@ -23,8 +35,15 @@ type GRU struct { } // newGRU creates a new gru operator. -func newGRU() ops.Operator { +func newGRU(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &GRU{ + BaseOperator: ops.NewBaseOperator( + version, + MinGRUInputs, + MaxGRUInputs, + typeConstraints, + "gru", + ), activations: []string{"sigmoid", "tanh"}, direction: ops.Forward, linearBeforeReset: false, @@ -71,7 +90,7 @@ func (g *GRU) Init(n *onnx.NodeProto) error { // Apply applies the gru operator. func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if inputs[4] != nil { - return nil, ops.ErrUnsupportedInput("sequence lens", g) + return nil, ops.ErrUnsupportedInput("sequence lens", g.BaseOperator) } X := inputs[0] @@ -185,39 +204,6 @@ func (g *GRU) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{Y, Yh}, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *GRU) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(g, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *GRU) GetMinInputs() int { - return MinGRUInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *GRU) GetMaxInputs() int { - return MaxGRUInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (g *GRU) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Int32}, - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (g *GRU) String() string { - return "gru operator" -} - // extractXt extracts the value of x for timestep t. func (g *GRU) extractXt(X tensor.Tensor, t int) (tensor.Tensor, error) { return X.Slice(ops.NewSlicer(t, t+1), nil, nil) @@ -226,7 +212,21 @@ func (g *GRU) extractXt(X tensor.Tensor, t int) (tensor.Tensor, error) { func (g *GRU) gateCalculation( Xt, H, W, R, Wb, Rb tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { - gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + gemm := gemm.GetGemmVersions()[13]() + + err := gemm.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, + ) + if err != nil { + return nil, err + } inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) if err != nil { @@ -258,7 +258,21 @@ func (g *GRU) htCalculation( return g.gateCalculation(Xt, temp1, W, R, Wb, Rb, activation) } - gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + gemm := gemm.GetGemmVersions()[13]() + + err := gemm.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, + ) + if err != nil { + return nil, err + } inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) if err != nil { diff --git a/ops/opset13/gru_test.go b/ops/gru/gru_test.go similarity index 73% rename from ops/opset13/gru_test.go rename to ops/gru/gru_test.go index 44140f9..789f509 100644 --- a/ops/opset13/gru_test.go +++ b/ops/gru/gru_test.go @@ -1,4 +1,4 @@ -package opset13 +package gru import ( "testing" @@ -10,7 +10,7 @@ import ( ) func TestGruInit(t *testing.T) { - gru := &GRU{} + gru := GRU{} err := gru.Init(GRUOnnxNodeProtoFixture()) assert.Nil(t, err) @@ -46,58 +46,71 @@ func TestGruInitUnkownAttr(t *testing.T) { func TestGru(t *testing.T) { tests := []struct { - gru *GRU + version int64 + node *onnx.NodeProto inputs ops.InputFixture expected []float32 err error }{ { - &GRU{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh"}, - direction: ops.Forward, - hiddenSize: 4, - linearBeforeReset: true, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "linear_before_reset", I: 1}, + }, }, gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, }, { - &GRU{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh"}, - direction: ops.Forward, - hiddenSize: 4, - linearBeforeReset: false, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "linear_before_reset", I: 0}, + }, }, gruInput0, []float32{6.6936556e-03, 8.3446503e-07, 0.0000000e+00, 0.0000000e+00}, nil, }, { - &GRU{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh"}, - direction: ops.Forward, - hiddenSize: 4, - linearBeforeReset: false, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "linear_before_reset", I: 0}, + }, }, gruInput1, []float32{0.44905475, 0.4406946, 0.43368173, 0.42782417}, nil, }, { - &GRU{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh"}, - direction: ops.Forward, - hiddenSize: 4, - linearBeforeReset: false, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + {Name: "linear_before_reset", I: 0}, + }, }, gruInputNoBNoH, []float32{0.24553154, 0.24553154, 0.24553154, 0.24553154}, @@ -107,7 +120,12 @@ func TestGru(t *testing.T) { for _, test := range tests { inputs := test.inputs() - res, err := test.gru.Apply(inputs) + + gru := gruVersions[test.version]() + err := gru.Init(test.node) + assert.Nil(t, err) + + res, err := gru.Apply(inputs) assert.Equal(t, test.err, err) if err == nil { @@ -118,11 +136,13 @@ func TestGru(t *testing.T) { func TestInputValidationGRU(t *testing.T) { tests := []struct { + version int64 inputs []tensor.Tensor expected []tensor.Tensor err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -135,6 +155,7 @@ func TestInputValidationGRU(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -151,47 +172,53 @@ func TestInputValidationGRU(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, - ops.ErrInvalidOptionalInputCount(1, &GRU{}), + ops.ErrInvalidOptionalInputCount(1, gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(1, "int", &GRU{}), + ops.ErrInvalidInputType(1, "int", gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(0, "int", &GRU{}), + ops.ErrInvalidInputType(0, "int", gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(1, "int", &GRU{}), + ops.ErrInvalidInputType(1, "int", gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(2, "int", &GRU{}), + ops.ErrInvalidInputType(2, "int", gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -199,9 +226,10 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(3, "int", &GRU{}), + ops.ErrInvalidInputType(3, "int", gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -210,9 +238,10 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(4, "float32", &GRU{}), + ops.ErrInvalidInputType(4, "float32", gru7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -222,12 +251,12 @@ func TestInputValidationGRU(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(5, "int", &GRU{}), + ops.ErrInvalidInputType(5, "int", gru7BaseOpFixture()), }, } for _, test := range tests { - gru := &GRU{} + gru := gruVersions[test.version]() validated, err := gru.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -296,3 +325,7 @@ func GRUOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func gru7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 3, 6, gruTypeConstraints, "gru") +} diff --git a/ops/gru/versions.go b/ops/gru/versions.go new file mode 100644 index 0000000..72d87f2 --- /dev/null +++ b/ops/gru/versions.go @@ -0,0 +1,11 @@ +package gru + +import "github.com/advancedclimatesystems/gonnx/ops" + +var gruVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newGRU, 7, gruTypeConstraints), +} + +func GetGRUVersions() ops.OperatorVersions { + return gruVersions +} diff --git a/ops/less/less.go b/ops/less/less.go new file mode 100644 index 0000000..c03125e --- /dev/null +++ b/ops/less/less.go @@ -0,0 +1,44 @@ +package less + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var less7TypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}} + +var lessTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} + +// Less represents the ONNX less operator. +type Less struct { + ops.BaseOperator +} + +// newLess creates a new less operator. +func newLess(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Less{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "less", + ), + } +} + +// Init initializes the less operator. +func (l *Less) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the less operator. +func (l *Less) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Lt, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/less_test.go b/ops/less/less_test.go similarity index 75% rename from ops/opset13/less_test.go rename to ops/less/less_test.go index a7a4036..77b3395 100644 --- a/ops/opset13/less_test.go +++ b/ops/less/less_test.go @@ -1,4 +1,4 @@ -package opset13 +package less import ( "testing" @@ -19,25 +19,25 @@ func TestLessInit(t *testing.T) { func TestLess(t *testing.T) { tests := []struct { - less *Less + version int64 backings [][]float32 shapes [][]int expected []bool }{ { - &Less{}, + 7, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{true, false, false, false}, }, { - &Less{}, + 9, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{true, true, false, false, false, false}, }, { - &Less{}, + 13, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []bool{false, false, true, true}, @@ -50,7 +50,8 @@ func TestLess(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.less.Apply(inputs) + less := lessVersions[test.version]() + res, err := less.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -60,10 +61,12 @@ func TestLess(t *testing.T) { func TestInputValidationLess(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -71,6 +74,7 @@ func TestInputValidationLess(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -78,6 +82,7 @@ func TestInputValidationLess(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -85,6 +90,7 @@ func TestInputValidationLess(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -92,6 +98,7 @@ func TestInputValidationLess(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -99,6 +106,7 @@ func TestInputValidationLess(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -106,22 +114,32 @@ func TestInputValidationLess(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Less{}), + ops.ErrInvalidInputCount(1, less13BaseOpFixture()), }, { + 7, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int32", less7BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Less{}), + ops.ErrInvalidInputType(0, "int", less13BaseOpFixture()), }, } for _, test := range tests { - less := &Less{} + less := lessVersions[test.version]() validated, err := less.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -131,3 +149,11 @@ func TestInputValidationLess(t *testing.T) { } } } + +func less7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, less7TypeConstraints, "less") +} + +func less13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, lessTypeConstraints, "less") +} diff --git a/ops/less/versions.go b/ops/less/versions.go new file mode 100644 index 0000000..8e9c3be --- /dev/null +++ b/ops/less/versions.go @@ -0,0 +1,13 @@ +package less + +import "github.com/advancedclimatesystems/gonnx/ops" + +var lessVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newLess, 7, less7TypeConstraints), + 9: ops.NewOperatorConstructor(newLess, 9, lessTypeConstraints), + 13: ops.NewOperatorConstructor(newLess, 13, lessTypeConstraints), +} + +func GetLessVersions() ops.OperatorVersions { + return lessVersions +} diff --git a/ops/lessorequal/less_or_equal.go b/ops/lessorequal/less_or_equal.go new file mode 100644 index 0000000..52303b1 --- /dev/null +++ b/ops/lessorequal/less_or_equal.go @@ -0,0 +1,42 @@ +package lessorequal + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var lessOrEqualTypeConstraints = [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} + +// LessOrEqual represents the ONNX lessOrEqual operator. +type LessOrEqual struct { + ops.BaseOperator +} + +// newLessOrEqual creates a new lessOrEqual operator. +func newLessOrEqual(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &LessOrEqual{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "lessorequal", + ), + } +} + +// Init initializes the lessOrEqual operator. +func (l *LessOrEqual) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the lessOrEqual operator. +func (l *LessOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Lte, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/less_or_equal_test.go b/ops/lessorequal/less_or_equal_test.go similarity index 73% rename from ops/opset13/less_or_equal_test.go rename to ops/lessorequal/less_or_equal_test.go index fbba443..51bc766 100644 --- a/ops/opset13/less_or_equal_test.go +++ b/ops/lessorequal/less_or_equal_test.go @@ -1,4 +1,4 @@ -package opset13 +package lessorequal import ( "testing" @@ -9,35 +9,41 @@ import ( ) func TestLessOrEqualInit(t *testing.T) { - l := &LessOrEqual{} + tests := []struct { + version int64 + err error + }{ + {12, nil}, + } - // since 'lessOrEqual' does not have any attributes we pass in nil. This should not - // fail initializing the lessOrEqual. - err := l.Init(ops.EmptyNodeProto()) - assert.Nil(t, err) + for _, test := range tests { + l := lessOrEqualVersions[test.version]() + err := l.Init(nil) + assert.Equal(t, test.err, err) + } } func TestLessOrEqual(t *testing.T) { tests := []struct { - lessOrEqual *LessOrEqual - backings [][]float32 - shapes [][]int - expected []bool + version int64 + backings [][]float32 + shapes [][]int + expected []bool }{ { - &LessOrEqual{}, + 12, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []bool{true, true, false, false}, }, { - &LessOrEqual{}, + 12, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []bool{true, true, true, false, false, false}, }, { - &LessOrEqual{}, + 12, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []bool{true, true, true, true}, @@ -50,7 +56,8 @@ func TestLessOrEqual(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.lessOrEqual.Apply(inputs) + lessOrEqual := lessOrEqualVersions[test.version]() + res, err := lessOrEqual.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -60,10 +67,12 @@ func TestLessOrEqual(t *testing.T) { func TestInputValidationLessOrEqual(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -71,6 +80,7 @@ func TestInputValidationLessOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -78,6 +88,7 @@ func TestInputValidationLessOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -85,6 +96,7 @@ func TestInputValidationLessOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -92,6 +104,7 @@ func TestInputValidationLessOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -99,6 +112,7 @@ func TestInputValidationLessOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -106,22 +120,24 @@ func TestInputValidationLessOrEqual(t *testing.T) { nil, }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &LessOrEqual{}), + ops.ErrInvalidInputCount(1, lessOrEqual12BaseOpFixture()), }, { + 12, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &LessOrEqual{}), + ops.ErrInvalidInputType(0, "int", lessOrEqual12BaseOpFixture()), }, } for _, test := range tests { - lessOrEqual := &LessOrEqual{} + lessOrEqual := lessOrEqualVersions[test.version]() validated, err := lessOrEqual.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -131,3 +147,7 @@ func TestInputValidationLessOrEqual(t *testing.T) { } } } + +func lessOrEqual12BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(12, 2, 2, lessOrEqualTypeConstraints, "lessorequal") +} diff --git a/ops/lessorequal/versions.go b/ops/lessorequal/versions.go new file mode 100644 index 0000000..f70d30f --- /dev/null +++ b/ops/lessorequal/versions.go @@ -0,0 +1,11 @@ +package lessorequal + +import "github.com/advancedclimatesystems/gonnx/ops" + +var lessOrEqualVersions = ops.OperatorVersions{ + 12: ops.NewOperatorConstructor(newLessOrEqual, 12, lessOrEqualTypeConstraints), +} + +func GetLessOrEqualVersions() ops.OperatorVersions { + return lessOrEqualVersions +} diff --git a/ops/opset13/linear_regressor.go b/ops/linearregressor/linear_regressor.go similarity index 65% rename from ops/opset13/linear_regressor.go rename to ops/linearregressor/linear_regressor.go index ceb0cb1..eda82e1 100644 --- a/ops/opset13/linear_regressor.go +++ b/ops/linearregressor/linear_regressor.go @@ -1,4 +1,4 @@ -package opset13 +package linearregressor import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,10 +6,9 @@ import ( "gorgonia.org/tensor" ) -const ( - MinLinearRegressorInputs = 1 - MaxLinearRegressorInputs = 1 -) +var linearRegressorTypeConstraints = [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} // PostTransformOption describes all possible post transform options for the // linear regressor operator. @@ -25,6 +24,8 @@ const ( // LinearRegressor represents the ONNX-ml linearRegressor operator. type LinearRegressor struct { + ops.BaseOperator + coefficients tensor.Tensor intercepts tensor.Tensor postTransform postTransformOption @@ -32,8 +33,15 @@ type LinearRegressor struct { } // newLinearRegressor creates a new linearRegressor operator. -func newLinearRegressor() ops.Operator { +func newLinearRegressor(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &LinearRegressor{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "linearregressor", + ), postTransform: noTransform, targets: 1, } @@ -87,31 +95,3 @@ func (l *LinearRegressor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) return []tensor.Tensor{Y}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (l *LinearRegressor) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(l, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (l *LinearRegressor) GetMinInputs() int { - return MinLinearRegressorInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (l *LinearRegressor) GetMaxInputs() int { - return MaxLinearRegressorInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (l *LinearRegressor) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (l *LinearRegressor) String() string { - return "linearRegressor operator" -} diff --git a/ops/opset13/linear_regressor_test.go b/ops/linearregressor/linear_regressor_test.go similarity index 90% rename from ops/opset13/linear_regressor_test.go rename to ops/linearregressor/linear_regressor_test.go index abaeb25..59c021c 100644 --- a/ops/opset13/linear_regressor_test.go +++ b/ops/linearregressor/linear_regressor_test.go @@ -1,4 +1,4 @@ -package opset13 +package linearregressor import ( "testing" @@ -37,6 +37,7 @@ func TestLinearRegressorInitFailInvalidAttribute(t *testing.T) { func TestLinearRegressor(t *testing.T) { tests := []struct { + version int64 attrs []*onnx.AttributeProto shape []int backing []float32 @@ -45,6 +46,7 @@ func TestLinearRegressor(t *testing.T) { description string }{ { + 1, []*onnx.AttributeProto{ {Name: "coefficients", Floats: []float32{-0.45977323}}, {Name: "intercepts", Floats: []float32{0.21509616}}, @@ -57,6 +59,7 @@ func TestLinearRegressor(t *testing.T) { "linear regressor with 1 input and 1 output variable, 1 sample", }, { + 1, []*onnx.AttributeProto{ {Name: "coefficients", Floats: []float32{-0.45977323}}, {Name: "intercepts", Floats: []float32{0.21509616}}, @@ -69,6 +72,7 @@ func TestLinearRegressor(t *testing.T) { "linear regressor with 1 input and 1 output variable, 5 samples", }, { + 1, []*onnx.AttributeProto{ {Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}}, {Name: "intercepts", Floats: []float32{-0.43156273}}, @@ -81,6 +85,7 @@ func TestLinearRegressor(t *testing.T) { "linear regressor with 3 inputs and 1 output variable, 1 sample", }, { + 1, []*onnx.AttributeProto{ {Name: "coefficients", Floats: []float32{0.24118852, 0.22617804, 0.27858477}}, {Name: "intercepts", Floats: []float32{-0.43156273}}, @@ -93,6 +98,7 @@ func TestLinearRegressor(t *testing.T) { "linear regressor with 3 inputs and 1 output variable, 2 samples", }, { + 1, []*onnx.AttributeProto{ {Name: "coefficients", Floats: []float32{ 0.5384742, 0.36729308, 0.13292366, -0.03843413, @@ -109,6 +115,7 @@ func TestLinearRegressor(t *testing.T) { "linear regressor with 4 input and 3 output variables, 1 samples", }, { + 1, []*onnx.AttributeProto{ {Name: "coefficients", Floats: []float32{ 0.5384742, 0.36729308, 0.13292366, -0.03843413, @@ -131,7 +138,7 @@ func TestLinearRegressor(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - linearRegressor := newLinearRegressor() + linearRegressor := linearRegressorVersions[test.version]() err := linearRegressor.Init(&onnx.NodeProto{Attribute: test.attrs}) assert.Nil(t, err, test.description) @@ -144,37 +151,44 @@ func TestLinearRegressor(t *testing.T) { func TestInputValidationLinearRegressor(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int32{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int64{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &LinearRegressor{}), + ops.ErrInvalidInputCount(0, linearRegressor1BaseOpFixture()), }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &LinearRegressor{}), + ops.ErrInvalidInputType(0, "int", linearRegressor1BaseOpFixture()), }, } for _, test := range tests { - linearRegressor := &LinearRegressor{} + linearRegressor := linearRegressorVersions[test.version]() validated, err := linearRegressor.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -194,3 +208,13 @@ func LinearRegressorOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func linearRegressor1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator( + 1, + 1, + 1, + linearRegressorTypeConstraints, + "linearregressor", + ) +} diff --git a/ops/linearregressor/versions.go b/ops/linearregressor/versions.go new file mode 100644 index 0000000..8808594 --- /dev/null +++ b/ops/linearregressor/versions.go @@ -0,0 +1,11 @@ +package linearregressor + +import "github.com/advancedclimatesystems/gonnx/ops" + +var linearRegressorVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newLinearRegressor, 1, linearRegressorTypeConstraints), +} + +func GetLinearRegressorVersions() ops.OperatorVersions { + return linearRegressorVersions +} diff --git a/ops/opset13/logsoftmax.go b/ops/logsoftmax/logsoftmax.go similarity index 54% rename from ops/opset13/logsoftmax.go rename to ops/logsoftmax/logsoftmax.go index 3b8d14a..7009488 100644 --- a/ops/opset13/logsoftmax.go +++ b/ops/logsoftmax/logsoftmax.go @@ -1,4 +1,4 @@ -package opset13 +package logsoftmax import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,15 +6,26 @@ import ( "gorgonia.org/tensor" ) +var logSoftmaxTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + // LogSoftmax represents the ONNX logsoftmax operator. type LogSoftmax struct { + ops.BaseOperator + // The axis along which to perform the LogSoftmax operation. axis int } // newLogSoftmax creates a new logsoftmax operator. -func newLogSoftmax() ops.Operator { +func newLogSoftmax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &LogSoftmax{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "logsoftmax", + ), axis: -1, } } @@ -56,31 +67,3 @@ func (l *LogSoftmax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (l *LogSoftmax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(l, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (l *LogSoftmax) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (l *LogSoftmax) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (l *LogSoftmax) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (l *LogSoftmax) String() string { - return "logsoftmax operator" -} diff --git a/ops/opset13/logsoftmax_test.go b/ops/logsoftmax/logsoftmax_test.go similarity index 64% rename from ops/opset13/logsoftmax_test.go rename to ops/logsoftmax/logsoftmax_test.go index 80f5688..79d6709 100644 --- a/ops/opset13/logsoftmax_test.go +++ b/ops/logsoftmax/logsoftmax_test.go @@ -1,8 +1,9 @@ -package opset13 +package logsoftmax import ( "testing" + "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" "gorgonia.org/tensor" @@ -19,46 +20,62 @@ func TestLogSoftmaxInit(t *testing.T) { func TestLogSoftmax(t *testing.T) { tests := []struct { - logsoftmax *LogSoftmax - backing []float32 - shape []int - expected []float32 + version int64 + attrs *onnx.NodeProto + backing []float32 + shape []int + expected []float32 }{ { - &LogSoftmax{ - axis: -1, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: -1}, + }, }, []float32{0, 1, 2, 3}, []int{1, 4}, []float32{-3.4401898, -2.4401898, -1.4401897, -0.44018975}, }, { - &LogSoftmax{ - axis: 1, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: 1}, + }, }, []float32{0, 1, 2, 3}, []int{1, 4}, []float32{-3.4401898, -2.4401898, -1.4401897, -0.44018975}, }, { - &LogSoftmax{ - axis: -1, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: -1}, + }, }, []float32{0, 1, 2, 3}, []int{2, 2}, []float32{-1.3132616, -0.31326166, -1.3132616, -0.31326166}, }, { - &LogSoftmax{ - axis: -1, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: -1}, + }, }, []float32{0, 1, 2, 3, 4, 5}, []int{1, 2, 3}, []float32{-2.407606, -1.4076059, -0.40760595, -2.407606, -1.4076059, -0.40760595}, }, { - &LogSoftmax{ - axis: -1, + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axis", I: -1}, + }, }, []float32{0, 1, 2, 3}, []int{4, 1}, @@ -71,7 +88,11 @@ func TestLogSoftmax(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.logsoftmax.Apply(inputs) + logsoftmax := logSoftmaxVersions[test.version]() + err := logsoftmax.Init(test.attrs) + assert.Nil(t, err) + + res, err := logsoftmax.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) @@ -96,38 +117,44 @@ func TestLogSoftmaxFail(t *testing.T) { func TestInputValidationLogSoftmax(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(2, &LogSoftmax{}), + ops.ErrInvalidInputCount(2, logSoftmax13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &LogSoftmax{}), + ops.ErrInvalidInputType(0, "int", logSoftmax13BaseOpFixture()), }, } for _, test := range tests { - logsoftmax := &LogSoftmax{} + logsoftmax := logSoftmaxVersions[test.version]() + validated, err := logsoftmax.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -137,3 +164,7 @@ func TestInputValidationLogSoftmax(t *testing.T) { } } } + +func logSoftmax13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, logSoftmaxTypeConstraints, "logsoftmax") +} diff --git a/ops/logsoftmax/versions.go b/ops/logsoftmax/versions.go new file mode 100644 index 0000000..dd719bc --- /dev/null +++ b/ops/logsoftmax/versions.go @@ -0,0 +1,13 @@ +package logsoftmax + +import "github.com/advancedclimatesystems/gonnx/ops" + +var logSoftmaxVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newLogSoftmax, 1, logSoftmaxTypeConstraints), + 11: ops.NewOperatorConstructor(newLogSoftmax, 11, logSoftmaxTypeConstraints), + 13: ops.NewOperatorConstructor(newLogSoftmax, 13, logSoftmaxTypeConstraints), +} + +func GetLogSoftmaxVersions() ops.OperatorVersions { + return logSoftmaxVersions +} diff --git a/ops/opset13/lstm.go b/ops/lstm/lstm.go similarity index 87% rename from ops/opset13/lstm.go rename to ops/lstm/lstm.go index 8b32b2a..396d11a 100644 --- a/ops/opset13/lstm.go +++ b/ops/lstm/lstm.go @@ -1,8 +1,9 @@ -package opset13 +package lstm import ( "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/gemm" "gorgonia.org/tensor" ) @@ -11,8 +12,21 @@ const ( MaxLSTMInputs = 8 ) +var lstmTypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, +} + // LSTM represents the ONNX lstm operator. type LSTM struct { + ops.BaseOperator + activationAlpha []float32 activationBeta []float32 activations []string @@ -24,8 +38,15 @@ type LSTM struct { } // newLSTM creates a new lstm operator. -func newLSTM() ops.Operator { +func newLSTM(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &LSTM{ + BaseOperator: ops.NewBaseOperator( + version, + MinLSTMInputs, + MaxLSTMInputs, + typeConstraints, + "lstm", + ), activations: []string{"sigmoid", "tanh", "tanh"}, direction: ops.Forward, inputForget: false, @@ -72,7 +93,7 @@ func (l *LSTM) Init(n *onnx.NodeProto) error { // Apply applies the lstm operator. func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if inputs[4] != nil { - return nil, ops.ErrUnsupportedInput("sequence_lens", l) + return nil, ops.ErrUnsupportedInput("sequence_lens", l.BaseOperator) } X := inputs[0] @@ -235,41 +256,6 @@ func (l *LSTM) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return result, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (l *LSTM) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(l, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (l *LSTM) GetMinInputs() int { - return MinLSTMInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (l *LSTM) GetMaxInputs() int { - return MaxLSTMInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (l *LSTM) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Int32}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (l *LSTM) String() string { - return "lstm operator" -} - // gateCalculation performs a standard gate calculation for an LSTM gate defined as: // // o = f(Xt*(W^T) + Wb + H*(R^T) + Rb + P (.) C) @@ -292,7 +278,21 @@ func (l *LSTM) String() string { func (l *LSTM) gateCalculation( Xt, W, Wb, H, R, Rb, P, C tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { - gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + gemm := gemm.GetGemmVersions()[13]() + + err := gemm.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, + ) + if err != nil { + return nil, err + } inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, W, Wb}) if err != nil { diff --git a/ops/opset13/lstm_test.go b/ops/lstm/lstm_test.go similarity index 75% rename from ops/opset13/lstm_test.go rename to ops/lstm/lstm_test.go index 25ff642..4d66882 100644 --- a/ops/opset13/lstm_test.go +++ b/ops/lstm/lstm_test.go @@ -1,4 +1,4 @@ -package opset13 +package lstm import ( "math/rand" @@ -48,71 +48,87 @@ func TestLSTMInitUnkownAttr(t *testing.T) { func TestLSTM(t *testing.T) { tests := []struct { - lstm *LSTM + version int64 + attrs *onnx.NodeProto inputs ops.InputFixture expected []float32 err error }{ { - &LSTM{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh", "tanh"}, - direction: ops.Forward, - hiddenSize: 4, - outputs: []string{"Y", "Y_h", "Y_c"}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, }, lstmInput0, []float32{0.9159305, 0.9356764, 0.87070554, 0.84180677}, nil, }, { - &LSTM{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh", "relu"}, - direction: ops.Forward, - hiddenSize: 4, - outputs: []string{"Y", "Y_h", "Y_c"}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("relu")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, }, lstmInput0, []float32{1.7530097, 1.7829735, 1.6231446, 1.5197954}, nil, }, { - &LSTM{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh", "relu"}, - direction: ops.Forward, - hiddenSize: 4, - outputs: []string{"Y", "Y_h", "Y_c"}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("relu")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, }, lstmInput1, []float32{10.598255, 10.547241, 10.214846, 10.267471}, nil, }, { - &LSTM{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh", "relu"}, - direction: ops.Forward, - hiddenSize: 4, - outputs: []string{"Y", "Y_h", "Y_c"}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("relu")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, }, lstmInputNoBNoH, []float32{8.276371, 8.291079, 8.161418, 7.7900877}, nil, }, { - &LSTM{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid", "tanh", "tanh"}, - direction: ops.Forward, - hiddenSize: 4, - outputs: []string{"Y", "Y_h", "Y_c"}, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid"), []byte("tanh"), []byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, + Output: []string{"Y", "Y_h", "Y_c"}, }, lstmInputPeepholes, []float32{0.99891853, 0.99994266, 0.9995524, 0.99171203}, @@ -122,7 +138,12 @@ func TestLSTM(t *testing.T) { for _, test := range tests { inputs := test.inputs() - res, err := test.lstm.Apply(inputs) + + lstm := lstmVersions[test.version]() + err := lstm.Init(test.attrs) + assert.Nil(t, err) + + res, err := lstm.Apply(inputs) assert.Equal(t, test.err, err) if err == nil { @@ -133,11 +154,13 @@ func TestLSTM(t *testing.T) { func TestInputValidationLSTM(t *testing.T) { tests := []struct { + version int64 inputs []tensor.Tensor expected []tensor.Tensor err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -152,6 +175,7 @@ func TestInputValidationLSTM(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -170,38 +194,43 @@ func TestInputValidationLSTM(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, - ops.ErrInvalidOptionalInputCount(1, &LSTM{}), + ops.ErrInvalidOptionalInputCount(1, lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(1, "int", &LSTM{}), + ops.ErrInvalidInputType(1, "int", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(0, "int", &LSTM{}), + ops.ErrInvalidInputType(0, "int", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(2, "int", &LSTM{}), + ops.ErrInvalidInputType(2, "int", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -209,9 +238,10 @@ func TestInputValidationLSTM(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(3, "int", &LSTM{}), + ops.ErrInvalidInputType(3, "int", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -220,9 +250,10 @@ func TestInputValidationLSTM(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(4, "float32", &LSTM{}), + ops.ErrInvalidInputType(4, "float32", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -232,9 +263,10 @@ func TestInputValidationLSTM(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(5, "int", &LSTM{}), + ops.ErrInvalidInputType(5, "int", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -245,9 +277,10 @@ func TestInputValidationLSTM(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(6, "int", &LSTM{}), + ops.ErrInvalidInputType(6, "int", lstm7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -259,12 +292,12 @@ func TestInputValidationLSTM(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(7, "int", &LSTM{}), + ops.ErrInvalidInputType(7, "int", lstm7BaseOpFixture()), }, } for _, test := range tests { - lstm := &LSTM{} + lstm := lstmVersions[test.version]() validated, err := lstm.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -384,3 +417,7 @@ func LSTMOnnxNodeProtoFixture() *onnx.NodeProto { Output: []string{"Y", "Y_h"}, } } + +func lstm7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 3, 8, lstmTypeConstraints, "lstm") +} diff --git a/ops/lstm/versions.go b/ops/lstm/versions.go new file mode 100644 index 0000000..da9eec0 --- /dev/null +++ b/ops/lstm/versions.go @@ -0,0 +1,11 @@ +package lstm + +import "github.com/advancedclimatesystems/gonnx/ops" + +var lstmVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newLSTM, 7, lstmTypeConstraints), +} + +func GetLSTMVersions() ops.OperatorVersions { + return lstmVersions +} diff --git a/ops/opset13/matmul.go b/ops/matmul/matmul.go similarity index 82% rename from ops/opset13/matmul.go rename to ops/matmul/matmul.go index 1212233..2ae5b8c 100644 --- a/ops/opset13/matmul.go +++ b/ops/matmul/matmul.go @@ -1,4 +1,4 @@ -package opset13 +package matmul import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,17 +6,32 @@ import ( "gorgonia.org/tensor" ) -const ( - MinMatMulInputs = 2 - MaxMatMulInputs = 2 -) +var matmul1TypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, +} + +var matmulTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} // MatMul represents the ONNX matmul operator. -type MatMul struct{} +type MatMul struct { + ops.BaseOperator +} // newMatMul returns a new MatMul operator. -func newMatMul() ops.Operator { - return &MatMul{} +func newMatMul(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &MatMul{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "matmul", + ), + } } // Init initializes the matmul operator. @@ -108,35 +123,6 @@ func (m *MatMul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, err } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (m *MatMul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(m, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (m *MatMul) GetMinInputs() int { - return MinMatMulInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (m *MatMul) GetMaxInputs() int { - return MaxMatMulInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (m *MatMul) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (m *MatMul) String() string { - return "matmul operator" -} - // broadcastTensors broadcasts both tensors for the matmul operator. It is almost identical // to multidirectional broadcast, but here we need to treat the 2 trailing dimensions as // matrices, and we do not want to broadcast those. All leading dimensions to the matrices diff --git a/ops/opset13/matmul_test.go b/ops/matmul/matmul_test.go similarity index 77% rename from ops/opset13/matmul_test.go rename to ops/matmul/matmul_test.go index fa8dcc2..9119068 100644 --- a/ops/opset13/matmul_test.go +++ b/ops/matmul/matmul_test.go @@ -1,4 +1,4 @@ -package opset13 +package matmul import ( "testing" @@ -9,34 +9,46 @@ import ( ) func TestMatMulInit(t *testing.T) { - s := newMatMul() - // since the matMul does not have any attributes we expect it to initialize even - // when nil is passed. - err := s.Init(nil) + tests := []struct { + version int64 + err error + }{ + {1, nil}, + {9, nil}, + {13, nil}, + } - assert.Nil(t, err) + for _, test := range tests { + r := matMulVersions[test.version]() + err := r.Init(nil) + assert.Equal(t, test.err, err) + } } func TestMatMul(t *testing.T) { tests := []struct { + version int64 backings [][]float32 shapes [][]int expected []float32 expectedShape tensor.Shape }{ { + 13, [][]float32{{3, 1, 4}, {4, 3, 2, 5, 6, 8}}, [][]int{{1, 3}, {3, 2}}, []float32{38, 46}, []int{1, 2}, }, { + 13, [][]float32{{3, 4, 7, 2, 5, 9}, {3, 1, 5, 6, 9, 7}}, [][]int{{3, 2}, {2, 3}}, []float32{33, 39, 43, 33, 25, 49, 69, 86, 88}, []int{3, 3}, }, { + 13, [][]float32{ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, @@ -46,6 +58,7 @@ func TestMatMul(t *testing.T) { []int{2, 2, 2}, }, { + 13, [][]float32{ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, {1, 2, 3, 4, 5, 6}, @@ -55,6 +68,7 @@ func TestMatMul(t *testing.T) { []int{2, 2, 2}, }, { + 13, [][]float32{ {0, 1, 2, 3, 4, 5}, {1, 2, 3, 4}, @@ -64,18 +78,21 @@ func TestMatMul(t *testing.T) { []int{2, 3, 4}, }, { + 13, [][]float32{{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, {1, 2}}, [][]int{{2, 2, 2, 2}, {2}}, []float32{5, 11, 17, 23, 29, 35, 41, 47}, []int{2, 2, 2}, }, { + 13, [][]float32{{1, 2}, {1, 2, 3, 4}}, [][]int{{2}, {2, 2}}, []float32{7, 10}, []int{2}, }, { + 13, [][]float32{{1, 2, 3, 4}, {1, 2}}, [][]int{{2, 2}, {2}}, []float32{5, 11}, @@ -84,12 +101,12 @@ func TestMatMul(t *testing.T) { } for i, test := range tests { - matmul := &MatMul{} inputs := []tensor.Tensor{ ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } + matmul := matMulVersions[test.version]() res, err := matmul.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data(), "test number %d", i) @@ -116,7 +133,9 @@ func TestBroadcastTensors(t *testing.T) { } for _, test := range tests { - matmul := &MatMul{} + matmul, ok := matMulVersions[13]().(*MatMul) + assert.True(t, ok) + A := ops.Float32TensorFixture(test.shapes[0]...) B := ops.Float32TensorFixture(test.shapes[1]...) newA, newB, err := matmul.broadcastTensors(A, B) @@ -129,10 +148,12 @@ func TestBroadcastTensors(t *testing.T) { func TestInputValidationMatMul(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -140,6 +161,7 @@ func TestInputValidationMatMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -147,6 +169,7 @@ func TestInputValidationMatMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -154,6 +177,7 @@ func TestInputValidationMatMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -161,6 +185,7 @@ func TestInputValidationMatMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -168,6 +193,7 @@ func TestInputValidationMatMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -175,22 +201,32 @@ func TestInputValidationMatMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &MatMul{}), + ops.ErrInvalidInputCount(1, matmul13BaseOpFixture()), }, { + 1, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int32", matmul1BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &MatMul{}), + ops.ErrInvalidInputType(0, "int", matmul13BaseOpFixture()), }, } for _, test := range tests { - matmul := &MatMul{} + matmul := matMulVersions[test.version]() validated, err := matmul.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -200,3 +236,23 @@ func TestInputValidationMatMul(t *testing.T) { } } } + +func matmul1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator( + 1, + 2, + 2, + matmul1TypeConstraints, + "matmul", + ) +} + +func matmul13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator( + 13, + 2, + 2, + matmulTypeConstraints, + "matmul", + ) +} diff --git a/ops/matmul/versions.go b/ops/matmul/versions.go new file mode 100644 index 0000000..11b5a51 --- /dev/null +++ b/ops/matmul/versions.go @@ -0,0 +1,13 @@ +package matmul + +import "github.com/advancedclimatesystems/gonnx/ops" + +var matMulVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newMatMul, 1, matmul1TypeConstraints), + 9: ops.NewOperatorConstructor(newMatMul, 9, matmulTypeConstraints), + 13: ops.NewOperatorConstructor(newMatMul, 13, matmulTypeConstraints), +} + +func GetMatMulVersions() ops.OperatorVersions { + return matMulVersions +} diff --git a/ops/mul/mul.go b/ops/mul/mul.go new file mode 100644 index 0000000..461c885 --- /dev/null +++ b/ops/mul/mul.go @@ -0,0 +1,45 @@ +package mul + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var mulTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Mul represents the ONNX mul operator. +type Mul struct { + ops.BaseOperator +} + +// newMul creates a new mul operator. +func newMul(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Mul{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "mul", + ), + } +} + +// Init initializes the mul operator. +func (m *Mul) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the mul operator. +func (m *Mul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Mul, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/mul_test.go b/ops/mul/mul_test.go similarity index 78% rename from ops/opset13/mul_test.go rename to ops/mul/mul_test.go index e6d00e4..d922d67 100644 --- a/ops/opset13/mul_test.go +++ b/ops/mul/mul_test.go @@ -1,4 +1,4 @@ -package opset13 +package mul import ( "testing" @@ -19,25 +19,25 @@ func TestMulInit(t *testing.T) { func TestMul(t *testing.T) { tests := []struct { - mul *Mul + version int64 backings [][]float32 shapes [][]int expected []float32 }{ { - &Mul{}, + 13, [][]float32{{0, 1, 2, 3}, {1, 1, 1, 1}}, [][]int{{2, 2}, {2, 2}}, []float32{0, 1, 2, 3}, }, { - &Mul{}, + 13, [][]float32{{0, 1, 2, 3, 4, 5}, {2, 2, 2, 2, 2, 2}}, [][]int{{3, 2}, {3, 2}}, []float32{0, 2, 4, 6, 8, 10}, }, { - &Mul{}, + 13, [][]float32{{0, 1}, {0, 1, 2, 3}}, [][]int{{2}, {2, 2}}, []float32{0, 1, 0, 3}, @@ -50,7 +50,8 @@ func TestMul(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.mul.Apply(inputs) + mul := mulVersions[test.version]() + res, err := mul.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -75,10 +76,12 @@ func TestMulFail(t *testing.T) { func TestInputValidationMul(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -86,6 +89,7 @@ func TestInputValidationMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -93,6 +97,7 @@ func TestInputValidationMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -100,6 +105,7 @@ func TestInputValidationMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -107,6 +113,7 @@ func TestInputValidationMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -114,6 +121,7 @@ func TestInputValidationMul(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -121,22 +129,31 @@ func TestInputValidationMul(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Mul{}), + ops.ErrInvalidInputCount(1, mul7BaseOpFixture()), }, { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, mul13BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Mul{}), + ops.ErrInvalidInputType(0, "int", mul13BaseOpFixture()), }, } for _, test := range tests { - mul := &Mul{} + mul := mulVersions[test.version]() validated, err := mul.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -146,3 +163,11 @@ func TestInputValidationMul(t *testing.T) { } } } + +func mul7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, mulTypeConstraints, "mul") +} + +func mul13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, mulTypeConstraints, "mul") +} diff --git a/ops/mul/versions.go b/ops/mul/versions.go new file mode 100644 index 0000000..09bf698 --- /dev/null +++ b/ops/mul/versions.go @@ -0,0 +1,12 @@ +package mul + +import "github.com/advancedclimatesystems/gonnx/ops" + +var mulVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newMul, 7, mulTypeConstraints), + 13: ops.NewOperatorConstructor(newMul, 13, mulTypeConstraints), +} + +func GetMulVersions() ops.OperatorVersions { + return mulVersions +} diff --git a/ops/not/not.go b/ops/not/not.go new file mode 100644 index 0000000..863a099 --- /dev/null +++ b/ops/not/not.go @@ -0,0 +1,46 @@ +package not + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var notTypeConstraints = [][]tensor.Dtype{{tensor.Bool}} + +// Not represents the ONNX not operator. +type Not struct { + ops.BaseOperator +} + +// newNot creates a new not operator. +func newNot(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Not{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "not", + ), + } +} + +// Init initializes the not operator. +func (n *Not) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the not operator. +func (n *Not) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := inputs[0].Apply(not) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func not(x bool) bool { + return !x +} diff --git a/ops/opset13/not_test.go b/ops/not/not_test.go similarity index 75% rename from ops/opset13/not_test.go rename to ops/not/not_test.go index 6069622..ffffc27 100644 --- a/ops/opset13/not_test.go +++ b/ops/not/not_test.go @@ -1,4 +1,4 @@ -package opset13 +package not import ( "testing" @@ -19,25 +19,25 @@ func TestNotInit(t *testing.T) { func TestNot(t *testing.T) { tests := []struct { - not *Not + version int64 backing []bool shape []int expected []bool }{ { - &Not{}, + 1, []bool{true, false, true, false}, []int{2, 2}, []bool{false, true, false, true}, }, { - &Not{}, + 1, []bool{true, true, false, false}, []int{1, 4}, []bool{false, false, true, true}, }, { - &Not{}, + 1, []bool{false, false, false, false}, []int{4, 1}, []bool{true, true, true, true}, @@ -49,7 +49,8 @@ func TestNot(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.not.Apply(inputs) + not := notVersions[test.version]() + res, err := not.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,29 +60,33 @@ func TestNot(t *testing.T) { func TestInputValidationNot(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 1, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), }, nil, }, { + 1, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Not{}), + ops.ErrInvalidInputCount(0, not1BaseOpFixture()), }, { + 1, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Not{}), + ops.ErrInvalidInputType(0, "int", not1BaseOpFixture()), }, } for _, test := range tests { - not := &Not{} + not := notVersions[test.version]() validated, err := not.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -91,3 +96,13 @@ func TestInputValidationNot(t *testing.T) { } } } + +func not1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator( + 1, + 1, + 1, + notTypeConstraints, + "not", + ) +} diff --git a/ops/not/versions.go b/ops/not/versions.go new file mode 100644 index 0000000..0b6769b --- /dev/null +++ b/ops/not/versions.go @@ -0,0 +1,11 @@ +package not + +import "github.com/advancedclimatesystems/gonnx/ops" + +var notVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newNot, 1, notTypeConstraints), +} + +func GetNotVersions() ops.OperatorVersions { + return notVersions +} diff --git a/ops/operator.go b/ops/operator.go index 7f26e4d..1d8a9b9 100644 --- a/ops/operator.go +++ b/ops/operator.go @@ -5,6 +5,18 @@ import ( "gorgonia.org/tensor" ) +type OperatorVersions map[int64]OperatorFactory + +type OperatorFactory func() Operator + +type Constructor func(int, [][]tensor.Dtype) Operator + +func NewOperatorConstructor(fn Constructor, version int, typeContstraint [][]tensor.Dtype) OperatorFactory { + return func() Operator { + return fn(version, typeContstraint) + } +} + // Operator is the base interface for all operators. type Operator interface { // String should return a simple string describing the operator @@ -33,4 +45,7 @@ type Operator interface { // ValidateInputs should validate the list of input tensors. It should check for both // the right amount of inputs and the correct dtypes of the tensors. ValidateInputs([]tensor.Tensor) ([]tensor.Tensor, error) + + // Version returns the version of this operator. + Version() int } diff --git a/ops/opset13/abs.go b/ops/opset13/abs.go deleted file mode 100644 index 482d80e..0000000 --- a/ops/opset13/abs.go +++ /dev/null @@ -1,63 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinAbsInputs = 1 - MaxAbsInputs = 1 -) - -// Abs represents the ONNX abs operator. -type Abs struct{} - -// newAbs creates a new abs operator. -func newAbs() ops.Operator { - return &Abs{} -} - -// Init initializes the abs operator. -func (a *Abs) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the abs operator. -func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := tensor.Abs(inputs[0]) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Abs) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Abs) GetMinInputs() int { - return MinAbsInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Abs) GetMaxInputs() int { - return MaxAbsInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Abs) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Abs) String() string { - return "abs operator" -} diff --git a/ops/opset13/acos.go b/ops/opset13/acos.go deleted file mode 100644 index 139c1ed..0000000 --- a/ops/opset13/acos.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Acos represents the ONNX acos operator. -type Acos struct{} - -// newAcos creates a new acos operator. -func newAcos() ops.Operator { - return &Acos{} -} - -// Init initializes the acos operator. -func (c *Acos) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the acos operator. -func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(acos[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(acos[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Acos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Acos) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Acos) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Acos) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Acos) String() string { - return "acos operator" -} - -func acos[T ops.FloatType](x T) T { - return T(math.Acos(float64(x))) -} diff --git a/ops/opset13/acosh.go b/ops/opset13/acosh.go deleted file mode 100644 index 0e1404c..0000000 --- a/ops/opset13/acosh.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Acosh represents the ONNX acosh operator. -type Acosh struct{} - -// newAcosh creates a new acosh operator. -func newAcosh() ops.Operator { - return &Acosh{} -} - -// Init initializes the acosh operator. -func (c *Acosh) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the acosh operator. -func (c *Acosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(acosh[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(acosh[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Acosh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Acosh) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Acosh) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Acosh) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Acosh) String() string { - return "acosh operator" -} - -func acosh[T ops.FloatType](x T) T { - return T(math.Acosh(float64(x))) -} diff --git a/ops/opset13/add.go b/ops/opset13/add.go deleted file mode 100644 index cf5b566..0000000 --- a/ops/opset13/add.go +++ /dev/null @@ -1,64 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinAddInputs = 2 - MaxAddInputs = 2 -) - -// Add represents the ONNX add operator. -type Add struct{} - -// newAdd creates a new add operator. -func newAdd() ops.Operator { - return &Add{} -} - -// Init initializes the add operator. -func (a *Add) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the add operator. -func (a *Add) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Add, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Add) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Add) GetMinInputs() int { - return MinAddInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Add) GetMaxInputs() int { - return MaxAddInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Add) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Add) String() string { - return "add operator" -} diff --git a/ops/opset13/and.go b/ops/opset13/and.go deleted file mode 100644 index 68b2a22..0000000 --- a/ops/opset13/and.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinAndInputs = 2 - MaxAndInputs = 2 -) - -// And represents the ONNX and operator. -type And struct{} - -// newAnd creates a new and operator. -func newAnd() ops.Operator { - return &And{} -} - -// Init initializes the and operator. -func (a *And) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the and operator. -func (a *And) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.And, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *And) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *And) GetMinInputs() int { - return MinAndInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *And) GetMaxInputs() int { - return MaxAndInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *And) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *And) String() string { - return "and operator" -} diff --git a/ops/opset13/asin.go b/ops/opset13/asin.go deleted file mode 100644 index 0dae65f..0000000 --- a/ops/opset13/asin.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Asin represents the ONNX asin operator. -type Asin struct{} - -// newSin creates a new asin operator. -func newAsin() ops.Operator { - return &Asin{} -} - -// Init initializes the asin operator. -func (s *Asin) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the asin operator. -func (s *Asin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(asin[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(asin[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Asin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Asin) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Asin) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Asin) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Asin) String() string { - return "asin operator" -} - -func asin[T ops.FloatType](x T) T { - return T(math.Asin(float64(x))) -} diff --git a/ops/opset13/asinh.go b/ops/opset13/asinh.go deleted file mode 100644 index 8490711..0000000 --- a/ops/opset13/asinh.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Asinh represents the ONNX asinh operator. -type Asinh struct{} - -// newAsinh creates a new asinh operator. -func newAsinh() ops.Operator { - return &Asinh{} -} - -// Init initializes the asinh operator. -func (a *Asinh) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the asinh operator. -func (a *Asinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(asinh[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(asinh[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Asinh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Asinh) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Asinh) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Asinh) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Asinh) String() string { - return "asinh operator" -} - -func asinh[T ops.FloatType](x T) T { - return T(math.Asinh(float64(x))) -} diff --git a/ops/opset13/atan.go b/ops/opset13/atan.go deleted file mode 100644 index d373d65..0000000 --- a/ops/opset13/atan.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Atan represents the ONNX atan operator. -type Atan struct{} - -// newAtan creates a new atan operator. -func newAtan() ops.Operator { - return &Atan{} -} - -// Init initializes the atan operator. -func (a *Atan) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the atan operator. -func (a *Atan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(atan[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(atan[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Atan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Atan) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Atan) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Atan) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Atan) String() string { - return "atan operator" -} - -func atan[T ops.FloatType](x T) T { - return T(math.Atan(float64(x))) -} diff --git a/ops/opset13/atanh.go b/ops/opset13/atanh.go deleted file mode 100644 index f60b6d1..0000000 --- a/ops/opset13/atanh.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Atanh represents the ONNX atanh operator. -type Atanh struct{} - -// newAtanh creates a new atanh operator. -func newAtanh() ops.Operator { - return &Atanh{} -} - -// Init initializes the atanh operator. -func (a *Atanh) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the atanh operator. -func (a *Atanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(atanh[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(atanh[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), a) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (a *Atanh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(a, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (a *Atanh) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (a *Atanh) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (a *Atanh) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (a *Atanh) String() string { - return "atanh operator" -} - -func atanh[T ops.FloatType](x T) T { - return T(math.Atanh(float64(x))) -} diff --git a/ops/opset13/cast.go b/ops/opset13/cast.go deleted file mode 100644 index 8a8a552..0000000 --- a/ops/opset13/cast.go +++ /dev/null @@ -1,81 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinCastInputs = 1 - MaxCastInputs = 1 -) - -// Cast represents the ONNX cast operator. -type Cast struct { - to int32 // DataType to cast to, as defined by TensorProto -} - -// newCast creates a new cast operator. -func newCast() ops.Operator { - return &Cast{} -} - -// Init initializes the cast operator. -func (c *Cast) Init(n *onnx.NodeProto) error { - attributes := n.GetAttribute() - - if len(attributes) != 1 { - return ops.ErrInvalidAttributeCount(1, len(attributes), c) - } - - attr := attributes[0] - if attr.GetName() == "to" { - c.to = int32(attr.GetI()) - } else { - return ops.ErrInvalidAttribute(attr.GetName(), c) - } - - return nil -} - -// Apply applies the cast operator. -func (c *Cast) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := ops.ConvertTensorDtype(inputs[0], c.to) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Cast) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Cast) GetMinInputs() int { - return MinCastInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Cast) GetMaxInputs() int { - return MaxCastInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Cast) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - { - tensor.Int16, tensor.Uint16, tensor.Int32, tensor.Uint32, tensor.Int64, tensor.Uint64, - tensor.Float32, tensor.Float64, - }, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Cast) String() string { - return "cast operator" -} diff --git a/ops/opset13/cos.go b/ops/opset13/cos.go deleted file mode 100644 index ad01f82..0000000 --- a/ops/opset13/cos.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Cos represents the ONNX cos operator. -type Cos struct{} - -// newCos creates a new cos operator. -func newCos() ops.Operator { - return &Cos{} -} - -// Init initializes the cos operator. -func (c *Cos) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the cos operator. -func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(cos[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(cos[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Cos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Cos) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Cos) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Cos) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Cos) String() string { - return "cos operator" -} - -func cos[T ops.FloatType](x T) T { - return T(math.Cos(float64(x))) -} diff --git a/ops/opset13/cosh.go b/ops/opset13/cosh.go deleted file mode 100644 index cddb129..0000000 --- a/ops/opset13/cosh.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Cosh represents the ONNX cosh operator. -type Cosh struct{} - -// newCosh creates a new cosh operator. -func newCosh() ops.Operator { - return &Cosh{} -} - -// Init initializes the cosh operator. -func (c *Cosh) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the cosh operator. -func (c *Cosh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(cosh[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(cosh[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (c *Cosh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(c, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (c *Cosh) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (c *Cosh) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (c *Cosh) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (c *Cosh) String() string { - return "cosh operator" -} - -func cosh[T ops.FloatType](x T) T { - return T(math.Cosh(float64(x))) -} diff --git a/ops/opset13/div.go b/ops/opset13/div.go deleted file mode 100644 index e918e7f..0000000 --- a/ops/opset13/div.go +++ /dev/null @@ -1,64 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinDivInputs = 2 - MaxDivInputs = 2 -) - -// Div represents the ONNX div operator. -type Div struct{} - -// newDiv creates a new div operator. -func newDiv() ops.Operator { - return &Div{} -} - -// Init initializes the div operator. -func (d *Div) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the div operator. -func (d *Div) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Div, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (d *Div) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(d, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (d *Div) GetMinInputs() int { - return MinDivInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (d *Div) GetMaxInputs() int { - return MaxDivInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (d *Div) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (d *Div) String() string { - return "div operator" -} diff --git a/ops/opset13/equal.go b/ops/opset13/equal.go deleted file mode 100644 index db888b8..0000000 --- a/ops/opset13/equal.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinEqualInputs = 2 - MaxEqualInputs = 2 -) - -// Equal represents the ONNX equal operator. -type Equal struct{} - -// newEqual creates a new equal operator. -func newEqual() ops.Operator { - return &Equal{} -} - -// Init initializes the equal operator. -func (e *Equal) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the equal operator. -func (e *Equal) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Equal, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (e *Equal) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(e, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (e *Equal) GetMinInputs() int { - return MinEqualInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (e *Equal) GetMaxInputs() int { - return MaxEqualInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (e *Equal) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (e *Equal) String() string { - return "equal operator" -} diff --git a/ops/opset13/flatten.go b/ops/opset13/flatten.go deleted file mode 100644 index 50e9039..0000000 --- a/ops/opset13/flatten.go +++ /dev/null @@ -1,95 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinFlattenInputs = 1 - MaxFlattenInputs = 1 -) - -// Flatten represents the ONNX flatten operator. -type Flatten struct { - axis int -} - -// newFlatten creates a new flatten operator. -func newFlatten() ops.Operator { - return &Flatten{ - axis: 1, - } -} - -// Init initializes the flatten operator. -func (f *Flatten) Init(n *onnx.NodeProto) error { - for _, attr := range n.GetAttribute() { - switch attr.GetName() { - case "axis": - f.axis = int(attr.GetI()) - default: - return ops.ErrInvalidAttribute(attr.GetName(), f) - } - } - - return nil -} - -// Apply applies the flatten operator. -func (f *Flatten) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - inputShape := inputs[0].Shape() - rank := len(inputShape) - - axis := f.axis - if axis < 0 { - axis = rank + axis - } - - out, ok := inputs[0].Clone().(tensor.Tensor) - if !ok { - return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) - } - - var err error - // In the special case where axis is 0, we reshape the tensor to shape - // (1, ). This is ONNX defined behaviour. - if axis == 0 { - err = out.Reshape(1, ops.NElements(inputShape...)) - } else { - err = out.Reshape(ops.NElements(inputShape[:axis]...), ops.NElements(inputShape[axis:]...)) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (f *Flatten) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(f, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (f *Flatten) GetMinInputs() int { - return MinFlattenInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (f *Flatten) GetMaxInputs() int { - return MaxFlattenInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (f *Flatten) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (f *Flatten) String() string { - return "flatten operator" -} diff --git a/ops/opset13/flatten_test.go b/ops/opset13/flatten_test.go deleted file mode 100644 index 4a750e1..0000000 --- a/ops/opset13/flatten_test.go +++ /dev/null @@ -1,162 +0,0 @@ -package opset13 - -import ( - "testing" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "github.com/stretchr/testify/assert" - "gorgonia.org/tensor" -) - -func TestFlattenInit(t *testing.T) { - f := &Flatten{} - - err := f.Init(&onnx.NodeProto{Attribute: []*onnx.AttributeProto{{Name: "axis", I: 2}}}) - assert.Nil(t, err) - - assert.Equal(t, 2, f.axis) -} - -func TestFlatten(t *testing.T) { - tests := []struct { - flatten *Flatten - backing []float32 - shape []int - expectedShape tensor.Shape - }{ - { - &Flatten{}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []int{1, 4}, - }, - { - &Flatten{}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []int{1, 6}, - }, - { - &Flatten{axis: 1}, - []float32{0, 1, 2, 3, 4, 5, 6, 7}, - []int{2, 2, 2}, - []int{2, 4}, - }, - { - &Flatten{axis: 2}, - []float32{0, 1, 2, 3, 4, 5, 6, 7}, - []int{2, 2, 2}, - []int{4, 2}, - }, - { - &Flatten{axis: -1}, - []float32{0, 1, 2, 3, 4, 5, 6, 7}, - []int{2, 2, 2}, - []int{4, 2}, - }, - { - &Flatten{axis: -2}, - []float32{0, 1, 2, 3, 4, 5, 6, 7}, - []int{2, 2, 2}, - []int{2, 4}, - }, - { - &Flatten{axis: -3}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, - []int{3, 2, 3}, - []int{1, 18}, - }, - { - &Flatten{axis: 2}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, - []int{3, 2, 3}, - []int{6, 3}, - }, - { - &Flatten{axis: 1}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, - []int{3, 2, 3}, - []int{3, 6}, - }, - } - - for _, test := range tests { - inputs := []tensor.Tensor{ - ops.TensorWithBackingFixture(test.backing, test.shape...), - } - - res, err := test.flatten.Apply(inputs) - assert.Nil(t, err) - - assert.Equal(t, test.expectedShape, res[0].Shape()) - } -} - -func TestInputValidationFlatten(t *testing.T) { - tests := []struct { - inputs []tensor.Tensor - err error - }{ - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]uint32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]uint64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float32{1, 2}, 2), - ops.TensorWithBackingFixture([]float32{1, 2}, 2), - }, - ops.ErrInvalidInputCount(2, &Flatten{}), - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int{1, 2}, 2), - }, - ops.ErrInvalidInputType(0, "int", &Flatten{}), - }, - } - - for _, test := range tests { - flatten := &Flatten{} - validated, err := flatten.ValidateInputs(test.inputs) - - assert.Equal(t, test.err, err) - - if test.err == nil { - assert.Equal(t, test.inputs, validated) - } - } -} diff --git a/ops/opset13/greater.go b/ops/opset13/greater.go deleted file mode 100644 index 37e5af4..0000000 --- a/ops/opset13/greater.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinGreaterInputs = 2 - MaxGreaterInputs = 2 -) - -// Greater represents the ONNX greater operator. -type Greater struct{} - -// newGreater creates a new greater operator. -func newGreater() ops.Operator { - return &Greater{} -} - -// Init initializes the greater operator. -func (g *Greater) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the greater operator. -func (g *Greater) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Gt, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *Greater) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(g, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *Greater) GetMinInputs() int { - return MinGreaterInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *Greater) GetMaxInputs() int { - return MaxGreaterInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (g *Greater) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (g *Greater) String() string { - return "greater operator" -} diff --git a/ops/opset13/greater_or_equal.go b/ops/opset13/greater_or_equal.go deleted file mode 100644 index 25eb27b..0000000 --- a/ops/opset13/greater_or_equal.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinGreaterOrEqualInputs = 2 - MaxGreaterOrEqualInputs = 2 -) - -// GreaterOrEqual represents the ONNX greaterOrEqual operator. -type GreaterOrEqual struct{} - -// newGreaterOrEqual creates a new greaterOrEqual operator. -func newGreaterOrEqual() ops.Operator { - return &GreaterOrEqual{} -} - -// Init initializes the greaterOrEqual operator. -func (g *GreaterOrEqual) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the greaterOrEqual operator. -func (g *GreaterOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Gte, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (g *GreaterOrEqual) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(g, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (g *GreaterOrEqual) GetMinInputs() int { - return MinGreaterOrEqualInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (g *GreaterOrEqual) GetMaxInputs() int { - return MaxGreaterOrEqualInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (g *GreaterOrEqual) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (g *GreaterOrEqual) String() string { - return "greaterOrEqual operator" -} diff --git a/ops/opset13/less.go b/ops/opset13/less.go deleted file mode 100644 index d8e271d..0000000 --- a/ops/opset13/less.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinLessInputs = 2 - MaxLessInputs = 2 -) - -// Less represents the ONNX less operator. -type Less struct{} - -// newLess creates a new less operator. -func newLess() ops.Operator { - return &Less{} -} - -// Init initializes the less operator. -func (l *Less) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the less operator. -func (l *Less) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Lt, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (l *Less) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(l, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (l *Less) GetMinInputs() int { - return MinLessInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (l *Less) GetMaxInputs() int { - return MaxLessInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (l *Less) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (l *Less) String() string { - return "less operator" -} diff --git a/ops/opset13/less_or_equal.go b/ops/opset13/less_or_equal.go deleted file mode 100644 index 3fcb85f..0000000 --- a/ops/opset13/less_or_equal.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinLessOrEqualInputs = 2 - MaxLessOrEqualInputs = 2 -) - -// LessOrEqual represents the ONNX lessOrEqual operator. -type LessOrEqual struct{} - -// newLessOrEqual creates a new lessOrEqual operator. -func newLessOrEqual() ops.Operator { - return &LessOrEqual{} -} - -// Init initializes the lessOrEqual operator. -func (l *LessOrEqual) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the lessOrEqual operator. -func (l *LessOrEqual) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Lte, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (l *LessOrEqual) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(l, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (l *LessOrEqual) GetMinInputs() int { - return MinLessOrEqualInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (l *LessOrEqual) GetMaxInputs() int { - return MaxLessOrEqualInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (l *LessOrEqual) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (l *LessOrEqual) String() string { - return "lessOrEqual operator" -} diff --git a/ops/opset13/mul.go b/ops/opset13/mul.go deleted file mode 100644 index 3d4db10..0000000 --- a/ops/opset13/mul.go +++ /dev/null @@ -1,64 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinMulInputs = 2 - MaxMulInputs = 2 -) - -// Mul represents the ONNX mul operator. -type Mul struct{} - -// newMul creates a new mul operator. -func newMul() ops.Operator { - return &Mul{} -} - -// Init initializes the mul operator. -func (m *Mul) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the mul operator. -func (m *Mul) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Mul, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (m *Mul) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(m, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (m *Mul) GetMinInputs() int { - return MinMulInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (m *Mul) GetMaxInputs() int { - return MaxMulInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (m *Mul) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (m *Mul) String() string { - return "mul operator" -} diff --git a/ops/opset13/not.go b/ops/opset13/not.go deleted file mode 100644 index ba69c56..0000000 --- a/ops/opset13/not.go +++ /dev/null @@ -1,60 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Not represents the ONNX not operator. -type Not struct{} - -// newNot creates a new not operator. -func newNot() ops.Operator { - return &Not{} -} - -// Init initializes the not operator. -func (n *Not) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the not operator. -func (n *Not) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := inputs[0].Apply(not) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (n *Not) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(n, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (n *Not) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (n *Not) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (n *Not) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Bool}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (n *Not) String() string { - return "not operator" -} - -func not(x bool) bool { - return !x -} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go deleted file mode 100644 index 3930ab3..0000000 --- a/ops/opset13/opset13.go +++ /dev/null @@ -1,83 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/ops" -) - -var operators13 = map[string]func() ops.Operator{ - "Abs": newAbs, - "Acos": newAcos, - "Acosh": newAcosh, - "Add": newAdd, - "And": newAnd, - "ArgMax": newArgMax, - "Asin": newAsin, - "Asinh": newAsinh, - "Atan": newAtan, - "Atanh": newAtanh, - "Cast": newCast, - "Concat": newConcat, - "Constant": newConstant, - "ConstantOfShape": newConstantOfShape, - "Conv": newConv, - "Cos": newCos, - "Cosh": newCosh, - "Div": newDiv, - "Equal": newEqual, - "Expand": newExpand, - "Flatten": newFlatten, - "Gather": newGather, - "Gemm": newGemm, - "Greater": newGreater, - "GreaterOrEqual": newGreaterOrEqual, - "GRU": newGRU, - "Less": newLess, - "LessOrEqual": newLessOrEqual, - "LinearRegressor": newLinearRegressor, - "LogSoftmax": newLogSoftmax, - "LSTM": newLSTM, - "MatMul": newMatMul, - "Mul": newMul, - "Not": newNot, - "Or": newOr, - "PRelu": newPRelu, - "ReduceMax": newReduceMax, - "ReduceMin": newReduceMin, - "Relu": newRelu, - "Reshape": newReshape, - "RNN": newRNN, - "Scaler": newScaler, - "Shape": newShape, - "Sigmoid": newSigmoid, - "Sin": newSin, - "Sinh": newSinh, - "Slice": newSlice, - "Softmax": newSoftmax, - "Squeeze": newSqueeze, - "Sub": newSub, - "Tan": newTan, - "Tanh": newTanh, - "Transpose": newTranspose, - "Unsqueeze": newUnsqueeze, - "Xor": newXor, -} - -// GetOperator maps strings as found in the ModelProto to Operators from opset 13. -func GetOperator(operatorType string) (ops.Operator, error) { - if opInit, ok := operators13[operatorType]; ok { - return opInit(), nil - } - - return nil, ops.ErrUnknownOperatorType(operatorType) -} - -// GetOpNames returns a list with operator names for opset 13. -func GetOpNames() []string { - opList := make([]string, 0, len(operators13)) - - for opName := range operators13 { - opList = append(opList, opName) - } - - return opList -} diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go deleted file mode 100644 index 01008ee..0000000 --- a/ops/opset13/opset13_test.go +++ /dev/null @@ -1,305 +0,0 @@ -package opset13 - -import ( - "testing" - - "github.com/advancedclimatesystems/gonnx/ops" - "github.com/stretchr/testify/assert" -) - -func TestGetOperator(t *testing.T) { - tests := []struct { - opType string - expected ops.Operator - err error - }{ - { - "Abs", - newAbs(), - nil, - }, - { - "Acos", - newAcos(), - nil, - }, - { - "Acosh", - newAcosh(), - nil, - }, - { - "Add", - newAdd(), - nil, - }, - { - "And", - newAnd(), - nil, - }, - { - "ArgMax", - newArgMax(), - nil, - }, - - { - "Asin", - newAsin(), - nil, - }, - { - "Asinh", - newAsinh(), - nil, - }, - { - "Atan", - newAtan(), - nil, - }, - { - "Atanh", - newAtanh(), - nil, - }, - { - "Cast", - newCast(), - nil, - }, - { - "Concat", - newConcat(), - nil, - }, - { - "Constant", - newConstant(), - nil, - }, - { - "ConstantOfShape", - newConstantOfShape(), - nil, - }, - { - "Conv", - newConv(), - nil, - }, - { - "Cos", - newCos(), - nil, - }, - { - "Cosh", - newCosh(), - nil, - }, - { - "Div", - newDiv(), - nil, - }, - { - "Equal", - newEqual(), - nil, - }, - { - "Expand", - newExpand(), - nil, - }, - { - "Flatten", - newFlatten(), - nil, - }, - { - "Gather", - newGather(), - nil, - }, - { - "Gemm", - newGemm(), - nil, - }, - { - "Greater", - newGreater(), - nil, - }, - { - "GreaterOrEqual", - newGreaterOrEqual(), - nil, - }, - { - "GRU", - newGRU(), - nil, - }, - { - "Less", - newLess(), - nil, - }, - { - "LessOrEqual", - newLessOrEqual(), - nil, - }, - { - "LinearRegressor", - newLinearRegressor(), - nil, - }, - { - "LogSoftmax", - newLogSoftmax(), - nil, - }, - { - "LSTM", - newLSTM(), - nil, - }, - { - "MatMul", - newMatMul(), - nil, - }, - { - "Mul", - newMul(), - nil, - }, - { - "Not", - newNot(), - nil, - }, - { - "Or", - newOr(), - nil, - }, - { - "PRelu", - newPRelu(), - nil, - }, - { - "ReduceMax", - newReduceMax(), - nil, - }, - { - "ReduceMin", - newReduceMin(), - nil, - }, - { - "Relu", - newRelu(), - nil, - }, - { - "Reshape", - newReshape(), - nil, - }, - { - "RNN", - newRNN(), - nil, - }, - { - "Scaler", - newScaler(), - nil, - }, - { - "Shape", - newShape(), - nil, - }, - { - "Sigmoid", - newSigmoid(), - nil, - }, - { - "Sin", - newSin(), - nil, - }, - { - "Sinh", - newSinh(), - nil, - }, - { - "Slice", - newSlice(), - nil, - }, - { - "Softmax", - newSoftmax(), - nil, - }, - { - "Squeeze", - newSqueeze(), - nil, - }, - { - "Sub", - newSub(), - nil, - }, - { - "Tan", - newTan(), - nil, - }, - { - "Tanh", - newTanh(), - nil, - }, - { - "Transpose", - newTranspose(), - nil, - }, - { - "Unsqueeze", - newUnsqueeze(), - nil, - }, - { - "Xor", - newXor(), - nil, - }, - { - "NotYetImplemented", - nil, - ops.ErrUnknownOperatorType("NotYetImplemented"), - }, - } - - for _, test := range tests { - op, err := GetOperator(test.opType) - - assert.Equal(t, test.expected, op) - assert.Equal(t, test.err, err) - } -} diff --git a/ops/opset13/or.go b/ops/opset13/or.go deleted file mode 100644 index f660891..0000000 --- a/ops/opset13/or.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinOrInputs = 2 - MaxOrInputs = 2 -) - -// Or represents the ONNX or operator. -type Or struct{} - -// newOr creates a new or operator. -func newOr() ops.Operator { - return &Or{} -} - -// Init initializes the or operator. -func (o *Or) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the or operator. -func (o *Or) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Or, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (o *Or) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(o, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (o *Or) GetMinInputs() int { - return MinOrInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (o *Or) GetMaxInputs() int { - return MaxOrInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (o *Or) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (o *Or) String() string { - return "or operator" -} diff --git a/ops/opset13/reduce_max_test.go b/ops/opset13/reduce_max_test.go deleted file mode 100644 index 508939c..0000000 --- a/ops/opset13/reduce_max_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package opset13 - -import ( - "testing" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "github.com/stretchr/testify/assert" - "gorgonia.org/tensor" -) - -func TestReduceMaxInit(t *testing.T) { - r := &ReduceMax{} - err := r.Init(&onnx.NodeProto{ - Attribute: []*onnx.AttributeProto{ - {Name: "axes", Ints: []int64{1, 3}}, - {Name: "keepdims", I: 0}, - }, - }) - - assert.Nil(t, err) - assert.Equal(t, []int{1, 3}, r.axes) - assert.Equal(t, false, r.keepDims) -} - -func TestReduceMax(t *testing.T) { - tests := []struct { - reduceMax *ReduceMax - backing []float32 - shape []int - expectedBacking []float32 - expectedShape tensor.Shape - }{ - { - &ReduceMax{axes: []int{0}, keepDims: false}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{2, 3}, - []int{2}, - }, - { - &ReduceMax{axes: []int{0}, keepDims: true}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{2, 3}, - []int{1, 2}, - }, - { - &ReduceMax{axes: []int{1}, keepDims: false}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{1, 3}, - []int{2}, - }, - { - &ReduceMax{axes: []int{1}, keepDims: true}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{1, 3}, - []int{2, 1}, - }, - { - &ReduceMax{axes: []int{0}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{3, 4, 5}, - []int{3}, - }, - { - &ReduceMax{axes: []int{0}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{3, 4, 5}, - []int{1, 3}, - }, - { - &ReduceMax{axes: []int{1}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{2, 5}, - []int{2}, - }, - { - &ReduceMax{axes: []int{1}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{2, 5}, - []int{2, 1}, - }, - { - &ReduceMax{axes: []int{1}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{3, 4, 5, 9, 10, 11}, - []int{2, 3}, - }, - { - &ReduceMax{axes: []int{1}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{3, 4, 5, 9, 10, 11}, - []int{2, 1, 3}, - }, - { - &ReduceMax{axes: []int{0, 1}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{9, 10, 11}, - []int{3}, - }, - { - &ReduceMax{axes: []int{0, 1}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{9, 10, 11}, - []int{1, 1, 3}, - }, - { - &ReduceMax{axes: []int{1, 2}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{5, 11}, - []int{2}, - }, - { - &ReduceMax{axes: []int{1, 2}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{5, 11}, - []int{2, 1, 1}, - }, - { - &ReduceMax{axes: []int{-1}, keepDims: true}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{1, 3}, - []int{2, 1}, - }, - } - - for _, test := range tests { - inputs := []tensor.Tensor{ - ops.TensorWithBackingFixture(test.backing, test.shape...), - } - - res, err := test.reduceMax.Apply(inputs) - assert.Nil(t, err) - - assert.Equal(t, test.expectedShape, res[0].Shape()) - assert.Equal(t, test.expectedBacking, res[0].Data()) - } -} - -func TestInputValidationReduceMax(t *testing.T) { - tests := []struct { - inputs []tensor.Tensor - err error - }{ - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]uint32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]uint64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int{1, 2}, 2), - ops.TensorWithBackingFixture([]int{3, 4}, 2), - }, - ops.ErrInvalidInputCount(2, &ReduceMax{}), - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int{1, 2}, 2), - }, - ops.ErrInvalidInputType(0, "int", &ReduceMax{}), - }, - } - - for _, test := range tests { - reduceMax := &ReduceMax{} - validated, err := reduceMax.ValidateInputs(test.inputs) - - assert.Equal(t, test.err, err) - - if test.err == nil { - assert.Equal(t, test.inputs, validated) - } - } -} diff --git a/ops/opset13/reduce_min_test.go b/ops/opset13/reduce_min_test.go deleted file mode 100644 index 572e36d..0000000 --- a/ops/opset13/reduce_min_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package opset13 - -import ( - "testing" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "github.com/stretchr/testify/assert" - "gorgonia.org/tensor" -) - -func TestReduceMinInit(t *testing.T) { - r := &ReduceMin{} - err := r.Init(&onnx.NodeProto{ - Attribute: []*onnx.AttributeProto{ - {Name: "axes", Ints: []int64{1, 3}}, - {Name: "keepdims", I: 0}, - }, - }) - - assert.Nil(t, err) - assert.Equal(t, []int{1, 3}, r.axes) - assert.Equal(t, false, r.keepDims) -} - -func TestReduceMin(t *testing.T) { - tests := []struct { - reduceMin *ReduceMin - backing []float32 - shape []int - expectedBacking []float32 - expectedShape tensor.Shape - }{ - { - &ReduceMin{axes: []int{0}, keepDims: false}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{0, 1}, - []int{2}, - }, - { - &ReduceMin{axes: []int{0}, keepDims: true}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{0, 1}, - []int{1, 2}, - }, - { - &ReduceMin{axes: []int{1}, keepDims: false}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{0, 2}, - []int{2}, - }, - { - &ReduceMin{axes: []int{1}, keepDims: true}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{0, 2}, - []int{2, 1}, - }, - { - &ReduceMin{axes: []int{0}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{0, 1, 2}, - []int{3}, - }, - { - &ReduceMin{axes: []int{0}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{0, 1, 2}, - []int{1, 3}, - }, - { - &ReduceMin{axes: []int{1}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{0, 3}, - []int{2}, - }, - { - &ReduceMin{axes: []int{1}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5}, - []int{2, 3}, - []float32{0, 3}, - []int{2, 1}, - }, - { - &ReduceMin{axes: []int{1}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{0, 1, 2, 6, 7, 8}, - []int{2, 3}, - }, - { - &ReduceMin{axes: []int{1}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{0, 1, 2, 6, 7, 8}, - []int{2, 1, 3}, - }, - { - &ReduceMin{axes: []int{0, 1}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{0, 1, 2}, - []int{3}, - }, - { - &ReduceMin{axes: []int{0, 1}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{0, 1, 2}, - []int{1, 1, 3}, - }, - { - &ReduceMin{axes: []int{1, 2}, keepDims: false}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{0, 6}, - []int{2}, - }, - { - &ReduceMin{axes: []int{1, 2}, keepDims: true}, - []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, - []int{2, 2, 3}, - []float32{0, 6}, - []int{2, 1, 1}, - }, - { - &ReduceMin{axes: []int{-1}, keepDims: true}, - []float32{0, 1, 2, 3}, - []int{2, 2}, - []float32{0, 2}, - []int{2, 1}, - }, - } - - for _, test := range tests { - inputs := []tensor.Tensor{ - ops.TensorWithBackingFixture(test.backing, test.shape...), - } - - res, err := test.reduceMin.Apply(inputs) - assert.Nil(t, err) - - assert.Equal(t, test.expectedShape, res[0].Shape()) - assert.Equal(t, test.expectedBacking, res[0].Data()) - } -} - -func TestInputValidationReduceMin(t *testing.T) { - tests := []struct { - inputs []tensor.Tensor - err error - }{ - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]uint32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]uint64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float32{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]float64{1, 2}, 2), - }, - nil, - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int{1, 2}, 2), - ops.TensorWithBackingFixture([]int{3, 4}, 2), - }, - ops.ErrInvalidInputCount(2, &ReduceMin{}), - }, - { - []tensor.Tensor{ - ops.TensorWithBackingFixture([]int{1, 2}, 2), - }, - ops.ErrInvalidInputType(0, "int", &ReduceMin{}), - }, - } - - for _, test := range tests { - reduceMin := &ReduceMin{} - validated, err := reduceMin.ValidateInputs(test.inputs) - - assert.Equal(t, test.err, err) - - if test.err == nil { - assert.Equal(t, test.inputs, validated) - } - } -} diff --git a/ops/opset13/relu.go b/ops/opset13/relu.go deleted file mode 100644 index 370940a..0000000 --- a/ops/opset13/relu.go +++ /dev/null @@ -1,58 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Relu represents the ONNX relu operator. -type Relu struct{} - -// newRelu creates a new relu operator. -func newRelu() ops.Operator { - return &Relu{} -} - -// Init initializes the relu operator. -func (r *Relu) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the relu operator. -func (r *Relu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := ops.ReLU(inputs[0]) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (r *Relu) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(r, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (r *Relu) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (r *Relu) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (r *Relu) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (r *Relu) String() string { - return "relu operator" -} diff --git a/ops/opset13/shape.go b/ops/opset13/shape.go deleted file mode 100644 index bb99709..0000000 --- a/ops/opset13/shape.go +++ /dev/null @@ -1,66 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinShapeInputs = 1 - MaxShapeInputs = 1 -) - -// Shape represents the ONNX shape operator. -type Shape struct{} - -// newShape creates a new shape operator. -func newShape() ops.Operator { - return &Shape{} -} - -// Init initializes the shape operator. -func (s *Shape) Init(*onnx.NodeProto) error { - return nil -} - -// Apply the shape operator to the graph. It creates a node that holds the shape of the -// input node as 1D int64 tensor. -func (s *Shape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - nodeShape := inputs[0].Shape() - shape := make([]int64, len(nodeShape)) - - for i, dimSize := range nodeShape { - shape[i] = int64(dimSize) - } - - out := tensor.New(tensor.WithShape(len(nodeShape)), tensor.WithBacking(shape)) - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Shape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Shape) GetMinInputs() int { - return MinShapeInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Shape) GetMaxInputs() int { - return MaxShapeInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Shape) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Shape) String() string { - return "shape operator" -} diff --git a/ops/opset13/sigmoid.go b/ops/opset13/sigmoid.go deleted file mode 100644 index b8bc077..0000000 --- a/ops/opset13/sigmoid.go +++ /dev/null @@ -1,55 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Sigmoid represents the ONNX sigmoid operator. -type Sigmoid struct{} - -// newSigmoid returns a new sigmoid operator. -func newSigmoid() ops.Operator { - return &Sigmoid{} -} - -// Init initializes the sigmoid operator. -func (s *Sigmoid) Init(*onnx.NodeProto) error { - return nil -} - -// Apply the sigmoid operator to the input node. -func (s *Sigmoid) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := ops.Sigmoid(inputs[0]) - - return []tensor.Tensor{out}, err -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Sigmoid) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Sigmoid) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Sigmoid) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Sigmoid) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Sigmoid) String() string { - return "sigmoid operator" -} diff --git a/ops/opset13/sin.go b/ops/opset13/sin.go deleted file mode 100644 index ff61a71..0000000 --- a/ops/opset13/sin.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Sin represents the ONNX sin operator. -type Sin struct{} - -// newSin creates a new sin operator. -func newSin() ops.Operator { - return &Sin{} -} - -// Init initializes the sin operator. -func (s *Sin) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the sin operator. -func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(sin[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(sin[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Sin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Sin) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Sin) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Sin) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Sin) String() string { - return "sin operator" -} - -func sin[T ops.FloatType](x T) T { - return T(math.Sin(float64(x))) -} diff --git a/ops/opset13/sinh.go b/ops/opset13/sinh.go deleted file mode 100644 index 19d81e7..0000000 --- a/ops/opset13/sinh.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Sinh represents the ONNX sinh operator. -type Sinh struct{} - -// newSin creates a new sinh operator. -func newSinh() ops.Operator { - return &Sinh{} -} - -// Init initializes the sinh operator. -func (s *Sinh) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the sinh operator. -func (s *Sinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(sinh[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(sinh[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Sinh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Sinh) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Sinh) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Sinh) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Sinh) String() string { - return "sinh operator" -} - -func sinh[T ops.FloatType](x T) T { - return T(math.Sinh(float64(x))) -} diff --git a/ops/opset13/sub.go b/ops/opset13/sub.go deleted file mode 100644 index 9c59508..0000000 --- a/ops/opset13/sub.go +++ /dev/null @@ -1,64 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinSubInputs = 2 - MaxSubInputs = 2 -) - -// Sub represents the ONNX sub operator. -type Sub struct{} - -// newSub creates a new sub operator. -func newSub() ops.Operator { - return &Sub{} -} - -// Init initializes the sub operator. -func (s *Sub) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the sub operator. -func (s *Sub) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Sub, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Sub) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Sub) GetMinInputs() int { - return MinSubInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Sub) GetMaxInputs() int { - return MaxSubInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Sub) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Sub) String() string { - return "sub operator" -} diff --git a/ops/opset13/tan.go b/ops/opset13/tan.go deleted file mode 100644 index a7b4a3b..0000000 --- a/ops/opset13/tan.go +++ /dev/null @@ -1,75 +0,0 @@ -package opset13 - -import ( - "math" - - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Tan represents the ONNX tan operator. -type Tan struct{} - -// newTan creates a new tan operator. -func newTan() ops.Operator { - return &Tan{} -} - -// Init initializes the tan operator. -func (t *Tan) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the tan operator. -func (t *Tan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - var ( - out tensor.Tensor - err error - ) - - switch inputs[0].Dtype() { - case tensor.Float32: - out, err = inputs[0].Apply(tan[float32]) - case tensor.Float64: - out, err = inputs[0].Apply(tan[float64]) - default: - return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t) - } - - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (t *Tan) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(t, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (t *Tan) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (t *Tan) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (t *Tan) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (t *Tan) String() string { - return "tan operator" -} - -func tan[T ops.FloatType](x T) T { - return T(math.Tan(float64(x))) -} diff --git a/ops/opset13/tanh.go b/ops/opset13/tanh.go deleted file mode 100644 index b435fb9..0000000 --- a/ops/opset13/tanh.go +++ /dev/null @@ -1,54 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -// Tanh represents the tanh operator. -type Tanh struct{} - -// newTanh returns a new tanh operator. -func newTanh() ops.Operator { - return &Tanh{} -} - -// Init initializes the sigmoid operator. -func (t *Tanh) Init(*onnx.NodeProto) error { - return nil -} - -// Apply the sigmoid operator to the input node. -func (t *Tanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := ops.Tanh(inputs[0]) - - return []tensor.Tensor{out}, err -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (t *Tanh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(t, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (t *Tanh) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (t *Tanh) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list with for every input tensor a list of allowed types. -func (t *Tanh) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - } -} - -// String returns a small name of the operator that can be used in formatting errors or logs. -func (t *Tanh) String() string { - return "tanh operator" -} diff --git a/ops/opset13/transpose.go b/ops/opset13/transpose.go deleted file mode 100644 index c89aa67..0000000 --- a/ops/opset13/transpose.go +++ /dev/null @@ -1,80 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -const ( - MinTransposeInputs = 1 - MaxTransposeInputs = 1 -) - -// Transpose represents the ONNX transpose operator. -type Transpose struct { - perm []int -} - -// newTranspose creates a new transpose operator. -func newTranspose() ops.Operator { - return &Transpose{} -} - -// Init initializes the transpose operator. -func (t *Transpose) Init(n *onnx.NodeProto) error { - attributes := n.GetAttribute() - - if len(attributes) != 1 { - return ops.ErrInvalidAttributeCount(1, len(attributes), t) - } - - attr := attributes[0] - - if attr.GetName() != "perm" { - return ops.ErrInvalidAttribute(attr.GetName(), t) - } - - attrPerm := attr.GetInts() - for _, val := range attrPerm { - t.perm = append(t.perm, int(val)) - } - - return nil -} - -// Apply applies the transpose operator. -func (t *Transpose) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - out, err := tensor.Transpose(inputs[0], t.perm...) - if err != nil { - return nil, err - } - - return []tensor.Tensor{out}, nil -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (t *Transpose) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(t, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (t *Transpose) GetMinInputs() int { - return MinTransposeInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (t *Transpose) GetMaxInputs() int { - return MaxTransposeInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (t *Transpose) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (t *Transpose) String() string { - return "transpose operator" -} diff --git a/ops/opset13/xor.go b/ops/opset13/xor.go deleted file mode 100644 index f668a69..0000000 --- a/ops/opset13/xor.go +++ /dev/null @@ -1,61 +0,0 @@ -package opset13 - -import ( - "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops" - "gorgonia.org/tensor" -) - -var ( - MinXorInputs = 2 - MaxXorInputs = 2 -) - -// Xor represents the ONNX xor operator. -type Xor struct{} - -// newXor creates a new xor operator. -func newXor() ops.Operator { - return &Xor{} -} - -// Init initializes the xor operator. -func (x *Xor) Init(*onnx.NodeProto) error { - return nil -} - -// Apply applies the xor operator. -func (x *Xor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ApplyBinaryOperation( - inputs[0], - inputs[1], - ops.Xor, - ops.MultidirectionalBroadcasting, - ) -} - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (x *Xor) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(x, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (x *Xor) GetMinInputs() int { - return MinXorInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (x *Xor) GetMaxInputs() int { - return MaxXorInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (x *Xor) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (x *Xor) String() string { - return "xor operator" -} diff --git a/ops/or/or.go b/ops/or/or.go new file mode 100644 index 0000000..7e48791 --- /dev/null +++ b/ops/or/or.go @@ -0,0 +1,42 @@ +package or + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var orTypeConstraints = [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} + +// Or represents the ONNX or operator. +type Or struct { + ops.BaseOperator +} + +// newOr creates a new or operator. +func newOr(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Or{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "or", + ), + } +} + +// Init initializes the or operator. +func (o *Or) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the or operator. +func (o *Or) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Or, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/or_test.go b/ops/or/or_test.go similarity index 80% rename from ops/opset13/or_test.go rename to ops/or/or_test.go index 1c370a2..6604298 100644 --- a/ops/opset13/or_test.go +++ b/ops/or/or_test.go @@ -1,4 +1,4 @@ -package opset13 +package or import ( "testing" @@ -19,31 +19,31 @@ func TestOrInit(t *testing.T) { func TestOr(t *testing.T) { tests := []struct { - or *Or + version int64 backings [][]bool shapes [][]int expected []bool }{ { - &Or{}, + 7, [][]bool{{true, false, true, false}, {true, true, true, false}}, [][]int{{2, 2}, {2, 2}}, []bool{true, true, true, false}, }, { - &Or{}, + 7, [][]bool{{true, false, true, false}, {true, false}}, [][]int{{2, 2}, {1, 2}}, []bool{true, false, true, false}, }, { - &Or{}, + 7, [][]bool{{true, false, true, false}, {true, false}}, [][]int{{2, 2}, {2, 1}}, []bool{true, true, true, false}, }, { - &Or{}, + 7, [][]bool{{true, false, true, false, true, false}, {false, false}}, [][]int{{3, 2}, {1, 2}}, []bool{true, false, true, false, true, false}, @@ -56,7 +56,8 @@ func TestOr(t *testing.T) { ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), } - res, err := test.or.Apply(inputs) + or := orVersions[test.version]() + res, err := or.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -66,10 +67,12 @@ func TestOr(t *testing.T) { func TestInputValidationOr(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), ops.TensorWithBackingFixture([]bool{false, false}, 2), @@ -77,22 +80,24 @@ func TestInputValidationOr(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), }, - ops.ErrInvalidInputCount(1, &Or{}), + ops.ErrInvalidInputCount(1, or7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(1, "int", &Or{}), + ops.ErrInvalidInputType(1, "int", or7BaseOpFixture()), }, } for _, test := range tests { - or := &Or{} + or := orVersions[test.version]() validated, err := or.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -102,3 +107,13 @@ func TestInputValidationOr(t *testing.T) { } } } + +func or7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator( + 7, + 2, + 2, + orTypeConstraints, + "or", + ) +} diff --git a/ops/or/versions.go b/ops/or/versions.go new file mode 100644 index 0000000..f1c38d0 --- /dev/null +++ b/ops/or/versions.go @@ -0,0 +1,11 @@ +package or + +import "github.com/advancedclimatesystems/gonnx/ops" + +var orVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newOr, 7, orTypeConstraints), +} + +func GetOrVersions() ops.OperatorVersions { + return orVersions +} diff --git a/ops/opset13/prelu.go b/ops/prelu/prelu.go similarity index 55% rename from ops/opset13/prelu.go rename to ops/prelu/prelu.go index bfdc5d2..b107471 100644 --- a/ops/opset13/prelu.go +++ b/ops/prelu/prelu.go @@ -1,4 +1,4 @@ -package opset13 +package prelu import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,17 +6,32 @@ import ( "gorgonia.org/tensor" ) -const ( - PReluMinInputs = 2 - PReluMaxInputs = 2 -) +var prelu7TypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, +} + +var preluTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} // PRelu represents the ONNX prelu operator. -type PRelu struct{} +type PRelu struct { + ops.BaseOperator +} // newPRelu creates a new prelu operator. -func newPRelu() ops.Operator { - return &PRelu{} +func newPRelu(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &PRelu{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "prelu", + ), + } } // Init initializes the prelu operator. @@ -51,7 +66,7 @@ func (op *PRelu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { case tensor.Int64: err = calcPRelu[int64](y.Data(), x.Data(), slope.Data()) default: - return nil, ops.ErrInvalidInputType(0, x.Dtype().String(), op) + return nil, ops.ErrInvalidInputType(0, x.Dtype().String(), op.BaseOperator) } if err != nil { @@ -61,45 +76,6 @@ func (op *PRelu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{y}, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (op *PRelu) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - inputs, err := ops.ValidateInputs(op, inputs) - if err != nil { - return nil, err - } - - x, slope := inputs[0], inputs[1] - if x.Dtype() != slope.Dtype() { - return nil, ops.ErrInvalidTensor("DType of 'slope' does not match DType of 'x'", op) - } - - return inputs, nil -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (op *PRelu) GetMinInputs() int { - return PReluMinInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (op *PRelu) GetMaxInputs() int { - return PReluMaxInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (op *PRelu) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (op *PRelu) String() string { - return "prelu operator" -} - func calcPRelu[T float32 | float64 | uint32 | uint64 | int32 | int64](result any, input any, slope any) error { var convertedResult []T diff --git a/ops/opset13/prelu_test.go b/ops/prelu/prelu_test.go similarity index 57% rename from ops/opset13/prelu_test.go rename to ops/prelu/prelu_test.go index 763cbfb..76b2c61 100644 --- a/ops/opset13/prelu_test.go +++ b/ops/prelu/prelu_test.go @@ -1,4 +1,4 @@ -package opset13 +package prelu import ( "testing" @@ -9,31 +9,38 @@ import ( ) func TestPReluInit(t *testing.T) { - p := &PRelu{} + tests := []struct { + version int64 + err error + }{ + {7, nil}, + {9, nil}, + } - // since the prelu does not have any attributes we pass in nil. This should not - // fail initializing the prelu. - err := p.Init(nil) - assert.Nil(t, err) + for _, test := range tests { + r := preluVersions[test.version]() + err := r.Init(nil) + assert.Equal(t, test.err, err) + } } func TestPRelu(t *testing.T) { tests := []struct { - prelu *PRelu + version int64 backing []float32 slope []float32 shape []int expected []float32 }{ { - &PRelu{}, + 7, []float32{-4, -4, -4, -3, -2, -1}, []float32{2, 2, 4, 4, 0, 0}, []int{3, 2}, []float32{-8, -8, -16, -12, 0, 0}, }, { - &PRelu{}, + 9, []float32{-4, -4, -4, 3, 2, 1}, []float32{2, 2, 4, 4, 0, 0}, []int{3, 2}, @@ -46,7 +53,8 @@ func TestPRelu(t *testing.T) { ops.TensorWithBackingFixture(test.backing, test.shape...), ops.TensorWithBackingFixture(test.slope, test.shape...), } - res, err := test.prelu.Apply(inputs) + prelu := preluVersions[test.version]() + res, err := prelu.Apply(inputs) assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) } @@ -54,10 +62,12 @@ func TestPRelu(t *testing.T) { func TestInputValidationPRelu(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -65,20 +75,46 @@ func TestInputValidationPRelu(t *testing.T) { nil, }, { + 9, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + }, + { + 9, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &PRelu{}), + ops.ErrInvalidInputCount(0, prelu7BaseOpFixture()), }, { + 7, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int32", prelu7BaseOpFixture()), + }, + { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &PRelu{}), + ops.ErrInvalidInputType(0, "int", prelu9BaseOpFixture()), }, } for _, test := range tests { - prelu := &PRelu{} + prelu := preluVersions[test.version]() validated, err := prelu.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -106,3 +142,11 @@ func BenchmarkPRelu_Apply(b *testing.B) { _ = y } } + +func prelu7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, prelu7TypeConstraints, "prelu") +} + +func prelu9BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(9, 2, 2, preluTypeConstraints, "prelu") +} diff --git a/ops/prelu/versions.go b/ops/prelu/versions.go new file mode 100644 index 0000000..4fe2bce --- /dev/null +++ b/ops/prelu/versions.go @@ -0,0 +1,15 @@ +package prelu + +import ( + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var preluVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newPRelu, 7, [][]tensor.Dtype{{tensor.Float32, tensor.Float64}, {tensor.Float32, tensor.Float64}}), + 9: ops.NewOperatorConstructor(newPRelu, 9, preluTypeConstraints), +} + +func GetPReluVersions() ops.OperatorVersions { + return preluVersions +} diff --git a/ops/reducemax/constants.go b/ops/reducemax/constants.go new file mode 100644 index 0000000..006e821 --- /dev/null +++ b/ops/reducemax/constants.go @@ -0,0 +1,6 @@ +package reducemax + +const ( + axes = "axes" + keepDims = "keepdims" +) diff --git a/ops/opset13/reduce_max.go b/ops/reducemax/reduce_max.go similarity index 58% rename from ops/opset13/reduce_max.go rename to ops/reducemax/reduce_max.go index 8276e4a..fd9cfe5 100644 --- a/ops/opset13/reduce_max.go +++ b/ops/reducemax/reduce_max.go @@ -1,4 +1,4 @@ -package opset13 +package reducemax import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +var reduceMaxTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint8, tensor.Int8, tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +var reduceMax11TypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + const ( MinReduceMaxAttributes = 1 MaxReduceMaxAttributes = 2 @@ -13,13 +21,22 @@ const ( // ReduceMax represents the ONNX reduceMax operator. type ReduceMax struct { + ops.BaseOperator + axes []int keepDims bool } // newReduceMax creates a new reduceMax operator. -func newReduceMax() ops.Operator { +func newReduceMax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &ReduceMax{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "reducemax", + ), axes: []int{}, keepDims: true, } @@ -34,14 +51,14 @@ func (r *ReduceMax) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": - axes, err := ops.AnyToIntSlice(attr.GetInts()) + case axes: + value, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } - r.axes = axes - case "keepdims": + r.axes = value + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) @@ -79,31 +96,3 @@ func (r *ReduceMax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (r *ReduceMax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(r, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (r *ReduceMax) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (r *ReduceMax) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (r *ReduceMax) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint8, tensor.Int8, tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (r *ReduceMax) String() string { - return "reduceMax operator" -} diff --git a/ops/reducemax/reduce_max_test.go b/ops/reducemax/reduce_max_test.go new file mode 100644 index 0000000..41ec060 --- /dev/null +++ b/ops/reducemax/reduce_max_test.go @@ -0,0 +1,389 @@ +package reducemax + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestReduceMaxInit(t *testing.T) { + tests := []struct { + version int64 + err error + }{ + {1, nil}, + {11, nil}, + {12, nil}, + {13, nil}, + } + + for _, test := range tests { + r, ok := reduceMaxVersions[test.version]().(*ReduceMax) + assert.True(t, ok) + + err := r.Init(&onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 3}}, + {Name: "keepdims", I: 0}, + }, + }) + + assert.Equal(t, test.err, err) + assert.Equal(t, []int{1, 3}, r.axes) + assert.Equal(t, false, r.keepDims) + } +} + +func TestReduceMax(t *testing.T) { + tests := []struct { + version int64 + attrs *onnx.NodeProto + backing []float32 + shape []int + expectedBacking []float32 + expectedShape tensor.Shape + }{ + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{2, 3}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{2, 3}, + []int{1, 2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{1, 3}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{1, 3}, + []int{2, 1}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{3, 4, 5}, + []int{3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{3, 4, 5}, + []int{1, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{2, 5}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{2, 5}, + []int{2, 1}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{3, 4, 5, 9, 10, 11}, + []int{2, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{3, 4, 5, 9, 10, 11}, + []int{2, 1, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0, 1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{9, 10, 11}, + []int{3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0, 1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{9, 10, 11}, + []int{1, 1, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 2}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{5, 11}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 2}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{5, 11}, + []int{2, 1, 1}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{-1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{1, 3}, + []int{2, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + reduceMax := reduceMaxVersions[test.version]() + err := reduceMax.Init(test.attrs) + assert.Nil(t, err) + + res, err := reduceMax.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expectedShape, res[0].Shape()) + assert.Equal(t, test.expectedBacking, res[0].Data()) + } +} + +func TestInputValidationReduceMax(t *testing.T) { + tests := []struct { + version int64 + inputs []tensor.Tensor + err error + }{ + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputCount(2, reduceMax13BaseOpFixture()), + }, + { + 1, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int8", reduceMax1BaseOpFixture()), + }, + { + 11, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "uint8", reduceMax11BaseOpFixture()), + }, + { + 12, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", reduceMax12BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", reduceMax13BaseOpFixture()), + }, + } + + for _, test := range tests { + reduceMax := reduceMaxVersions[test.version]() + validated, err := reduceMax.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func reduceMax1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 1, 1, reduceMax11TypeConstraints, "reducemax") +} + +func reduceMax11BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(11, 1, 1, reduceMax11TypeConstraints, "reducemax") +} + +func reduceMax12BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(12, 1, 1, reduceMaxTypeConstraints, "reducemax") +} + +func reduceMax13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, reduceMaxTypeConstraints, "reducemax") +} diff --git a/ops/reducemax/versions.go b/ops/reducemax/versions.go new file mode 100644 index 0000000..64b9806 --- /dev/null +++ b/ops/reducemax/versions.go @@ -0,0 +1,14 @@ +package reducemax + +import "github.com/advancedclimatesystems/gonnx/ops" + +var reduceMaxVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newReduceMax, 1, reduceMax11TypeConstraints), + 11: ops.NewOperatorConstructor(newReduceMax, 11, reduceMax11TypeConstraints), + 12: ops.NewOperatorConstructor(newReduceMax, 12, reduceMaxTypeConstraints), + 13: ops.NewOperatorConstructor(newReduceMax, 13, reduceMaxTypeConstraints), +} + +func GetReduceMaxVersions() ops.OperatorVersions { + return reduceMaxVersions +} diff --git a/ops/reducemin/constants.go b/ops/reducemin/constants.go new file mode 100644 index 0000000..4b152c6 --- /dev/null +++ b/ops/reducemin/constants.go @@ -0,0 +1,6 @@ +package reducemin + +const ( + axes = "axes" + keepDims = "keepdims" +) diff --git a/ops/opset13/reduce_min.go b/ops/reducemin/reduce_min.go similarity index 58% rename from ops/opset13/reduce_min.go rename to ops/reducemin/reduce_min.go index 38e9c49..34791f3 100644 --- a/ops/opset13/reduce_min.go +++ b/ops/reducemin/reduce_min.go @@ -1,4 +1,4 @@ -package opset13 +package reducemin import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,6 +6,14 @@ import ( "gorgonia.org/tensor" ) +var reduceMinTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint8, tensor.Int8, tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +var reduceMin11TypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + const ( MinReduceMinAttributes = 1 MaxReduceMinAttributes = 2 @@ -13,13 +21,22 @@ const ( // ReduceMin represents the ONNX reduceMin operator. type ReduceMin struct { + ops.BaseOperator + axes []int keepDims bool } // newReduceMin creates a new reduceMin operator. -func newReduceMin() ops.Operator { +func newReduceMin(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &ReduceMin{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "reducemin", + ), axes: []int{}, keepDims: true, } @@ -34,14 +51,14 @@ func (r *ReduceMin) Init(n *onnx.NodeProto) error { for _, attr := range attributes { switch attr.GetName() { - case "axes": - axes, err := ops.AnyToIntSlice(attr.GetInts()) + case axes: + value, err := ops.AnyToIntSlice(attr.GetInts()) if err != nil { return err } - r.axes = axes - case "keepdims": + r.axes = value + case keepDims: r.keepDims = attr.GetI() == 1 default: return ops.ErrInvalidAttribute(attr.GetName(), r) @@ -79,31 +96,3 @@ func (r *ReduceMin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (r *ReduceMin) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(r, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (r *ReduceMin) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (r *ReduceMin) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (r *ReduceMin) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Uint8, tensor.Int8, tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (r *ReduceMin) String() string { - return "reduceMin operator" -} diff --git a/ops/reducemin/reduce_min_test.go b/ops/reducemin/reduce_min_test.go new file mode 100644 index 0000000..55926b4 --- /dev/null +++ b/ops/reducemin/reduce_min_test.go @@ -0,0 +1,389 @@ +package reducemin + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestReduceMinInit(t *testing.T) { + tests := []struct { + version int64 + err error + }{ + {1, nil}, + {11, nil}, + {12, nil}, + {13, nil}, + } + + for _, test := range tests { + r, ok := reduceMinVersions[test.version]().(*ReduceMin) + assert.True(t, ok) + + err := r.Init(&onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 3}}, + {Name: "keepdims", I: 0}, + }, + }) + + assert.Equal(t, test.err, err) + assert.Equal(t, []int{1, 3}, r.axes) + assert.Equal(t, false, r.keepDims) + } +} + +func TestReduceMin(t *testing.T) { + tests := []struct { + version int64 + attrs *onnx.NodeProto + backing []float32 + shape []int + expectedBacking []float32 + expectedShape tensor.Shape + }{ + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0, 1}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0, 1}, + []int{1, 2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0, 2}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0, 2}, + []int{2, 1}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{0, 1, 2}, + []int{3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{0, 1, 2}, + []int{1, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{0, 3}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5}, + []int{2, 3}, + []float32{0, 3}, + []int{2, 1}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{0, 1, 2, 6, 7, 8}, + []int{2, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{0, 1, 2, 6, 7, 8}, + []int{2, 1, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0, 1}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{0, 1, 2}, + []int{3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0, 1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{0, 1, 2}, + []int{1, 1, 3}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 2}}, + {Name: "keepdims", I: 0}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{0, 6}, + []int{2}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 2}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + []int{2, 2, 3}, + []float32{0, 6}, + []int{2, 1, 1}, + }, + { + 13, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{-1}}, + {Name: "keepdims", I: 1}, + }, + }, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0, 2}, + []int{2, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + reduceMin := reduceMinVersions[test.version]() + err := reduceMin.Init(test.attrs) + assert.Nil(t, err) + + res, err := reduceMin.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expectedShape, res[0].Shape()) + assert.Equal(t, test.expectedBacking, res[0].Data()) + } +} + +func TestInputValidationReduceMin(t *testing.T) { + tests := []struct { + version int64 + inputs []tensor.Tensor + err error + }{ + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputCount(2, reduceMin13BaseOpFixture()), + }, + { + 1, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int8{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int8", reduceMin1BaseOpFixture()), + }, + { + 11, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint8{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "uint8", reduceMin11BaseOpFixture()), + }, + { + 12, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", reduceMin12BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", reduceMin13BaseOpFixture()), + }, + } + + for _, test := range tests { + reduceMin := reduceMinVersions[test.version]() + validated, err := reduceMin.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func reduceMin1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 1, 1, reduceMin11TypeConstraints, "reducemin") +} + +func reduceMin11BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(11, 1, 1, reduceMin11TypeConstraints, "reducemin") +} + +func reduceMin12BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(12, 1, 1, reduceMinTypeConstraints, "reducemin") +} + +func reduceMin13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, reduceMinTypeConstraints, "reducemin") +} diff --git a/ops/reducemin/versions.go b/ops/reducemin/versions.go new file mode 100644 index 0000000..9846f4e --- /dev/null +++ b/ops/reducemin/versions.go @@ -0,0 +1,14 @@ +package reducemin + +import "github.com/advancedclimatesystems/gonnx/ops" + +var reduceMinVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newReduceMin, 1, reduceMin11TypeConstraints), + 11: ops.NewOperatorConstructor(newReduceMin, 11, reduceMin11TypeConstraints), + 12: ops.NewOperatorConstructor(newReduceMin, 12, reduceMinTypeConstraints), + 13: ops.NewOperatorConstructor(newReduceMin, 13, reduceMinTypeConstraints), +} + +func GetReduceMinVersions() ops.OperatorVersions { + return reduceMinVersions +} diff --git a/ops/relu/relu.go b/ops/relu/relu.go new file mode 100644 index 0000000..0404b48 --- /dev/null +++ b/ops/relu/relu.go @@ -0,0 +1,42 @@ +package relu + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var reluTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Relu represents the ONNX relu operator. +type Relu struct { + ops.BaseOperator +} + +// newRelu creates a new relu operator. +func newRelu(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Relu{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "relu", + ), + } +} + +// Init initializes the relu operator. +func (r *Relu) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the relu operator. +func (r *Relu) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := ops.ReLU(inputs[0]) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/opset13/relu_test.go b/ops/relu/relu_test.go similarity index 52% rename from ops/opset13/relu_test.go rename to ops/relu/relu_test.go index b2d5fa0..3794d10 100644 --- a/ops/opset13/relu_test.go +++ b/ops/relu/relu_test.go @@ -1,4 +1,4 @@ -package opset13 +package relu import ( "testing" @@ -9,35 +9,42 @@ import ( ) func TestReluInit(t *testing.T) { - r := &Relu{} + tests := []struct { + version int64 + err error + }{ + {6, nil}, + {13, nil}, + } - // since the relu does not have any attributes we pass in nil. This should not - // fail initializing the relu. - err := r.Init(nil) - assert.Nil(t, err) + for _, test := range tests { + r := reluVersions[test.version]() + err := r.Init(nil) + assert.Equal(t, test.err, err) + } } func TestRelu(t *testing.T) { tests := []struct { - relu *Relu + version int64 backing []float32 shape []int expected []float32 }{ { - &Relu{}, + 6, []float32{-4, -4, -4, -3, -2, -1}, []int{3, 2}, []float32{0, 0, 0, 0, 0, 0}, }, { - &Relu{}, + 13, []float32{-4, -4, -4, 3, 2, 1}, []int{3, 2}, []float32{0, 0, 0, 3, 2, 1}, }, { - &Relu{}, + 13, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, []int{4, 3}, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, @@ -46,7 +53,10 @@ func TestRelu(t *testing.T) { for _, test := range tests { inputs := []tensor.Tensor{ops.TensorWithBackingFixture(test.backing, test.shape...)} - res, err := test.relu.Apply(inputs) + + relu := reluVersions[test.version]() + res, err := relu.Apply(inputs) + assert.Nil(t, err) assert.Equal(t, test.expected, res[0].Data()) } @@ -54,29 +64,55 @@ func TestRelu(t *testing.T) { func TestInputValidationRelu(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 6, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 6, []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, nil, }, { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + }, + { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, + nil, + }, + { + 6, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Relu{}), + ops.ErrInvalidInputCount(0, relu6BaseOpFixture()), }, { + 6, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &Relu{}), + ops.ErrInvalidInputType(0, "int", relu6BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, relu13BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + ops.ErrInvalidInputType(0, "int", relu13BaseOpFixture()), }, } for _, test := range tests { - relu := &Relu{} + relu := reluVersions[test.version]() + validated, err := relu.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -86,3 +122,11 @@ func TestInputValidationRelu(t *testing.T) { } } } + +func relu6BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(6, 1, 1, reluTypeConstraints, "relu") +} + +func relu13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, reluTypeConstraints, "relu") +} diff --git a/ops/relu/versions.go b/ops/relu/versions.go new file mode 100644 index 0000000..2ecb090 --- /dev/null +++ b/ops/relu/versions.go @@ -0,0 +1,12 @@ +package relu + +import "github.com/advancedclimatesystems/gonnx/ops" + +var reluVersions = ops.OperatorVersions{ + 6: ops.NewOperatorConstructor(newRelu, 6, reluTypeConstraints), + 13: ops.NewOperatorConstructor(newRelu, 13, reluTypeConstraints), +} + +func GetReluVersions() ops.OperatorVersions { + return reluVersions +} diff --git a/ops/opset13/reshape.go b/ops/reshape/reshape.go similarity index 62% rename from ops/opset13/reshape.go rename to ops/reshape/reshape.go index a2a8f59..446a041 100644 --- a/ops/opset13/reshape.go +++ b/ops/reshape/reshape.go @@ -1,4 +1,4 @@ -package opset13 +package reshape import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,17 +6,29 @@ import ( "gorgonia.org/tensor" ) +var reshapeTypeConstraints = [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} + const ( ReshapeMinInputs = 2 ReshapeMaxInputs = 2 ) // Reshape represents the ONNX reshape operator. -type Reshape struct{} +type Reshape struct { + ops.BaseOperator +} // newReshape creates a new reshape operator. -func newReshape() ops.Operator { - return &Reshape{} +func newReshape(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Reshape{ + BaseOperator: ops.NewBaseOperator( + version, + ReshapeMinInputs, + ReshapeMaxInputs, + typeConstraints, + "reshape", + ), + } } // Init initializes the reshape operator. @@ -48,32 +60,6 @@ func (r *Reshape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, err } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (r *Reshape) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(r, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (r *Reshape) GetMinInputs() int { - return ReshapeMinInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (r *Reshape) GetMaxInputs() int { - return ReshapeMaxInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (r *Reshape) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (r *Reshape) String() string { - return "reshape operator" -} - func processShape(newShape, currentShape []int) error { for i := 0; i < len(newShape); i++ { if newShape[i] == 0 { diff --git a/ops/opset13/reshape_test.go b/ops/reshape/reshape_test.go similarity index 63% rename from ops/opset13/reshape_test.go rename to ops/reshape/reshape_test.go index 8651d32..9f02eea 100644 --- a/ops/opset13/reshape_test.go +++ b/ops/reshape/reshape_test.go @@ -1,4 +1,4 @@ -package opset13 +package reshape import ( "testing" @@ -9,41 +9,54 @@ import ( ) func TestReshapeInit(t *testing.T) { - r := &Reshape{} + tests := []struct { + version int64 + err error + }{ + {5, nil}, + {13, nil}, + } - // since the reshape does not have any attributes we pass in nil. This should not - // fail initializing the reshape. - err := r.Init(nil) - assert.Nil(t, err) + for _, test := range tests { + r := reshapeVersions[test.version]() + err := r.Init(nil) + assert.Equal(t, test.err, err) + } } func TestReshape(t *testing.T) { tests := []struct { + version int64 inputShape []int newShape []int64 expected tensor.Shape }{ { + 5, []int{2, 3}, []int64{1, 6}, []int{1, 6}, }, { + 13, []int{1, 2, 3}, []int64{0, 2, 3}, []int{1, 2, 3}, }, { + 13, []int{1, 2, 3}, []int64{1, -1, 2}, []int{1, 3, 2}, }, { + 13, []int{1, 2, 3}, []int64{1, -1}, []int{1, 6}, }, { + 13, []int{3, 4, 2}, []int64{1, 0, -1}, []int{1, 4, 6}, @@ -51,7 +64,7 @@ func TestReshape(t *testing.T) { } for _, test := range tests { - reshape := &Reshape{} + reshape := reshapeVersions[test.version]() inputs := []tensor.Tensor{ ops.Float32TensorFixture(test.inputShape...), tensor.New(tensor.WithBacking(test.newShape)), @@ -64,10 +77,12 @@ func TestReshape(t *testing.T) { func TestInputValidationReshape(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 5, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -75,6 +90,7 @@ func TestInputValidationReshape(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -82,20 +98,27 @@ func TestInputValidationReshape(t *testing.T) { nil, }, { + 5, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputCount(1, &Reshape{}), + ops.ErrInvalidInputCount(1, reshape5BaseOpFixture()), }, { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + ops.ErrInvalidInputCount(1, reshape13BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(1, "int", &Reshape{}), + ops.ErrInvalidInputType(1, "int", reshape13BaseOpFixture()), }, } for _, test := range tests { - reshape := &Reshape{} + reshape := reshapeVersions[test.version]() validated, err := reshape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -105,3 +128,11 @@ func TestInputValidationReshape(t *testing.T) { } } } + +func reshape5BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(5, 2, 2, reshapeTypeConstraints, "reshape") +} + +func reshape13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, reshapeTypeConstraints, "reshape") +} diff --git a/ops/reshape/versions.go b/ops/reshape/versions.go new file mode 100644 index 0000000..2985725 --- /dev/null +++ b/ops/reshape/versions.go @@ -0,0 +1,12 @@ +package reshape + +import "github.com/advancedclimatesystems/gonnx/ops" + +var reshapeVersions = ops.OperatorVersions{ + 5: ops.NewOperatorConstructor(newReshape, 5, reshapeTypeConstraints), + 13: ops.NewOperatorConstructor(newReshape, 13, reshapeTypeConstraints), +} + +func GetReshapeVersions() ops.OperatorVersions { + return reshapeVersions +} diff --git a/ops/opset13/rnn.go b/ops/rnn/rnn.go similarity index 80% rename from ops/opset13/rnn.go rename to ops/rnn/rnn.go index b3248d8..a3d9c13 100644 --- a/ops/opset13/rnn.go +++ b/ops/rnn/rnn.go @@ -1,11 +1,21 @@ -package opset13 +package rnn import ( "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" + "github.com/advancedclimatesystems/gonnx/ops/gemm" "gorgonia.org/tensor" ) +var rnnTypeConstraints = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Int32}, + {tensor.Float32, tensor.Float64}, +} + const ( MinRNNInputs = 3 MaxRNNInputs = 6 @@ -13,6 +23,8 @@ const ( // RNN represents the ONNX rnn operator. type RNN struct { + ops.BaseOperator + activationAlpha []float32 activationBeta []float32 activations []string @@ -21,8 +33,15 @@ type RNN struct { } // newRNN creates a new rnn operator. -func newRNN() ops.Operator { +func newRNN(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &RNN{ + BaseOperator: ops.NewBaseOperator( + version, + MinRNNInputs, + MaxRNNInputs, + typeConstraints, + "rnn", + ), activations: []string{"tanh"}, direction: ops.Forward, } @@ -63,7 +82,7 @@ func (r *RNN) Init(n *onnx.NodeProto) error { // Apply applies the rnn operator. func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { if inputs[4] != nil { - return nil, ops.ErrUnsupportedInput("sequence lens", r) + return nil, ops.ErrUnsupportedInput("sequence lens", r.BaseOperator) } X := inputs[0] @@ -154,39 +173,6 @@ func (r *RNN) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{Y, Yh}, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (r *RNN) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(r, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (r *RNN) GetMinInputs() int { - return MinRNNInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (r *RNN) GetMaxInputs() int { - return MaxRNNInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (r *RNN) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Float32, tensor.Float64}, - {tensor.Int32}, - {tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (r *RNN) String() string { - return "rnn operator" -} - // layerCalculation performs the actual RNN calculation. By ONNX definition // this is: // @@ -197,7 +183,21 @@ func (r *RNN) String() string { func (r *RNN) layerCalculation( Xt, H, Wi, Ri, Wbi, Rbi tensor.Tensor, activation ops.Activation, ) (tensor.Tensor, error) { - gemm := &Gemm{transA: false, transB: true, alpha: 1.0, beta: 1.0} + gemm := gemm.GetGemmVersions()[13]() + + err := gemm.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "alpha", F: 1.0}, + {Name: "beta", F: 1.0}, + {Name: "transA", I: 0}, + {Name: "transB", I: 1}, + }, + }, + ) + if err != nil { + return nil, err + } inputCalc, err := gemm.Apply([]tensor.Tensor{Xt, Wi, Wbi}) if err != nil { diff --git a/ops/opset13/rnn_test.go b/ops/rnn/rnn_test.go similarity index 74% rename from ops/opset13/rnn_test.go rename to ops/rnn/rnn_test.go index f24b934..fb874ff 100644 --- a/ops/opset13/rnn_test.go +++ b/ops/rnn/rnn_test.go @@ -1,4 +1,4 @@ -package opset13 +package rnn import ( "math/rand" @@ -36,66 +36,83 @@ func TestRNNInitUnknownAttr(t *testing.T) { func TestRNN(t *testing.T) { tests := []struct { - rnn *RNN + version int64 + attrs *onnx.NodeProto inputs ops.InputFixture expected []float32 err error }{ { - &RNN{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"tanh"}, - direction: ops.Forward, - hiddenSize: 4, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, }, rnnInput0, []float32{0.78036773, 0.97858655, 0.94110376, 0.90722954}, nil, }, { - &RNN{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"sigmoid"}, - direction: ops.Forward, - hiddenSize: 4, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("sigmoid")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, }, rnnInput0, []float32{0.82048327, 0.922734, 0.89050114, 0.8620579}, nil, }, { - &RNN{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"relu"}, - direction: ops.Forward, - hiddenSize: 4, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("relu")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, }, + rnnInput0, []float32{1.0667435, 2.328037, 1.7986122, 1.545068}, nil, }, { - &RNN{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"tanh"}, - direction: ops.Forward, - hiddenSize: 10, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 10}, + }, }, rnnInput1, []float32{0.99996024, 0.9999855, 0.99998087, 0.9999288, 0.9997511, 0.99918234, 0.99999964, 0.9999981, 0.9997658, 0.9999618, 0.9998762, 0.9999353, 0.9999194, 0.9999428, 0.9997284, 0.9982606, 0.999999, 0.9999897, 0.99964744, 0.9998234, 0.99997497, 0.9999893, 0.9999906, 0.9999812, 0.99983937, 0.99967873, 0.9999998, 0.9999965, 0.9999516, 0.9999541}, nil, }, { - &RNN{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"tanh"}, - direction: ops.Forward, - hiddenSize: 4, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, }, rnnInputNoB, // Same values as first test, but B is initialized automatically. @@ -103,12 +120,15 @@ func TestRNN(t *testing.T) { nil, }, { - &RNN{ - activationAlpha: []float32{}, - activationBeta: []float32{}, - activations: []string{"tanh"}, - direction: ops.Forward, - hiddenSize: 4, + 7, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "activation_alpha", Floats: []float32{}}, + {Name: "activation_beta", Floats: []float32{}}, + {Name: "activations", Strings: [][]byte{[]byte("tanh")}}, + {Name: "direction", S: []byte("forward")}, + {Name: "hidden_size", I: 4}, + }, }, rnnInputNoBNoH, // Same values as first test, but B and H are initialized automatically. @@ -119,7 +139,12 @@ func TestRNN(t *testing.T) { for _, test := range tests { inputs := test.inputs() - res, err := test.rnn.Apply(inputs) + + rnn := rnnVersions[test.version]() + err := rnn.Init(test.attrs) + assert.Nil(t, err) + + res, err := rnn.Apply(inputs) assert.Equal(t, test.err, err) if err == nil { @@ -130,11 +155,13 @@ func TestRNN(t *testing.T) { func TestInputValidationRNN(t *testing.T) { tests := []struct { + version int64 inputs []tensor.Tensor expected []tensor.Tensor err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -147,6 +174,7 @@ func TestInputValidationRNN(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -163,38 +191,43 @@ func TestInputValidationRNN(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, - ops.ErrInvalidOptionalInputCount(1, &RNN{}), + ops.ErrInvalidOptionalInputCount(1, rnn7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(0, "int", &RNN{}), + ops.ErrInvalidInputType(0, "int", rnn7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(1, "int", &RNN{}), + ops.ErrInvalidInputType(1, "int", rnn7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(2, "int", &RNN{}), + ops.ErrInvalidInputType(2, "int", rnn7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -202,9 +235,10 @@ func TestInputValidationRNN(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(3, "int", &RNN{}), + ops.ErrInvalidInputType(3, "int", rnn7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -213,9 +247,10 @@ func TestInputValidationRNN(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(4, "float32", &RNN{}), + ops.ErrInvalidInputType(4, "float32", rnn7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{1, 2}, 2), @@ -225,12 +260,12 @@ func TestInputValidationRNN(t *testing.T) { ops.TensorWithBackingFixture([]int{1, 2}, 2), }, nil, - ops.ErrInvalidInputType(5, "int", &RNN{}), + ops.ErrInvalidInputType(5, "int", rnn7BaseOpFixture()), }, } for _, test := range tests { - rnn := &RNN{} + rnn := rnnVersions[test.version]() validated, err := rnn.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -332,3 +367,7 @@ func RNNOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func rnn7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 3, 6, rnnTypeConstraints, "rnn") +} diff --git a/ops/rnn/versions.go b/ops/rnn/versions.go new file mode 100644 index 0000000..3d607e2 --- /dev/null +++ b/ops/rnn/versions.go @@ -0,0 +1,11 @@ +package rnn + +import "github.com/advancedclimatesystems/gonnx/ops" + +var rnnVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newRNN, 7, rnnTypeConstraints), +} + +func GetRNNVersions() ops.OperatorVersions { + return rnnVersions +} diff --git a/ops/opset13/scaler.go b/ops/scaler/scaler.go similarity index 58% rename from ops/opset13/scaler.go rename to ops/scaler/scaler.go index c5eb53b..15f05a7 100644 --- a/ops/opset13/scaler.go +++ b/ops/scaler/scaler.go @@ -1,4 +1,4 @@ -package opset13 +package scaler import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,21 +6,33 @@ import ( "gorgonia.org/tensor" ) +var scalerTypeConstraints = [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + const ( ScalerExpectedAttributes = 2 - MinScalerInputs = 1 - MaxScalerInputs = 1 ) // Scaler represents the ONNX-ml scaler operator. type Scaler struct { + ops.BaseOperator + offset tensor.Tensor scale tensor.Tensor } // newScaler creates a new scaler operator. -func newScaler() ops.Operator { - return &Scaler{} +func newScaler(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Scaler{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "scaler", + ), + } } // Init initializes the scaler operator. @@ -70,31 +82,3 @@ func (s *Scaler) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{Y}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Scaler) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Scaler) GetMinInputs() int { - return MinScalerInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Scaler) GetMaxInputs() int { - return MaxScalerInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Scaler) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - {tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Scaler) String() string { - return "scaler operator" -} diff --git a/ops/opset13/scaler_test.go b/ops/scaler/scaler_test.go similarity index 89% rename from ops/opset13/scaler_test.go rename to ops/scaler/scaler_test.go index 2a6dfae..7561d2f 100644 --- a/ops/opset13/scaler_test.go +++ b/ops/scaler/scaler_test.go @@ -1,4 +1,4 @@ -package opset13 +package scaler import ( "testing" @@ -87,37 +87,44 @@ func TestScaler(t *testing.T) { func TestInputValidationScaler(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int32{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int64{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Scaler{}), + ops.ErrInvalidInputCount(0, scaler1BaseOpFixture()), }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &Scaler{}), + ops.ErrInvalidInputType(0, "int", scaler1BaseOpFixture()), }, } for _, test := range tests { - scaler := &Scaler{} + scaler := scalerVersions[test.version]() validated, err := scaler.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -136,3 +143,7 @@ func ScalerOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func scaler1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 1, 1, scalerTypeConstraints, "scaler") +} diff --git a/ops/scaler/versions.go b/ops/scaler/versions.go new file mode 100644 index 0000000..73b0b94 --- /dev/null +++ b/ops/scaler/versions.go @@ -0,0 +1,11 @@ +package scaler + +import "github.com/advancedclimatesystems/gonnx/ops" + +var scalerVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newScaler, 1, scalerTypeConstraints), +} + +func GetScalerVersions() ops.OperatorVersions { + return scalerVersions +} diff --git a/ops/shape/shape.go b/ops/shape/shape.go new file mode 100644 index 0000000..caf3f8c --- /dev/null +++ b/ops/shape/shape.go @@ -0,0 +1,47 @@ +package shape + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var shapeTypeConstraints = [][]tensor.Dtype{ops.AllTypes} + +// Shape represents the ONNX shape operator. +type Shape struct { + ops.BaseOperator +} + +// newShape creates a new shape operator. +func newShape(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Shape{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "shape", + ), + } +} + +// Init initializes the shape operator. +func (s *Shape) Init(*onnx.NodeProto) error { + return nil +} + +// Apply the shape operator to the graph. It creates a node that holds the shape of the +// input node as 1D int64 tensor. +func (s *Shape) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + nodeShape := inputs[0].Shape() + shape := make([]int64, len(nodeShape)) + + for i, dimSize := range nodeShape { + shape[i] = int64(dimSize) + } + + out := tensor.New(tensor.WithShape(len(nodeShape)), tensor.WithBacking(shape)) + + return []tensor.Tensor{out}, nil +} diff --git a/ops/opset13/shape_test.go b/ops/shape/shape_test.go similarity index 74% rename from ops/opset13/shape_test.go rename to ops/shape/shape_test.go index 1ab9382..8d65d3a 100644 --- a/ops/opset13/shape_test.go +++ b/ops/shape/shape_test.go @@ -1,4 +1,4 @@ -package opset13 +package shape import ( "testing" @@ -19,21 +19,24 @@ func TestShapeInit(t *testing.T) { func TestShape(t *testing.T) { tests := []struct { + version int64 inputShape []int expected []int64 }{ { + 1, []int{1, 2, 3, 4}, []int64{1, 2, 3, 4}, }, { + 13, []int{2, 3}, []int64{2, 3}, }, } for _, test := range tests { - shape := &Shape{} + shape := shapeVersions[test.version]() inputs := []tensor.Tensor{ ops.Float32TensorFixture(test.inputShape...), } @@ -46,29 +49,34 @@ func TestShape(t *testing.T) { func TestInputValidationShape(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]uint32{3, 4}, 2)}, nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{3, 4}, 2)}, nil, }, { + 13, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Shape{}), + ops.ErrInvalidInputCount(0, shape13BaseOpFixture()), }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &Shape{}), + ops.ErrInvalidInputType(0, "int", shape13BaseOpFixture()), }, } for _, test := range tests { - shape := &Shape{} + shape := shapeVersions[test.version]() validated, err := shape.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -78,3 +86,7 @@ func TestInputValidationShape(t *testing.T) { } } } + +func shape13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, shapeTypeConstraints, "shape") +} diff --git a/ops/shape/versions.go b/ops/shape/versions.go new file mode 100644 index 0000000..9b140e0 --- /dev/null +++ b/ops/shape/versions.go @@ -0,0 +1,12 @@ +package shape + +import "github.com/advancedclimatesystems/gonnx/ops" + +var shapeVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newShape, 1, shapeTypeConstraints), + 13: ops.NewOperatorConstructor(newShape, 13, shapeTypeConstraints), +} + +func GetShapeVersions() ops.OperatorVersions { + return shapeVersions +} diff --git a/ops/sigmoid/sigmoid.go b/ops/sigmoid/sigmoid.go new file mode 100644 index 0000000..ff4015c --- /dev/null +++ b/ops/sigmoid/sigmoid.go @@ -0,0 +1,39 @@ +package sigmoid + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var sigmoidTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Sigmoid represents the ONNX sigmoid operator. +type Sigmoid struct { + ops.BaseOperator +} + +// newSigmoid returns a new sigmoid operator. +func newSigmoid(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Sigmoid{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "sigmoid", + ), + } +} + +// Init initializes the sigmoid operator. +func (s *Sigmoid) Init(*onnx.NodeProto) error { + return nil +} + +// Apply the sigmoid operator to the input node. +func (s *Sigmoid) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := ops.Sigmoid(inputs[0]) + + return []tensor.Tensor{out}, err +} diff --git a/ops/opset13/sigmoid_test.go b/ops/sigmoid/sigmoid_test.go similarity index 64% rename from ops/opset13/sigmoid_test.go rename to ops/sigmoid/sigmoid_test.go index 3277a6f..a835262 100644 --- a/ops/opset13/sigmoid_test.go +++ b/ops/sigmoid/sigmoid_test.go @@ -1,4 +1,4 @@ -package opset13 +package sigmoid import ( "testing" @@ -9,7 +9,7 @@ import ( ) func TestSigmoidInit(t *testing.T) { - s := newSigmoid() + s := &Sigmoid{} // Since the sigmoid does not have any attributes we expect it to initialize even // when nil is passed. err := s.Init(nil) @@ -18,33 +18,48 @@ func TestSigmoidInit(t *testing.T) { func TestSigmoid(t *testing.T) { tests := []struct { + version int64 backing []float32 shape []int expected []float32 }{ { + 6, + []float32{-2, -1, 0, 3}, + []int{2, 2}, + []float32{ + 0.11920292, + 0.26894143, + 0.5, + 0.95257413, + }, + }, + { + 13, []float32{-4, -3, -2, -1, 0, 12}, []int{3, 2}, []float32{ 0.01798620996209155802679, 0.04742587317756678087885, 0.1192029220221175559403, - 0.2689414213699951207488, 0.5, + 0.26894142699951207488, 0.5, 0.9999938558253977852822, }, }, { + 13, []float32{-4, -4, -4, 3, 2, 1}, []int{3, 2}, []float32{ 0.01798621, 0.01798621, 0.01798621, - 0.95257413, 0.8807971, 0.7310586, + 0.952574, 0.8807971, 0.7310586, }, }, { + 13, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, []int{4, 3}, []float32{ - 0.5, 0.7310586, 0.8807971, 0.95257413, + 0.5, 0.7310586, 0.8807971, 0.952574, 0.98201376, 0.9933072, 0.99752736, 0.99908894, 0.99966466, 0.9998766, 0.9999546, 0.9999833, }, @@ -52,7 +67,7 @@ func TestSigmoid(t *testing.T) { } for _, test := range tests { - sigmoid := &Sigmoid{} + sigmoid := sigmoidVersions[test.version]() inputs := []tensor.Tensor{ ops.TensorWithBackingFixture(test.backing, test.shape...), } @@ -65,29 +80,39 @@ func TestSigmoid(t *testing.T) { func TestInputValidationSigmoid(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, nil, }, { + 6, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Sigmoid{}), + ops.ErrInvalidInputCount(0, sigmoid6BaseOpFixture()), }, { + 13, + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, sigmoid13BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &Sigmoid{}), + ops.ErrInvalidInputType(0, "int", sigmoid13BaseOpFixture()), }, } for _, test := range tests { - sigmoid := &Sigmoid{} + sigmoid := sigmoidVersions[test.version]() validated, err := sigmoid.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +122,11 @@ func TestInputValidationSigmoid(t *testing.T) { } } } + +func sigmoid6BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(6, 1, 1, sigmoidTypeConstraints, "sigmoid") +} + +func sigmoid13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, sigmoidTypeConstraints, "sigmoid") +} diff --git a/ops/sigmoid/versions.go b/ops/sigmoid/versions.go new file mode 100644 index 0000000..31e8cf6 --- /dev/null +++ b/ops/sigmoid/versions.go @@ -0,0 +1,12 @@ +package sigmoid + +import "github.com/advancedclimatesystems/gonnx/ops" + +var sigmoidVersions = ops.OperatorVersions{ + 6: ops.NewOperatorConstructor(newSigmoid, 6, sigmoidTypeConstraints), + 13: ops.NewOperatorConstructor(newSigmoid, 13, sigmoidTypeConstraints), +} + +func GetSigmoidVersions() ops.OperatorVersions { + return sigmoidVersions +} diff --git a/ops/sin/sin.go b/ops/sin/sin.go new file mode 100644 index 0000000..a93f3f9 --- /dev/null +++ b/ops/sin/sin.go @@ -0,0 +1,61 @@ +package sin + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var sinTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Sin represents the ONNX sin operator. +type Sin struct { + ops.BaseOperator +} + +// newSin creates a new sin operator. +func newSin(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Sin{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "sin", + ), + } +} + +// Init initializes the sin operator. +func (s *Sin) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the sin operator. +func (s *Sin) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(sin[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(sin[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func sin[T ops.FloatType](x T) T { + return T(math.Sin(float64(x))) +} diff --git a/ops/opset13/sin_test.go b/ops/sin/sin_test.go similarity index 77% rename from ops/opset13/sin_test.go rename to ops/sin/sin_test.go index 1ec4483..e62c455 100644 --- a/ops/opset13/sin_test.go +++ b/ops/sin/sin_test.go @@ -1,4 +1,4 @@ -package opset13 +package sin import ( "testing" @@ -19,25 +19,25 @@ func TestSinInit(t *testing.T) { func TestSin(t *testing.T) { tests := []struct { - sin *Sin + version int64 backing []float32 shape []int expected []float32 }{ { - &Sin{}, + 7, []float32{-2, -1, 0, 1}, []int{2, 2}, []float32{-0.9092974, -0.84147096, 0, 0.84147096}, }, { - &Sin{}, + 7, []float32{1, 3, 4, 5}, []int{1, 4}, []float32{0.84147096, 0.14112, -0.7568025, -0.9589243}, }, { - &Sin{}, + 7, []float32{-1, -1, -1, -1}, []int{1, 4}, []float32{-0.84147096, -0.84147096, -0.84147096, -0.84147096}, @@ -45,11 +45,12 @@ func TestSin(t *testing.T) { } for _, test := range tests { + sin := sinVersions[test.version]() inputs := []tensor.Tensor{ ops.TensorWithBackingFixture(test.backing, test.shape...), } - res, err := test.sin.Apply(inputs) + res, err := sin.Apply(inputs) assert.Nil(t, err) assert.Nil(t, err) @@ -59,35 +60,40 @@ func TestSin(t *testing.T) { func TestInputValidationSin(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Sin{}), + ops.ErrInvalidInputCount(0, sin7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Sin{}), + ops.ErrInvalidInputType(0, "int", sin7BaseOpFixture()), }, } for _, test := range tests { - sin := &Sin{} + sin := sinVersions[test.version]() validated, err := sin.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +103,7 @@ func TestInputValidationSin(t *testing.T) { } } } + +func sin7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 1, 1, sinTypeConstraints, "sin") +} diff --git a/ops/sin/versions.go b/ops/sin/versions.go new file mode 100644 index 0000000..2b52a9c --- /dev/null +++ b/ops/sin/versions.go @@ -0,0 +1,11 @@ +package sin + +import "github.com/advancedclimatesystems/gonnx/ops" + +var sinVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newSin, 7, sinTypeConstraints), +} + +func GetSinVersions() ops.OperatorVersions { + return sinVersions +} diff --git a/ops/sinh/sinh.go b/ops/sinh/sinh.go new file mode 100644 index 0000000..4132bb7 --- /dev/null +++ b/ops/sinh/sinh.go @@ -0,0 +1,61 @@ +package sinh + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var sinhTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Sinh represents the ONNX sinh operator. +type Sinh struct { + ops.BaseOperator +} + +// newSin creates a new sinh operator. +func newSinh(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Sinh{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "sinh", + ), + } +} + +// Init initializes the sinh operator. +func (s *Sinh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the sinh operator. +func (s *Sinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(sinh[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(sinh[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func sinh[T ops.FloatType](x T) T { + return T(math.Sinh(float64(x))) +} diff --git a/ops/opset13/sinh_test.go b/ops/sinh/sinh_test.go similarity index 82% rename from ops/opset13/sinh_test.go rename to ops/sinh/sinh_test.go index 3288490..3d564c6 100644 --- a/ops/opset13/sinh_test.go +++ b/ops/sinh/sinh_test.go @@ -1,4 +1,4 @@ -package opset13 +package sinh import ( "testing" @@ -59,35 +59,40 @@ func TestSinh(t *testing.T) { func TestInputValidationSinh(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 9, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Sinh{}), + ops.ErrInvalidInputCount(0, sinh9BaseOpFixture()), }, { + 9, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Sinh{}), + ops.ErrInvalidInputType(0, "int", sinh9BaseOpFixture()), }, } for _, test := range tests { - sinh := &Sinh{} + sinh := sinhVersions[test.version]() validated, err := sinh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +102,7 @@ func TestInputValidationSinh(t *testing.T) { } } } + +func sinh9BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(9, 1, 1, sinhTypeConstraints, "sinh") +} diff --git a/ops/sinh/versions.go b/ops/sinh/versions.go new file mode 100644 index 0000000..44d9b61 --- /dev/null +++ b/ops/sinh/versions.go @@ -0,0 +1,11 @@ +package sinh + +import "github.com/advancedclimatesystems/gonnx/ops" + +var sinhVersions = ops.OperatorVersions{ + 9: ops.NewOperatorConstructor(newSinh, 9, sinhTypeConstraints), +} + +func GetSinhVersions() ops.OperatorVersions { + return sinhVersions +} diff --git a/ops/opset13/slice.go b/ops/slice/slice.go similarity index 57% rename from ops/opset13/slice.go rename to ops/slice/slice.go index d7589f5..402d32b 100644 --- a/ops/opset13/slice.go +++ b/ops/slice/slice.go @@ -1,4 +1,4 @@ -package opset13 +package slice import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,17 +6,35 @@ import ( "gorgonia.org/tensor" ) +var sliceTypeConstraints = [][]tensor.Dtype{ + ops.AllTypes, + {tensor.Int32, tensor.Int64}, + {tensor.Int32, tensor.Int64}, + {tensor.Int32, tensor.Int64}, + {tensor.Int32, tensor.Int64}, +} + const ( MinSliceInputs = 3 MaxSliceInputs = 5 ) // Slice represents the ONNX slice operator. -type Slice struct{} +type Slice struct { + ops.BaseOperator +} // newSlice creates a new slice operator. -func newSlice() ops.Operator { - return &Slice{} +func newSlice(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Slice{ + BaseOperator: ops.NewBaseOperator( + version, + MinSliceInputs, + MaxSliceInputs, + typeConstraints, + "slice", + ), + } } // Init initializes the slice operator. @@ -38,7 +56,7 @@ func (s *Slice) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return nil, err } - axes := s.getDefaultAxes(len(starts)) + axes := getDefaultAxes(len(starts)) if inputs[3] != nil { axes, err = ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[3].Data())) if err != nil { @@ -46,7 +64,7 @@ func (s *Slice) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } } - steps := s.getDefaultSteps(len(starts)) + steps := getDefaultSteps(len(starts)) if inputs[4] != nil { steps, err = ops.AnyToIntSlice(ops.IfScalarToSlice(inputs[4].Data())) if err != nil { @@ -54,7 +72,7 @@ func (s *Slice) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { } } - slices := s.constructSlices(starts, ends, steps, axes, len(data.Shape())) + slices := constructSlices(starts, ends, steps, axes, len(data.Shape())) out, err := data.Slice(slices...) if err != nil { @@ -64,41 +82,9 @@ func (s *Slice) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out.Materialize()}, nil } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Slice) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Slice) GetMinInputs() int { - return MinSliceInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Slice) GetMaxInputs() int { - return MaxSliceInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Slice) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ - ops.AllTypes, - {tensor.Int32, tensor.Int64}, - {tensor.Int32, tensor.Int64}, - {tensor.Int32, tensor.Int64}, - {tensor.Int32, tensor.Int64}, - } -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Slice) String() string { - return "slice operator" -} - // constructSlice constructs a list with tensor.Slice objects. The list is initializes with nils. // The axes parameter determines at which indices tensor.Slice objects are placed. -func (s *Slice) constructSlices(starts, ends, steps, axes []int, nTotalSlices int) []tensor.Slice { +func constructSlices(starts, ends, steps, axes []int, nTotalSlices int) []tensor.Slice { slices := make([]tensor.Slice, nTotalSlices) for i := 0; i < nTotalSlices; i++ { slices[i] = nil @@ -116,7 +102,7 @@ func (s *Slice) constructSlices(starts, ends, steps, axes []int, nTotalSlices in } // getDefaultAxes returns the default axes parameter. By default the slices are in natural order. -func (s *Slice) getDefaultAxes(nSlices int) []int { +func getDefaultAxes(nSlices int) []int { axes := make([]int, nSlices) for i := 0; i < nSlices; i++ { axes[i] = i @@ -126,7 +112,7 @@ func (s *Slice) getDefaultAxes(nSlices int) []int { } // getDefaultSteps returns the default steps data. By default the steps are 1. -func (s *Slice) getDefaultSteps(nSlices int) []int { +func getDefaultSteps(nSlices int) []int { steps := make([]int, nSlices) for i := 0; i < nSlices; i++ { steps[i] = 1 diff --git a/ops/slice/slice_1.go b/ops/slice/slice_1.go new file mode 100644 index 0000000..2178265 --- /dev/null +++ b/ops/slice/slice_1.go @@ -0,0 +1,96 @@ +package slice + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinSlice1Attributes = 2 + MaxSlice1Attributes = 3 +) + +// Slice1 represents the ONNX slice operator. +type Slice1 struct { + ops.BaseOperator + + axes []int + ends []int + starts []int +} + +// newSlice1 creates a new slice operator. +func newSlice1() ops.Operator { + return &Slice1{ + BaseOperator: ops.NewBaseOperator( + 1, + MinSliceInputs, + MaxSliceInputs, + sliceTypeConstraints, + "slice", + ), + } +} + +// Init initializes the slice operator. +func (s *Slice1) Init(n *onnx.NodeProto) error { + nAttrs := len(n.GetAttribute()) + if nAttrs < MinSlice1Attributes || nAttrs > MaxSlice1Attributes { + return ops.ErrInvalidOptionalAttributeCount(MinSlice1Attributes, MaxSlice1Attributes, nAttrs, s) + } + + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "axes": + axes, err := ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return err + } + + s.axes = axes + case "ends": + ends, err := ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return err + } + + s.ends = ends + case "starts": + starts, err := ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return err + } + + s.starts = starts + default: + return ops.ErrInvalidAttribute(attr.GetName(), s) + } + } + + return nil +} + +// Apply applies the slice operator. +func (s *Slice1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + data := inputs[0] + + axes := s.axes + if len(s.axes) == 0 { + axes = getDefaultAxes(len(s.starts)) + } + + steps := make([]int, len(s.starts)) + for i := range steps { + steps[i] = 1 + } + + slices := constructSlices(s.starts, s.ends, steps, axes, len(data.Shape())) + + out, err := data.Slice(slices...) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out.Materialize()}, nil +} diff --git a/ops/opset13/slice_test.go b/ops/slice/slice_test.go similarity index 59% rename from ops/opset13/slice_test.go rename to ops/slice/slice_test.go index 652608f..d3eac41 100644 --- a/ops/opset13/slice_test.go +++ b/ops/slice/slice_test.go @@ -1,24 +1,60 @@ -package opset13 +package slice import ( "testing" + "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" "gorgonia.org/tensor" ) func TestSliceInit(t *testing.T) { - s := &Slice{} + tests := []struct { + version int64 + attrs *onnx.NodeProto + err error + }{ + { + 1, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 0}}, + {Name: "starts", Ints: []int64{1, 0}}, + {Name: "ends", Ints: []int64{2, 2}}, + }, + }, + nil, + }, + { + 10, + nil, + nil, + }, + { + 11, + nil, + nil, + }, + { + 13, + nil, + nil, + }, + } - // since the slice does not have any attributes we pass in nil. This should not - // fail initializing the slice. - err := s.Init(nil) - assert.Nil(t, err) + for _, test := range tests { + op := sliceVersions[test.version]() + err := op.Init(test.attrs) + assert.Equal(t, test.err, err) + } } func TestSlice(t *testing.T) { tests := []struct { + version int64 + attrs *onnx.NodeProto + shape []int starts []int64 ends []int64 @@ -28,6 +64,25 @@ func TestSlice(t *testing.T) { expectedBacking []float32 }{ { + 1, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{0, 1}}, + {Name: "starts", Ints: []int64{1, 0}}, + {Name: "ends", Ints: []int64{2, 3}}, + }, + }, + []int{2, 4}, + nil, + nil, + nil, + nil, + []int{3}, + []float32{4, 5, 6}, + }, + { + 13, + nil, []int{2, 3}, []int64{1, 0}, []int64{2, 2}, @@ -37,6 +92,19 @@ func TestSlice(t *testing.T) { []float32{3, 4}, }, { + 13, + nil, + []int{2, 3}, + []int64{1, 0}, + []int64{2, 2}, + nil, + nil, + []int{2}, + []float32{3, 4}, + }, + { + 13, + nil, []int{3, 3}, []int64{1}, []int64{3}, @@ -46,6 +114,8 @@ func TestSlice(t *testing.T) { []float32{3, 4, 5, 6, 7, 8}, }, { + 13, + nil, []int{3, 3}, []int64{1}, []int64{3}, @@ -55,6 +125,8 @@ func TestSlice(t *testing.T) { []float32{1, 2, 4, 5, 7, 8}, }, { + 13, + nil, []int{2, 3, 3}, []int64{0, 1, 1}, []int64{1, 3, 3}, @@ -64,6 +136,8 @@ func TestSlice(t *testing.T) { []float32{4, 5, 7, 8}, }, { + 13, + nil, []int{4, 4}, []int64{0}, []int64{4}, @@ -75,11 +149,21 @@ func TestSlice(t *testing.T) { } for _, test := range tests { - slice := &Slice{} - inputs := []tensor.Tensor{ - ops.Float32TensorFixture(test.shape...), - ops.TensorWithBackingFixture(test.starts, len(test.starts)), - ops.TensorWithBackingFixture(test.ends, len(test.ends)), + slice := sliceVersions[test.version]() + err := slice.Init(test.attrs) + assert.Nil(t, err) + + var inputs []tensor.Tensor + if test.version >= 10 { + inputs = []tensor.Tensor{ + ops.Float32TensorFixture(test.shape...), + ops.TensorWithBackingFixture(test.starts, len(test.starts)), + ops.TensorWithBackingFixture(test.ends, len(test.ends)), + } + } else { + inputs = []tensor.Tensor{ + ops.Float32TensorFixture(test.shape...), + } } if test.axes != nil { @@ -105,7 +189,6 @@ func TestSlice(t *testing.T) { func TestConstructSlices(t *testing.T) { tests := []struct { - slice *Slice starts []int ends []int axes []int @@ -114,7 +197,6 @@ func TestConstructSlices(t *testing.T) { expectedSlices []tensor.Slice }{ { - &Slice{}, []int{1, 0}, []int{2, 3}, []int{0, 1}, @@ -123,7 +205,6 @@ func TestConstructSlices(t *testing.T) { []tensor.Slice{ops.NewSlicer(1, 2, 1), ops.NewSlicer(0, 3, 1)}, }, { - &Slice{}, []int{0, 2}, []int{2, 5}, []int{2, 0}, @@ -134,7 +215,7 @@ func TestConstructSlices(t *testing.T) { } for _, test := range tests { - slices := test.slice.constructSlices( + slices := constructSlices( test.starts, test.ends, test.steps, test.axes, test.nSlices, ) @@ -153,24 +234,24 @@ func TestConstructSlices(t *testing.T) { } func TestGetDefaultAxes(t *testing.T) { - slice := &Slice{} - res := slice.getDefaultAxes(3) + res := getDefaultAxes(3) assert.Equal(t, []int{0, 1, 2}, res) } func TestGetDefaultSteps(t *testing.T) { - slice := &Slice{} - res := slice.getDefaultSteps(3) + res := getDefaultSteps(3) assert.Equal(t, []int{1, 1, 1}, res) } func TestInputValidationSlice(t *testing.T) { tests := []struct { + version int64 inputs []tensor.Tensor expected []tensor.Tensor err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -182,6 +263,7 @@ func TestInputValidationSlice(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -197,23 +279,43 @@ func TestInputValidationSlice(t *testing.T) { nil, }, { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, nil, - ops.ErrInvalidOptionalInputCount(1, &Slice{}), + ops.ErrInvalidOptionalInputCount(1, slice1BaseOpFixture()), }, { + 10, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, slice10BaseOpFixture()), + }, + { + 11, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, slice11BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + nil, + ops.ErrInvalidOptionalInputCount(1, slice13BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - ops.ErrInvalidInputType(1, "int", &Slice{}), + ops.ErrInvalidInputType(1, "int", slice13BaseOpFixture()), }, } for _, test := range tests { - slice := &Slice{} + slice := sliceVersions[test.version]() validated, err := slice.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -227,3 +329,19 @@ func TestInputValidationSlice(t *testing.T) { } } } + +func slice1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 3, 5, sliceTypeConstraints, "slice") +} + +func slice10BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(10, 3, 5, sliceTypeConstraints, "slice") +} + +func slice11BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(11, 3, 5, sliceTypeConstraints, "slice") +} + +func slice13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 3, 5, sliceTypeConstraints, "slice") +} diff --git a/ops/slice/versions.go b/ops/slice/versions.go new file mode 100644 index 0000000..d52733f --- /dev/null +++ b/ops/slice/versions.go @@ -0,0 +1,14 @@ +package slice + +import "github.com/advancedclimatesystems/gonnx/ops" + +var sliceVersions = ops.OperatorVersions{ + 1: newSlice1, + 10: ops.NewOperatorConstructor(newSlice, 10, sliceTypeConstraints), + 11: ops.NewOperatorConstructor(newSlice, 11, sliceTypeConstraints), + 13: ops.NewOperatorConstructor(newSlice, 13, sliceTypeConstraints), +} + +func GetSliceVersions() ops.OperatorVersions { + return sliceVersions +} diff --git a/ops/opset13/softmax.go b/ops/softmax/softmax.go similarity index 52% rename from ops/opset13/softmax.go rename to ops/softmax/softmax.go index 8a2c0c0..e352c69 100644 --- a/ops/opset13/softmax.go +++ b/ops/softmax/softmax.go @@ -1,4 +1,4 @@ -package opset13 +package softmax import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,16 +6,27 @@ import ( "gorgonia.org/tensor" ) +var softmaxTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + // Softmax represents the ONNX softmax operator. type Softmax struct { + ops.BaseOperator + // The axis along which to perform the Softmax operation. axis int } // newSoftmax creates a new softmax operator. -func newSoftmax() ops.Operator { +func newSoftmax(version int, typeConstraints [][]tensor.Dtype) ops.Operator { return &Softmax{ - axis: -1, // This is the default value by ONNX definition. + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "softmax", + ), + axis: -1, } } @@ -56,29 +67,3 @@ func (s *Softmax) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, nil } - -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Softmax) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Softmax) GetMinInputs() int { - return 1 -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Softmax) GetMaxInputs() int { - return 1 -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Softmax) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Softmax) String() string { - return "softmax operator" -} diff --git a/ops/opset13/softmax_test.go b/ops/softmax/softmax_test.go similarity index 66% rename from ops/opset13/softmax_test.go rename to ops/softmax/softmax_test.go index 5c01bc0..25b462a 100644 --- a/ops/opset13/softmax_test.go +++ b/ops/softmax/softmax_test.go @@ -1,4 +1,4 @@ -package opset13 +package softmax import ( "testing" @@ -97,38 +97,73 @@ func TestSoftmaxFail(t *testing.T) { func TestInputValidationSoftmax(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 1, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 11, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 1, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(2, softmax1BaseOpFixture()), + }, + { + 11, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(2, &Softmax{}), + ops.ErrInvalidInputCount(2, softmax11BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Softmax{}), + ops.ErrInvalidInputCount(2, softmax13BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", softmax13BaseOpFixture()), }, } for _, test := range tests { - softmax := &Softmax{} + softmax := softmaxVersions[test.version]() validated, err := softmax.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -138,3 +173,15 @@ func TestInputValidationSoftmax(t *testing.T) { } } } + +func softmax1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 1, 1, softmaxTypeConstraints, "softmax") +} + +func softmax11BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(11, 1, 1, softmaxTypeConstraints, "softmax") +} + +func softmax13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, softmaxTypeConstraints, "softmax") +} diff --git a/ops/softmax/versions.go b/ops/softmax/versions.go new file mode 100644 index 0000000..2f95b4b --- /dev/null +++ b/ops/softmax/versions.go @@ -0,0 +1,13 @@ +package softmax + +import "github.com/advancedclimatesystems/gonnx/ops" + +var softmaxVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newSoftmax, 1, softmaxTypeConstraints), + 11: ops.NewOperatorConstructor(newSoftmax, 11, softmaxTypeConstraints), + 13: ops.NewOperatorConstructor(newSoftmax, 13, softmaxTypeConstraints), +} + +func GetSoftmaxVersions() ops.OperatorVersions { + return softmaxVersions +} diff --git a/ops/opset13/squeeze.go b/ops/squeeze/squeeze.go similarity index 73% rename from ops/opset13/squeeze.go rename to ops/squeeze/squeeze.go index d4c9055..7458ab2 100644 --- a/ops/opset13/squeeze.go +++ b/ops/squeeze/squeeze.go @@ -1,4 +1,4 @@ -package opset13 +package squeeze import ( "github.com/advancedclimatesystems/gonnx/onnx" @@ -6,17 +6,32 @@ import ( "gorgonia.org/tensor" ) +var squeezeTypeConstraints = [][]tensor.Dtype{ + ops.AllTypes, + {tensor.Int64}, +} + const ( MinSqueezeInputs = 1 MaxSqueezeInputs = 2 ) // Squeeze represents the ONNX squeeze operator. -type Squeeze struct{} +type Squeeze struct { + ops.BaseOperator +} // newSqueeze creates a new squeeze operator. -func newSqueeze() ops.Operator { - return &Squeeze{} +func newSqueeze(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Squeeze{ + BaseOperator: ops.NewBaseOperator( + version, + MinSqueezeInputs, + MaxSqueezeInputs, + typeConstraints, + "squeeze", + ), + } } // Init initializes the squeeze operator. @@ -59,32 +74,6 @@ func (s *Squeeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, err } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (s *Squeeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(s, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (s *Squeeze) GetMinInputs() int { - return MinSqueezeInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (s *Squeeze) GetMaxInputs() int { - return MaxSqueezeInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (s *Squeeze) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (s *Squeeze) String() string { - return "squeeze operator" -} - // getDimsToSqueezeFromTensor creates a list with ints representing the dimensions/axes to squeeze // based on a tensor. The tensor should contain dimensions/axes to squeeze. Negative dimensions // represent dimensions counting from the end of the shape, i.e. -2 repesents the second diff --git a/ops/squeeze/squeeze_1.go b/ops/squeeze/squeeze_1.go new file mode 100644 index 0000000..661a549 --- /dev/null +++ b/ops/squeeze/squeeze_1.go @@ -0,0 +1,74 @@ +package squeeze + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Squeeze1 represents the ONNX squeeze operator. +type Squeeze1 struct { + ops.BaseOperator + + axes []int +} + +// newSqueeze1 creates a new squeeze operator. +func newSqueeze1() ops.Operator { + return &Squeeze1{ + BaseOperator: ops.NewBaseOperator( + 1, + 1, + 1, + [][]tensor.Dtype{ops.AllTypes}, + "squeeze", + ), + } +} + +// Init initializes the squeeze operator. +func (s *Squeeze1) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "axes": + axes, err := ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return err + } + + s.axes = axes + default: + return ops.ErrInvalidAttribute(attr.GetName(), s) + } + } + + return nil +} + +// Apply applies the squeeze operator. +func (s *Squeeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + currentShape := inputs[0].Shape() + nDims := len(currentShape) + dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) + + if !ops.AllInRange(dimsToSqueeze, 0, nDims-1) { + return nil, ops.ErrNotAllAxesInRange(nDims, nDims) + } + + if len(s.axes) > 0 { + dimsToSqueeze = getDimsToSqueezeFromList(s.axes, nDims) + } + + newShape := getNewShape(currentShape, dimsToSqueeze) + + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + err = out.Reshape(newShape...) + + return []tensor.Tensor{out}, err +} diff --git a/ops/squeeze/squeeze_11.go b/ops/squeeze/squeeze_11.go new file mode 100644 index 0000000..259e86c --- /dev/null +++ b/ops/squeeze/squeeze_11.go @@ -0,0 +1,95 @@ +package squeeze + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Squeeze11 represents the ONNX squeeze operator. +type Squeeze11 struct { + ops.BaseOperator + + axes []int +} + +// newSqueeze11 creates a new squeeze operator. +func newSqueeze11() ops.Operator { + return &Squeeze11{ + BaseOperator: ops.NewBaseOperator( + 11, + 1, + 1, + [][]tensor.Dtype{ops.AllTypes}, + "squeeze", + ), + } +} + +// Init initializes the squeeze operator. +func (s *Squeeze11) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "axes": + axes, err := ops.AnyToIntSlice(attr.GetInts()) + if err != nil { + return err + } + + s.axes = axes + default: + return ops.ErrInvalidAttribute(attr.GetName(), s) + } + } + + return nil +} + +// Apply applies the squeeze operator. +func (s *Squeeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var err error + + currentShape := inputs[0].Shape() + nDims := len(currentShape) + dimsToSqueeze := getDimsToSqueezeFromShape(currentShape) + + if !ops.AllInRange(dimsToSqueeze, -nDims, nDims-1) { + return nil, ops.ErrNotAllAxesInRange(nDims, nDims) + } + + // negative entries should be offset by the rank of the output tensor + // i.e. -1 -> nDims - 1, -nDims -> 0 + ops.OffsetArrayIfNegative(dimsToSqueeze, nDims) + + if len(s.axes) > 0 { + dimsToSqueeze = getDimsToSqueezeFromList(s.axes, nDims) + } + + newShape := getNewShape(currentShape, dimsToSqueeze) + + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + err = out.Reshape(newShape...) + + return []tensor.Tensor{out}, err +} + +// getDimsToSqueezeFromList creates a list with ints representing the dimensions/axes to squeeze +// based on a list of ints. The list should contain dimensions/axes to squeeze. Negative dimensions +// represent dimensions counting from the end of the shape, i.e. -2 repesents the second +// last dimension. +func getDimsToSqueezeFromList(axes []int, nDims int) []int { + dimsToSqueeze := make([]int, len(axes)) + copy(dimsToSqueeze, axes) + + for i, val := range dimsToSqueeze { + if val < 0 { + dimsToSqueeze[i] = nDims + val + } + } + + return dimsToSqueeze +} diff --git a/ops/opset13/squeeze_test.go b/ops/squeeze/squeeze_test.go similarity index 81% rename from ops/opset13/squeeze_test.go rename to ops/squeeze/squeeze_test.go index bb160b5..99da0d3 100644 --- a/ops/opset13/squeeze_test.go +++ b/ops/squeeze/squeeze_test.go @@ -1,4 +1,4 @@ -package opset13 +package squeeze import ( "testing" @@ -19,26 +19,31 @@ func TestSqueezeInit(t *testing.T) { func TestSqueezeCustomDims(t *testing.T) { tests := []struct { + version int64 shape []int dimsToDrop []int64 expectedShape tensor.Shape }{ { + 13, []int{3, 1, 2}, []int64{1}, []int{3, 2}, }, { + 13, []int{3, 1, 2}, []int64{-2}, []int{3, 2}, }, { + 13, []int{1, 4, 3, 1}, []int64{0, -1}, []int{4, 3}, }, { + 13, []int{1, 4, 3, 1}, []int64{0}, []int{4, 3, 1}, @@ -46,7 +51,7 @@ func TestSqueezeCustomDims(t *testing.T) { } for _, test := range tests { - squeeze := &Squeeze{} + squeeze := squeezeVersions[test.version]() inputs := []tensor.Tensor{ ops.Float32TensorFixture(test.shape...), ops.TensorWithBackingFixture(test.dimsToDrop, len(test.dimsToDrop)), @@ -60,21 +65,24 @@ func TestSqueezeCustomDims(t *testing.T) { func TestSqueezeNoDims(t *testing.T) { tests := []struct { + version int64 shape []int expectedShape tensor.Shape }{ { + 13, []int{3, 1, 2}, []int{3, 2}, }, { + 13, []int{1, 4, 3, 1}, []int{4, 3}, }, } for _, test := range tests { - squeeze := &Squeeze{} + squeeze := squeezeVersions[test.version]() inputs := []tensor.Tensor{ops.Float32TensorFixture(test.shape...), nil} res, err := squeeze.Apply(inputs) @@ -128,11 +136,29 @@ func TestKeepDim(t *testing.T) { func TestInputValidationSqueeze(t *testing.T) { tests := []struct { + version int64 inputs []tensor.Tensor expected []tensor.Tensor err error }{ { + 1, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + 11, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + nil, + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -141,36 +167,40 @@ func TestInputValidationSqueeze(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2), nil}, nil, }, { + 13, []tensor.Tensor{}, nil, - ops.ErrInvalidOptionalInputCount(0, &Squeeze{}), + ops.ErrInvalidOptionalInputCount(0, squeeze13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - ops.ErrInvalidOptionalInputCount(3, &Squeeze{}), + ops.ErrInvalidOptionalInputCount(3, squeeze13BaseOpFixture()), }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, nil, - ops.ErrInvalidInputType(1, "int", &Squeeze{}), + ops.ErrInvalidInputType(1, "int", squeeze13BaseOpFixture()), }, } for _, test := range tests { - squeeze := &Squeeze{} + squeeze := squeezeVersions[test.version]() validated, err := squeeze.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -184,3 +214,7 @@ func TestInputValidationSqueeze(t *testing.T) { } } } + +func squeeze13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 2, squeezeTypeConstraints, "squeeze") +} diff --git a/ops/squeeze/versions.go b/ops/squeeze/versions.go new file mode 100644 index 0000000..04a853a --- /dev/null +++ b/ops/squeeze/versions.go @@ -0,0 +1,13 @@ +package squeeze + +import "github.com/advancedclimatesystems/gonnx/ops" + +var squeezeVersions = ops.OperatorVersions{ + 1: newSqueeze1, + 11: newSqueeze11, + 13: ops.NewOperatorConstructor(newSqueeze, 13, squeezeTypeConstraints), +} + +func GetSqueezeVersions() ops.OperatorVersions { + return squeezeVersions +} diff --git a/ops/sub/sub.go b/ops/sub/sub.go new file mode 100644 index 0000000..42d955c --- /dev/null +++ b/ops/sub/sub.go @@ -0,0 +1,45 @@ +package sub + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var subTypeConstraints = [][]tensor.Dtype{ + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, + {tensor.Uint32, tensor.Uint64, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, +} + +// Sub represents the ONNX sub operator. +type Sub struct { + ops.BaseOperator +} + +// newSub creates a new sub operator. +func newSub(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Sub{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "sub", + ), + } +} + +// Init initializes the sub operator. +func (s *Sub) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the sub operator. +func (s *Sub) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Sub, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/sub_test.go b/ops/sub/sub_test.go similarity index 74% rename from ops/opset13/sub_test.go rename to ops/sub/sub_test.go index 6812be0..7cdf803 100644 --- a/ops/opset13/sub_test.go +++ b/ops/sub/sub_test.go @@ -1,4 +1,4 @@ -package opset13 +package sub import ( "testing" @@ -54,10 +54,12 @@ func TestSub(t *testing.T) { func TestInputValidationSub(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint32{1, 2}, 2), ops.TensorWithBackingFixture([]uint32{3, 4}, 2), @@ -65,6 +67,7 @@ func TestInputValidationSub(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]uint64{1, 2}, 2), ops.TensorWithBackingFixture([]uint64{3, 4}, 2), @@ -72,6 +75,7 @@ func TestInputValidationSub(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), @@ -79,6 +83,7 @@ func TestInputValidationSub(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int64{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), @@ -86,6 +91,7 @@ func TestInputValidationSub(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]float32{3, 4}, 2), @@ -93,6 +99,7 @@ func TestInputValidationSub(t *testing.T) { nil, }, { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), ops.TensorWithBackingFixture([]float64{3, 4}, 2), @@ -100,22 +107,39 @@ func TestInputValidationSub(t *testing.T) { nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputCount(1, &Sub{}), + ops.ErrInvalidInputCount(1, sub7BaseOpFixture()), }, { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, sub13BaseOpFixture()), + }, + { + 7, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + }, + ops.ErrInvalidInputType(0, "int", sub7BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Sub{}), + ops.ErrInvalidInputType(0, "int", sub13BaseOpFixture()), }, } for _, test := range tests { - sub := &Sub{} + sub := subVersions[test.version]() validated, err := sub.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -125,3 +149,11 @@ func TestInputValidationSub(t *testing.T) { } } } + +func sub7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 2, 2, subTypeConstraints, "sub") +} + +func sub13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, subTypeConstraints, "sub") +} diff --git a/ops/sub/versions.go b/ops/sub/versions.go new file mode 100644 index 0000000..9847251 --- /dev/null +++ b/ops/sub/versions.go @@ -0,0 +1,12 @@ +package sub + +import "github.com/advancedclimatesystems/gonnx/ops" + +var subVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newSub, 7, subTypeConstraints), + 13: ops.NewOperatorConstructor(newSub, 13, subTypeConstraints), +} + +func GetSubVersions() ops.OperatorVersions { + return subVersions +} diff --git a/ops/tan/tan.go b/ops/tan/tan.go new file mode 100644 index 0000000..4629613 --- /dev/null +++ b/ops/tan/tan.go @@ -0,0 +1,61 @@ +package tan + +import ( + "math" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var tanTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Tan represents the ONNX tan operator. +type Tan struct { + ops.BaseOperator +} + +// newTan creates a new tan operator. +func newTan(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Tan{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "tan", + ), + } +} + +// Init initializes the tan operator. +func (t *Tan) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the tan operator. +func (t *Tan) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + var ( + out tensor.Tensor + err error + ) + + switch inputs[0].Dtype() { + case tensor.Float32: + out, err = inputs[0].Apply(tan[float32]) + case tensor.Float64: + out, err = inputs[0].Apply(tan[float64]) + default: + return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), t.BaseOperator) + } + + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +func tan[T ops.FloatType](x T) T { + return T(math.Tan(float64(x))) +} diff --git a/ops/opset13/tan_test.go b/ops/tan/tan_test.go similarity index 82% rename from ops/opset13/tan_test.go rename to ops/tan/tan_test.go index 2fbaf88..020f93d 100644 --- a/ops/opset13/tan_test.go +++ b/ops/tan/tan_test.go @@ -1,4 +1,4 @@ -package opset13 +package tan import ( "testing" @@ -59,35 +59,40 @@ func TestTan(t *testing.T) { func TestInputValidationTan(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]float64{1, 2}, 2), }, nil, }, { + 7, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Tan{}), + ops.ErrInvalidInputCount(0, tan7BaseOpFixture()), }, { + 7, []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(0, "int", &Tan{}), + ops.ErrInvalidInputType(0, "int", tan7BaseOpFixture()), }, } for _, test := range tests { - tan := &Tan{} + tan := tanVersions[test.version]() validated, err := tan.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -97,3 +102,7 @@ func TestInputValidationTan(t *testing.T) { } } } + +func tan7BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(7, 1, 1, tanTypeConstraints, "tan") +} diff --git a/ops/tan/versions.go b/ops/tan/versions.go new file mode 100644 index 0000000..da379e4 --- /dev/null +++ b/ops/tan/versions.go @@ -0,0 +1,11 @@ +package tan + +import "github.com/advancedclimatesystems/gonnx/ops" + +var tanVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newTan, 7, tanTypeConstraints), +} + +func GetTanVersions() ops.OperatorVersions { + return tanVersions +} diff --git a/ops/tanh/tanh.go b/ops/tanh/tanh.go new file mode 100644 index 0000000..3aca593 --- /dev/null +++ b/ops/tanh/tanh.go @@ -0,0 +1,41 @@ +package tanh + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var tanhTypeConstraint = [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, +} + +// Tanh represents the tanh operator. +type Tanh struct { + ops.BaseOperator +} + +// newTanh returns a new tanh operator. +func newTanh(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Tanh{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraint, + "tanh", + ), + } +} + +// Init initializes the sigmoid operator. +func (t *Tanh) Init(*onnx.NodeProto) error { + return nil +} + +// Apply the sigmoid operator to the input node. +func (t *Tanh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := ops.Tanh(inputs[0]) + + return []tensor.Tensor{out}, err +} diff --git a/ops/opset13/tanh_test.go b/ops/tanh/tanh_test.go similarity index 67% rename from ops/opset13/tanh_test.go rename to ops/tanh/tanh_test.go index 44b5409..9d371bb 100644 --- a/ops/opset13/tanh_test.go +++ b/ops/tanh/tanh_test.go @@ -1,4 +1,4 @@ -package opset13 +package tanh import ( "testing" @@ -9,7 +9,7 @@ import ( ) func TestTanhInit(t *testing.T) { - tanh := newTanh() + tanh := &Tanh{} // Since the tanh does not have any attributes we expect it to initialize even // when nil is passed. err := tanh.Init(nil) @@ -57,29 +57,49 @@ func TestTanh(t *testing.T) { func TestInputValidationTanh(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 6, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, + nil, + }, + { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, nil, }, { + 6, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Tanh{}), + ops.ErrInvalidInputCount(0, tanh6BaseOpFixture()), }, { + 13, + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, tanh13BaseOpFixture()), + }, + { + 6, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &Tanh{}), + ops.ErrInvalidInputType(0, "int", tanh6BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + ops.ErrInvalidInputType(0, "int", tanh13BaseOpFixture()), }, } for _, test := range tests { - tanh := &Tanh{} + tanh := tanhVersions[test.version]() validated, err := tanh.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -89,3 +109,11 @@ func TestInputValidationTanh(t *testing.T) { } } } + +func tanh6BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(6, 1, 1, tanhTypeConstraint, "tanh") +} + +func tanh13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, tanhTypeConstraint, "tanh") +} diff --git a/ops/tanh/versions.go b/ops/tanh/versions.go new file mode 100644 index 0000000..80e39ea --- /dev/null +++ b/ops/tanh/versions.go @@ -0,0 +1,12 @@ +package tanh + +import "github.com/advancedclimatesystems/gonnx/ops" + +var tanhVersions = ops.OperatorVersions{ + 6: ops.NewOperatorConstructor(newTanh, 6, tanhTypeConstraint), + 13: ops.NewOperatorConstructor(newTanh, 13, tanhTypeConstraint), +} + +func GetTanhVersions() ops.OperatorVersions { + return tanhVersions +} diff --git a/ops/transpose/transpose.go b/ops/transpose/transpose.go new file mode 100644 index 0000000..0f78e1f --- /dev/null +++ b/ops/transpose/transpose.go @@ -0,0 +1,63 @@ +package transpose + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var transposeTypeConstraint = [][]tensor.Dtype{ops.AllTypes} + +// Transpose represents the ONNX transpose operator. +type Transpose struct { + ops.BaseOperator + + perm []int +} + +// newTranspose creates a new transpose operator. +func newTranspose(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Transpose{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraint, + "transpose", + ), + } +} + +// Init initializes the transpose operator. +func (t *Transpose) Init(n *onnx.NodeProto) error { + attributes := n.GetAttribute() + + if len(attributes) == 1 { + attr := attributes[0] + + if attr.GetName() != "perm" { + return ops.ErrInvalidAttribute(attr.GetName(), t) + } + + attrPerm := attr.GetInts() + + perm := make([]int, 0) + for _, val := range attrPerm { + perm = append(perm, int(val)) + } + + t.perm = perm + } + + return nil +} + +// Apply applies the transpose operator. +func (t *Transpose) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := tensor.Transpose(inputs[0], t.perm...) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/opset13/transpose_test.go b/ops/transpose/transpose_test.go similarity index 70% rename from ops/opset13/transpose_test.go rename to ops/transpose/transpose_test.go index afe1c9e..40a4d29 100644 --- a/ops/opset13/transpose_test.go +++ b/ops/transpose/transpose_test.go @@ -1,4 +1,4 @@ -package opset13 +package transpose import ( "testing" @@ -25,14 +25,6 @@ func TestTransposeInitFailWrongAttribute(t *testing.T) { assert.Equal(t, expected, err) } -func TestTransposeInitFailAttrCount(t *testing.T) { - trans := &Transpose{} - err := trans.Init(ops.EmptyNodeProto()) - - expected := ops.ErrInvalidAttributeCount(1, 0, trans) - assert.Equal(t, expected, err) -} - func TestTranspose(t *testing.T) { tests := []struct { trans *Transpose @@ -67,33 +59,54 @@ func TestTranspose(t *testing.T) { func TestInputValidationTranspose(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + version int64 + inputs []tensor.Tensor + err error }{ { + 1, []tensor.Tensor{ops.TensorWithBackingFixture([]uint32{1, 2}, 2)}, nil, }, { + 13, + []tensor.Tensor{ops.TensorWithBackingFixture([]uint32{1, 2}, 2)}, + nil, + }, + { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float32{1, 2}, 2)}, nil, }, { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]float64{1, 2}, 2)}, nil, }, { + 1, []tensor.Tensor{}, - ops.ErrInvalidInputCount(0, &Transpose{}), + ops.ErrInvalidInputCount(0, transpose1BaseOpFixture()), }, { + 13, + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, transpose13BaseOpFixture()), + }, + { + 1, + []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, + ops.ErrInvalidInputType(0, "int", transpose1BaseOpFixture()), + }, + { + 13, []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputType(0, "int", &Transpose{}), + ops.ErrInvalidInputType(0, "int", transpose13BaseOpFixture()), }, } for _, test := range tests { - transpose := &Transpose{} + transpose := transposeVersions[test.version]() validated, err := transpose.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -111,3 +124,11 @@ func TransposeOnnxNodeProtoFixture() *onnx.NodeProto { }, } } + +func transpose1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 1, 1, transposeTypeConstraint, "transpose") +} + +func transpose13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, transposeTypeConstraint, "transpose") +} diff --git a/ops/transpose/versions.go b/ops/transpose/versions.go new file mode 100644 index 0000000..fcd1cbf --- /dev/null +++ b/ops/transpose/versions.go @@ -0,0 +1,12 @@ +package transpose + +import "github.com/advancedclimatesystems/gonnx/ops" + +var transposeVersions = ops.OperatorVersions{ + 1: ops.NewOperatorConstructor(newTranspose, 1, transposeTypeConstraint), + 13: ops.NewOperatorConstructor(newTranspose, 13, transposeTypeConstraint), +} + +func GetTransposeVersions() ops.OperatorVersions { + return transposeVersions +} diff --git a/ops/opset13/unsqueeze.go b/ops/unsqueeze/unsqueeze.go similarity index 66% rename from ops/opset13/unsqueeze.go rename to ops/unsqueeze/unsqueeze.go index b7d4530..f4a8908 100644 --- a/ops/opset13/unsqueeze.go +++ b/ops/unsqueeze/unsqueeze.go @@ -1,4 +1,4 @@ -package opset13 +package unsqueeze import ( "sort" @@ -8,17 +8,32 @@ import ( "gorgonia.org/tensor" ) +var unsqueezeTypeConstraints = [][]tensor.Dtype{ + ops.AllTypes, + {tensor.Int64}, +} + const ( MinUnsqueezeInputs = 2 MaxUnsqueezeInputs = 2 ) // Unsqueeze represents the ONNX unsqueeze operator. -type Unsqueeze struct{} +type Unsqueeze struct { + ops.BaseOperator +} // newUnsqueeze creates a new unsqueeze operator. -func newUnsqueeze() ops.Operator { - return &Unsqueeze{} +func newUnsqueeze(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Unsqueeze{ + BaseOperator: ops.NewBaseOperator( + version, + MinUnsqueezeInputs, + MaxUnsqueezeInputs, + typeConstraint, + "unsqueeze", + ), + } } // Init initializes the unsqueeze operator. @@ -48,7 +63,7 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { sort.Ints(axes) if ops.HasDuplicates(axes) { - return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u) + return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator) } newShape := insertOnes(dataShape, axes) @@ -63,32 +78,6 @@ func (u *Unsqueeze) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, err } -// ValidateInputs validates the inputs that will be given to Apply for this operator. -func (u *Unsqueeze) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return ops.ValidateInputs(u, inputs) -} - -// GetMinInputs returns the minimum number of input tensors this operator expects. -func (u *Unsqueeze) GetMinInputs() int { - return MinUnsqueezeInputs -} - -// GetMaxInputs returns the maximum number of input tensors this operator expects. -func (u *Unsqueeze) GetMaxInputs() int { - return MaxUnsqueezeInputs -} - -// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes -// for the corresponding input tensor. -func (u *Unsqueeze) GetInputTypeConstraints() [][]tensor.Dtype { - return [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}} -} - -// String implements the stringer interface, and can be used to format errors or messages. -func (u *Unsqueeze) String() string { - return "unsqueeze operator" -} - // Creates a new array, which is `original` with ones added at the indices specified by `indices` // `indices` may not contain duplicates, the elements are assumed to be in the range 0 <= x < N // and should be sorted in increasing order. diff --git a/ops/unsqueeze/unsqueeze_1.go b/ops/unsqueeze/unsqueeze_1.go new file mode 100644 index 0000000..ff78879 --- /dev/null +++ b/ops/unsqueeze/unsqueeze_1.go @@ -0,0 +1,74 @@ +package unsqueeze + +import ( + "sort" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Unsqueeze1 represents version 1 of the ONNX unsqueeze operator. +type Unsqueeze1 struct { + ops.BaseOperator + + axes []int +} + +// newUnsqueeze1 creates a new unsqueeze operator. +func newUnsqueeze1() ops.Operator { + return &Unsqueeze1{ + BaseOperator: ops.NewBaseOperator( + 1, + 1, + 1, + [][]tensor.Dtype{ops.AllTypes}, + "unsqueeze", + ), + } +} + +// Init initializes the unsqueeze operator. +func (u *Unsqueeze1) Init(n *onnx.NodeProto) error { + attrs := n.GetAttribute() + if len(attrs) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attrs), u) + } + + axes, err := ops.AnyToIntSlice(attrs[0].GetInts()) + if err != nil { + return err + } + + u.axes = axes + + return nil +} + +// Apply applies the unsqueeze operator. +func (u *Unsqueeze1) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + dataShape := inputs[0].Shape() + + outputRank := len(dataShape) + len(u.axes) + + if !ops.AllInRange(u.axes, 0, outputRank-1) { + return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) + } + + sort.Ints(u.axes) + + if ops.HasDuplicates(u.axes) { + return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator) + } + + newShape := insertOnes(dataShape, u.axes) + + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + err := out.Reshape(newShape...) + + return []tensor.Tensor{out}, err +} diff --git a/ops/unsqueeze/unsqueeze_11.go b/ops/unsqueeze/unsqueeze_11.go new file mode 100644 index 0000000..383276d --- /dev/null +++ b/ops/unsqueeze/unsqueeze_11.go @@ -0,0 +1,78 @@ +package unsqueeze + +import ( + "sort" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +// Unsqueeze11 represents version 11 of the ONNX unsqueeze operator. +type Unsqueeze11 struct { + ops.BaseOperator + + axes []int +} + +// newUnsqueeze11 creates a new unsqueeze operator. +func newUnsqueeze11() ops.Operator { + return &Unsqueeze11{ + BaseOperator: ops.NewBaseOperator( + 11, + 1, + 1, + [][]tensor.Dtype{ops.AllTypes}, + "unsqueeze", + ), + } +} + +// Init initializes the unsqueeze operator. +func (u *Unsqueeze11) Init(n *onnx.NodeProto) error { + attrs := n.GetAttribute() + if len(attrs) != 1 { + return ops.ErrInvalidAttributeCount(1, len(attrs), u) + } + + axes, err := ops.AnyToIntSlice(attrs[0].GetInts()) + if err != nil { + return err + } + + u.axes = axes + + return nil +} + +// Apply applies the unsqueeze operator. +func (u *Unsqueeze11) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + dataShape := inputs[0].Shape() + + outputRank := len(dataShape) + len(u.axes) + + if !ops.AllInRange(u.axes, -outputRank, outputRank-1) { + return nil, ops.ErrNotAllAxesInRange(outputRank, outputRank) + } + + // negative entries should be offset by the rank of the output tensor + // i.e. -1 -> outputRank - 1, -outputrank -> 0 + ops.OffsetArrayIfNegative(u.axes, outputRank) + + sort.Ints(u.axes) + + if ops.HasDuplicates(u.axes) { + return nil, ops.ErrInvalidInput("axes cannot have duplicate entries after offset", u.BaseOperator) + } + + newShape := insertOnes(dataShape, u.axes) + + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + err := out.Reshape(newShape...) + + return []tensor.Tensor{out}, err +} diff --git a/ops/opset13/unsqueeze_test.go b/ops/unsqueeze/unsqueeze_test.go similarity index 61% rename from ops/opset13/unsqueeze_test.go rename to ops/unsqueeze/unsqueeze_test.go index 445d0c5..9d27f3d 100644 --- a/ops/opset13/unsqueeze_test.go +++ b/ops/unsqueeze/unsqueeze_test.go @@ -1,24 +1,54 @@ -package opset13 +package unsqueeze import ( "testing" + "github.com/advancedclimatesystems/gonnx/onnx" "github.com/advancedclimatesystems/gonnx/ops" "github.com/stretchr/testify/assert" "gorgonia.org/tensor" ) func TestUnsqueezeInit(t *testing.T) { - s := &Unsqueeze{} + tests := []struct { + version int64 + attrs *onnx.NodeProto + err error + }{ + { + 1, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 0}}, + }, + }, + nil, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "axes", Ints: []int64{1, 0}}, + }, + }, + nil, + }, + { + 13, + nil, + nil, + }, + } - // since the unsqueeze does not have any attributes we pass in nil. This should not - // fail initializing the unsqueeze. - err := s.Init(nil) - assert.NoError(t, err) + for _, test := range tests { + op := unsqueezeVersions[test.version]() + err := op.Init(test.attrs) + assert.Equal(t, test.err, err) + } } func TestAxesOutRangeError(t *testing.T) { - op := &Unsqueeze{} + op := unsqueezeVersions[13]() err := op.Init(nil) assert.Nil(t, err) @@ -33,7 +63,7 @@ func TestAxesOutRangeError(t *testing.T) { } func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { - op := &Unsqueeze{} + op := unsqueezeVersions[13]() err := op.Init(nil) assert.Nil(t, err) @@ -44,11 +74,11 @@ func TestDuplicateEntriesAfterOffsetNotAllowed(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) - assert.EqualError(t, err, "invalid input tensor for unsqueeze operator: axes cannot have duplicate entries after offset") + assert.EqualError(t, err, "invalid input tensor for unsqueeze v13: axes cannot have duplicate entries after offset") } func TestDuplicateEntriesNotAllowed(t *testing.T) { - op := &Unsqueeze{} + op := unsqueezeVersions[13]() err := op.Init(nil) assert.Nil(t, err) @@ -58,7 +88,7 @@ func TestDuplicateEntriesNotAllowed(t *testing.T) { dataIn := ops.TensorWithBackingFixture(data, 3, 3) axesIn := ops.TensorWithBackingFixture(axes, len(axes)) _, err = op.Apply([]tensor.Tensor{dataIn, axesIn}) - assert.EqualError(t, err, "invalid input tensor for unsqueeze operator: axes cannot have duplicate entries after offset") + assert.EqualError(t, err, "invalid input tensor for unsqueeze v13: axes cannot have duplicate entries after offset") } func TestUnsqueeze(t *testing.T) { @@ -129,14 +159,30 @@ func TestUnsqueeze(t *testing.T) { func TestInputValidationUnsqueeze(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + inputs []tensor.Tensor + version int64 + err error }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + 1, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + 11, + nil, + }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), }, + 13, nil, }, { @@ -144,30 +190,50 @@ func TestInputValidationUnsqueeze(t *testing.T) { ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), }, + 13, nil, }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + 1, + ops.ErrInvalidInputCount(2, flatten1BaseOpFixture()), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int64{3, 4}, 2), + }, + 11, + ops.ErrInvalidInputCount(2, flatten11BaseOpFixture()), + }, { []tensor.Tensor{ops.TensorWithBackingFixture([]int{1, 2}, 2)}, - ops.ErrInvalidInputCount(1, &Unsqueeze{}), + 13, + ops.ErrInvalidInputCount(1, flatten13BaseOpFixture()), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]int{1, 2}, 2), ops.TensorWithBackingFixture([]int64{3, 4}, 2), }, - ops.ErrInvalidInputType(0, "int", &Unsqueeze{}), + 13, + ops.ErrInvalidInputType(0, "int", flatten13BaseOpFixture()), }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]float32{1, 2}, 2), ops.TensorWithBackingFixture([]int32{3, 4}, 2), }, - ops.ErrInvalidInputType(1, "int32", &Unsqueeze{}), + 13, + ops.ErrInvalidInputType(1, "int32", flatten13BaseOpFixture()), }, } for _, test := range tests { - unsqueeze := &Unsqueeze{} + unsqueeze := unsqueezeVersions[test.version]() validated, err := unsqueeze.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) @@ -177,3 +243,15 @@ func TestInputValidationUnsqueeze(t *testing.T) { } } } + +func flatten1BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(1, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "unsqueeze") +} + +func flatten11BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(11, 1, 1, [][]tensor.Dtype{ops.AllTypes}, "unsqueeze") +} + +func flatten13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 2, 2, [][]tensor.Dtype{ops.AllTypes, {tensor.Int64}}, "unsqueeze") +} diff --git a/ops/unsqueeze/versions.go b/ops/unsqueeze/versions.go new file mode 100644 index 0000000..08afd79 --- /dev/null +++ b/ops/unsqueeze/versions.go @@ -0,0 +1,13 @@ +package unsqueeze + +import "github.com/advancedclimatesystems/gonnx/ops" + +var unsqueezeVersions = ops.OperatorVersions{ + 1: newUnsqueeze1, + 11: newUnsqueeze11, + 13: ops.NewOperatorConstructor(newUnsqueeze, 13, unsqueezeTypeConstraints), +} + +func GetUnsqueezeVersions() ops.OperatorVersions { + return unsqueezeVersions +} diff --git a/ops/utils.go b/ops/utils.go index 4c88737..1dcee9e 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -19,9 +19,9 @@ func Int64ToBool(v int64) bool { } // AllInRange checks if all the entries in `arr` are in the inclusive range min <= x <= max. -func AllInRange(arr []int, min, max int) bool { +func AllInRange(arr []int, minVal, maxVal int) bool { for _, ax := range arr { - if ax < min || ax > max { + if ax < minVal || ax > maxVal { return false } } diff --git a/ops/validate_inputs.go b/ops/validate_inputs.go index fa82a48..1275ea1 100644 --- a/ops/validate_inputs.go +++ b/ops/validate_inputs.go @@ -8,7 +8,7 @@ import ( // When there are fewer input nodes then the given max, the list is padded with nils. // Expects either 1 requirement ==> the expected number of inputs, or 2 requirements, // the minimum and the maximum number of inputs. -func ValidateInputs(op Operator, inputs []tensor.Tensor) ([]tensor.Tensor, error) { +func ValidateInputs(op BaseOperator, inputs []tensor.Tensor) ([]tensor.Tensor, error) { padLength, err := checkNInputs(op, inputs) if err != nil { return inputs, err @@ -24,25 +24,25 @@ func ValidateInputs(op Operator, inputs []tensor.Tensor) ([]tensor.Tensor, error return inputs, nil } -func checkNInputs(op Operator, inputs []tensor.Tensor) (int, error) { +func checkNInputs(op BaseOperator, inputs []tensor.Tensor) (int, error) { nInputs := len(inputs) padLength := 0 - min := op.GetMinInputs() - max := op.GetMaxInputs() + minInputs := op.GetMinInputs() + maxInputs := op.GetMaxInputs() - if min == max { - if nInputs != min { + if minInputs == maxInputs { + if nInputs != minInputs { return 0, ErrInvalidInputCount(nInputs, op) } - padLength = min + padLength = minInputs } else { - if nInputs < min || nInputs > max { + if nInputs < minInputs || nInputs > maxInputs { return 0, ErrInvalidOptionalInputCount(nInputs, op) } - padLength = max + padLength = maxInputs } return padLength, nil @@ -57,7 +57,7 @@ func padInputs(inputs []tensor.Tensor, length int) []tensor.Tensor { return inputs } -func checkInputTypes(op Operator, inputs []tensor.Tensor) error { +func checkInputTypes(op BaseOperator, inputs []tensor.Tensor) error { typeConstraints := op.GetInputTypeConstraints() for i, input := range inputs { diff --git a/ops/validate_inputs_test.go b/ops/validate_inputs_test.go index 1edde37..f3e0162 100644 --- a/ops/validate_inputs_test.go +++ b/ops/validate_inputs_test.go @@ -10,16 +10,20 @@ import ( func TestValidateInputs(t *testing.T) { tests := []struct { - op Operator + op *MockOp inputs []tensor.Tensor expectedNil int err error }{ { &MockOp{ - minInputs: 1, - maxInputs: 1, - inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 1, + maxInputs: 1, + inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}}, + name: "mockop", + }, }, PaddedInputsFixture(1, 0), 0, @@ -27,9 +31,13 @@ func TestValidateInputs(t *testing.T) { }, { &MockOp{ - minInputs: 2, - maxInputs: 2, - inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}, {tensor.Float32}}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 2, + maxInputs: 2, + inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}, {tensor.Float32}}, + name: "mockop", + }, }, PaddedInputsFixture(2, 0), 0, @@ -37,14 +45,18 @@ func TestValidateInputs(t *testing.T) { }, { &MockOp{ - minInputs: 3, - maxInputs: 5, - inputTypeConstraints: [][]tensor.Dtype{ - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 3, + maxInputs: 5, + inputTypeConstraints: [][]tensor.Dtype{ + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + }, + name: "mockop", }, }, PaddedInputsFixture(3, 0), @@ -53,14 +65,18 @@ func TestValidateInputs(t *testing.T) { }, { &MockOp{ - minInputs: 3, - maxInputs: 5, - inputTypeConstraints: [][]tensor.Dtype{ - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 3, + maxInputs: 5, + inputTypeConstraints: [][]tensor.Dtype{ + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + }, + name: "mockop", }, }, PaddedInputsFixture(4, 0), @@ -69,44 +85,59 @@ func TestValidateInputs(t *testing.T) { }, { &MockOp{ - minInputs: 2, - maxInputs: 2, - inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}, {tensor.Float32}}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 2, + maxInputs: 2, + inputTypeConstraints: [][]tensor.Dtype{ + {tensor.Float32}, + {tensor.Float32}, + }, + name: "mockop", + }, }, PaddedInputsFixture(1, 0), 0, - ErrInvalidInputCount(1, &MockOp{minInputs: 2, maxInputs: 2}), + ErrInvalidInputCount(1, NewBaseOperator(1, 2, 2, [][]tensor.Dtype{{tensor.Float32}, {tensor.Float32}}, "mockop")), }, { &MockOp{ - minInputs: 3, - maxInputs: 5, - inputTypeConstraints: [][]tensor.Dtype{ - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, - {tensor.Float32}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 3, + maxInputs: 5, + inputTypeConstraints: [][]tensor.Dtype{ + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + {tensor.Float32}, + }, + name: "mockop", }, }, PaddedInputsFixture(7, 0), 0, - ErrInvalidOptionalInputCount(7, &MockOp{minInputs: 3, maxInputs: 5}), + ErrInvalidOptionalInputCount(7, NewBaseOperator(1, 3, 5, [][]tensor.Dtype{{tensor.Float32}, {tensor.Float32}, {tensor.Float32}, {tensor.Float32}, {tensor.Float32}}, "mockop")), }, { &MockOp{ - minInputs: 2, - maxInputs: 2, - inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}, {tensor.Float64}}, + BaseOperator: BaseOperator{ + version: 1, + minInputs: 2, + maxInputs: 2, + inputTypeConstraints: [][]tensor.Dtype{{tensor.Float32}, {tensor.Float64}}, + name: "mockop", + }, }, PaddedInputsFixture(2, 0), 0, - ErrInvalidInputType(1, "float32", &MockOp{}), + ErrInvalidInputType(1, "float32", NewBaseOperator(1, 2, 2, [][]tensor.Dtype{{tensor.Float32}, {tensor.Float32}}, "mockop")), }, } for _, test := range tests { - inputs, err := ValidateInputs(test.op, test.inputs) + inputs, err := ValidateInputs(test.op.BaseOperator, test.inputs) if test.err != nil { assert.EqualError(t, err, test.err.Error()) } @@ -150,9 +181,7 @@ func PaddedInputsFixture(nTensors, nNil int) []tensor.Tensor { } type MockOp struct { - minInputs int - maxInputs int - inputTypeConstraints [][]tensor.Dtype + BaseOperator } func (m *MockOp) Init(*onnx.NodeProto) error { @@ -162,23 +191,3 @@ func (m *MockOp) Init(*onnx.NodeProto) error { func (m *MockOp) Apply(_ []tensor.Tensor) ([]tensor.Tensor, error) { return nil, nil } - -func (m *MockOp) String() string { - return "mock op" -} - -func (m *MockOp) GetMinInputs() int { - return m.minInputs -} - -func (m *MockOp) GetMaxInputs() int { - return m.maxInputs -} - -func (m *MockOp) GetInputTypeConstraints() [][]tensor.Dtype { - return m.inputTypeConstraints -} - -func (m *MockOp) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { - return inputs, nil -} diff --git a/ops/xor/versions.go b/ops/xor/versions.go new file mode 100644 index 0000000..872d440 --- /dev/null +++ b/ops/xor/versions.go @@ -0,0 +1,11 @@ +package xor + +import "github.com/advancedclimatesystems/gonnx/ops" + +var xorVersions = ops.OperatorVersions{ + 7: ops.NewOperatorConstructor(newXor, 7, xorTypeConstraints), +} + +func GetXorVersions() ops.OperatorVersions { + return xorVersions +} diff --git a/ops/xor/xor.go b/ops/xor/xor.go new file mode 100644 index 0000000..5b169a9 --- /dev/null +++ b/ops/xor/xor.go @@ -0,0 +1,42 @@ +package xor + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var xorTypeConstraints = [][]tensor.Dtype{{tensor.Bool}, {tensor.Bool}} + +// Xor represents the ONNX xor operator. +type Xor struct { + ops.BaseOperator +} + +// newXor creates a new xor operator. +func newXor(version int, typeConstraint [][]tensor.Dtype) ops.Operator { + return &Xor{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraint, + "xor", + ), + } +} + +// Init initializes the xor operator. +func (x *Xor) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the xor operator. +func (x *Xor) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ApplyBinaryOperation( + inputs[0], + inputs[1], + ops.Xor, + ops.MultidirectionalBroadcasting, + ) +} diff --git a/ops/opset13/xor_test.go b/ops/xor/xor_test.go similarity index 84% rename from ops/opset13/xor_test.go rename to ops/xor/xor_test.go index 68658a4..194ed74 100644 --- a/ops/opset13/xor_test.go +++ b/ops/xor/xor_test.go @@ -1,4 +1,4 @@ -package opset13 +package xor import ( "testing" @@ -64,8 +64,9 @@ func TestXor(t *testing.T) { func TestInputValidationXor(t *testing.T) { tests := []struct { - inputs []tensor.Tensor - err error + inputs []tensor.Tensor + err error + version int64 }{ { []tensor.Tensor{ @@ -73,25 +74,28 @@ func TestInputValidationXor(t *testing.T) { ops.TensorWithBackingFixture([]bool{false, false}, 2), }, nil, + 7, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), }, - ops.ErrInvalidInputCount(1, &Xor{}), + ops.ErrInvalidInputCount(1, ops.NewBaseOperator(7, 2, 2, xorTypeConstraints, "xor")), + 7, }, { []tensor.Tensor{ ops.TensorWithBackingFixture([]bool{false, false}, 2), ops.TensorWithBackingFixture([]int{1, 2}, 2), }, - ops.ErrInvalidInputType(1, "int", &Xor{}), + ops.ErrInvalidInputType(1, "int", ops.NewBaseOperator(7, 2, 2, xorTypeConstraints, "xor")), + 7, }, } for _, test := range tests { - or := &Xor{} - validated, err := or.ValidateInputs(test.inputs) + xor := xorVersions[test.version]() + validated, err := xor.ValidateInputs(test.inputs) assert.Equal(t, test.err, err) diff --git a/ops_test.go b/ops_test.go index a07a351..07fb258 100644 --- a/ops_test.go +++ b/ops_test.go @@ -9,7 +9,6 @@ import ( "testing" "github.com/advancedclimatesystems/gonnx/onnx" - "github.com/advancedclimatesystems/gonnx/ops/opset13" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" "gorgonia.org/tensor" @@ -36,6 +35,7 @@ var ignoredTests = []string{ "test_logsoftmax_axis_2_expanded_ver18", // Opset18 "test_lstm_batchwise", // Opset14 "test_mul_uint8", // Opset14 + "test_reduce_max_empty_set", // Opset20 "test_reduce_max_do_not_keepdims_random", // Opset18 "test_reduce_max_keepdims_random", // Opset18 "test_reduce_max_default_axes_keepdims_random", // Opset18 @@ -76,9 +76,6 @@ var ignoredTests = []string{ "test_constant_pad", // Pad is not implemented yet. "test_constant_pad_axes", // Pad is not implemented yet. - "test_gemm_alpha", // For gemm in opset 11. - "test_gemm_default_no_bias", // For gemm in opset 11. - "test_gemm_default_scalar_bias", // For gemm in opset 11. "test_logsoftmax_large_number_expanded", // Requires 'Exp' operator. "test_logsoftmax_axis_0_expanded", // Requires 'Exp' operator. "test_logsoftmax_axis_1_expanded", // Requires 'Exp' operator. @@ -99,7 +96,6 @@ var ignoredTests = []string{ "test_slice_end_out_of_bounds", // ONNX expects nil output, but we throw an error. "test_slice_neg_steps", // ONNX expects nil output, but we throw an error. "test_slice_neg", // ONNX expects nil output, but we throw an error. - "test_transpose_default", // For transpose in opset 9. "test_equal_string", // Unsupported datatype String. "test_equal_string_broadcast", // Unsupported datatype String. @@ -146,7 +142,6 @@ var ignoredTests = []string{ "test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN", // Unsupported datatype. "test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2", // Unsupported datatype. - "test_unsqueeze_axis_3", // Tests an old version of Unsqueeze (<= 11) "test_constantofshape_int_shape_zero", // Empty tensors are not supported in gorgonia "test_gather_elements_0", // Operator GatherElements is not implemented "test_gather_elements_1", // Operator GatherElements is not implemented @@ -175,9 +170,8 @@ type ONNXTestCase struct { func TestOps(t *testing.T) { runnedTests := []string{} - opNames := opset13.GetOpNames() - for _, opName := range opNames { + for opName := range operators { tests, err := getTestCasesForOp(opName) assert.Nil(t, err) @@ -305,9 +299,13 @@ func readTestModel(folder string) (*Model, error) { return nil, err } - // Currently we only implemented Opset13, hence we enforce this in our tests. All + // Currently we support Opset 7-13, hence we enforce this in our tests. All // tests that fail because of this are ignored. - mp.OpsetImport[0].Version = 13 + if mp.OpsetImport[0].Version < MinSupportedOpset { + mp.OpsetImport[0].Version = MinSupportedOpset + } else if mp.OpsetImport[0].Version > MaxSupportedOpset { + mp.OpsetImport[0].Version = MaxSupportedOpset + } model, err := NewModel(mp) if err != nil { @@ -426,7 +424,10 @@ var expectedTests = []string{ "test_gather_negative_indices", "test_gemm_default_single_elem_vector_bias", "test_gemm_all_attributes", + "test_gemm_alpha", "test_gemm_default_matrix_bias", + "test_gemm_default_no_bias", + "test_gemm_default_scalar_bias", "test_gemm_default_vector_bias", "test_gemm_transposeA", "test_gemm_default_zero_bias", @@ -516,6 +517,7 @@ var expectedTests = []string{ "test_transpose_all_permutations_3", "test_transpose_all_permutations_4", "test_transpose_all_permutations_5", + "test_transpose_default", "test_unsqueeze_axis_0", "test_unsqueeze_axis_1", "test_unsqueeze_axis_2", diff --git a/opset.go b/opset.go index 2b83900..0bfcdb9 100644 --- a/opset.go +++ b/opset.go @@ -2,21 +2,159 @@ package gonnx import ( "github.com/advancedclimatesystems/gonnx/ops" - "github.com/advancedclimatesystems/gonnx/ops/opset13" + "github.com/advancedclimatesystems/gonnx/ops/abs" + "github.com/advancedclimatesystems/gonnx/ops/acos" + "github.com/advancedclimatesystems/gonnx/ops/acosh" + "github.com/advancedclimatesystems/gonnx/ops/add" + "github.com/advancedclimatesystems/gonnx/ops/and" + "github.com/advancedclimatesystems/gonnx/ops/argmax" + "github.com/advancedclimatesystems/gonnx/ops/asin" + "github.com/advancedclimatesystems/gonnx/ops/asinh" + "github.com/advancedclimatesystems/gonnx/ops/atan" + "github.com/advancedclimatesystems/gonnx/ops/atanh" + "github.com/advancedclimatesystems/gonnx/ops/cast" + "github.com/advancedclimatesystems/gonnx/ops/concat" + "github.com/advancedclimatesystems/gonnx/ops/constant" + "github.com/advancedclimatesystems/gonnx/ops/constantofshape" + "github.com/advancedclimatesystems/gonnx/ops/conv" + "github.com/advancedclimatesystems/gonnx/ops/cos" + "github.com/advancedclimatesystems/gonnx/ops/cosh" + "github.com/advancedclimatesystems/gonnx/ops/div" + "github.com/advancedclimatesystems/gonnx/ops/equal" + "github.com/advancedclimatesystems/gonnx/ops/expand" + "github.com/advancedclimatesystems/gonnx/ops/flatten" + "github.com/advancedclimatesystems/gonnx/ops/gather" + "github.com/advancedclimatesystems/gonnx/ops/gemm" + "github.com/advancedclimatesystems/gonnx/ops/greater" + "github.com/advancedclimatesystems/gonnx/ops/greaterorequal" + "github.com/advancedclimatesystems/gonnx/ops/gru" + "github.com/advancedclimatesystems/gonnx/ops/less" + "github.com/advancedclimatesystems/gonnx/ops/lessorequal" + "github.com/advancedclimatesystems/gonnx/ops/linearregressor" + "github.com/advancedclimatesystems/gonnx/ops/logsoftmax" + "github.com/advancedclimatesystems/gonnx/ops/lstm" + "github.com/advancedclimatesystems/gonnx/ops/matmul" + "github.com/advancedclimatesystems/gonnx/ops/mul" + "github.com/advancedclimatesystems/gonnx/ops/not" + "github.com/advancedclimatesystems/gonnx/ops/or" + "github.com/advancedclimatesystems/gonnx/ops/prelu" + "github.com/advancedclimatesystems/gonnx/ops/reducemax" + "github.com/advancedclimatesystems/gonnx/ops/reducemin" + "github.com/advancedclimatesystems/gonnx/ops/relu" + "github.com/advancedclimatesystems/gonnx/ops/reshape" + "github.com/advancedclimatesystems/gonnx/ops/rnn" + "github.com/advancedclimatesystems/gonnx/ops/scaler" + "github.com/advancedclimatesystems/gonnx/ops/shape" + "github.com/advancedclimatesystems/gonnx/ops/sigmoid" + "github.com/advancedclimatesystems/gonnx/ops/sin" + "github.com/advancedclimatesystems/gonnx/ops/sinh" + "github.com/advancedclimatesystems/gonnx/ops/slice" + "github.com/advancedclimatesystems/gonnx/ops/softmax" + "github.com/advancedclimatesystems/gonnx/ops/squeeze" + "github.com/advancedclimatesystems/gonnx/ops/sub" + "github.com/advancedclimatesystems/gonnx/ops/tan" + "github.com/advancedclimatesystems/gonnx/ops/tanh" + "github.com/advancedclimatesystems/gonnx/ops/transpose" + "github.com/advancedclimatesystems/gonnx/ops/unsqueeze" + "github.com/advancedclimatesystems/gonnx/ops/xor" ) -// OpGetter is a function that gets an operator based on a string. -type OpGetter func(string) (ops.Operator, error) +const ( + MinSupportedOpset = 7 + MaxSupportedOpset = 13 +) + +// Opset is a set of operators matching a certain opset version. +type Opset map[string]func() ops.Operator + +var operators = map[string]ops.OperatorVersions{ + "Abs": abs.GetAbsVersions(), + "Acos": acos.GetAcosVersions(), + "Acosh": acosh.GetAcoshVersions(), + "Add": add.GetAddVersions(), + "And": and.GetAndVersions(), + "ArgMax": argmax.GetArgMaxVersions(), + "Asin": asin.GetAsinVersions(), + "Asinh": asinh.GetAsinhVersions(), + "Atan": atan.GetAtanVersions(), + "Atanh": atanh.GetAtanhVersions(), + "Cast": cast.GetCastVersions(), + "Concat": concat.GetConcatVersions(), + "Constant": constant.GetConstantVersions(), + "ConstantOfShape": constantofshape.GetConstantOfShapeVersions(), + "Conv": conv.GetConvVersions(), + "Cos": cos.GetCosVersions(), + "Cosh": cosh.GetCoshVersions(), + "Div": div.GetDivVersions(), + "Equal": equal.GetEqualVersions(), + "Expand": expand.GetExpandVersions(), + "Flatten": flatten.GetFlattenVersions(), + "Gather": gather.GetGatherVersions(), + "Gemm": gemm.GetGemmVersions(), + "Greater": greater.GetGreaterVersions(), + "GreaterOrEqual": greaterorequal.GetGreaterOrEqualVersions(), + "GRU": gru.GetGRUVersions(), + "Less": less.GetLessVersions(), + "LessOrEqual": lessorequal.GetLessOrEqualVersions(), + "LinearRegressor": linearregressor.GetLinearRegressorVersions(), + "LogSoftmax": logsoftmax.GetLogSoftmaxVersions(), + "LSTM": lstm.GetLSTMVersions(), + "MatMul": matmul.GetMatMulVersions(), + "Mul": mul.GetMulVersions(), + "Not": not.GetNotVersions(), + "Or": or.GetOrVersions(), + "PRelu": prelu.GetPReluVersions(), + "ReduceMax": reducemax.GetReduceMaxVersions(), + "ReduceMin": reducemin.GetReduceMinVersions(), + "Relu": relu.GetReluVersions(), + "Reshape": reshape.GetReshapeVersions(), + "RNN": rnn.GetRNNVersions(), + "Scaler": scaler.GetScalerVersions(), + "Shape": shape.GetShapeVersions(), + "Sigmoid": sigmoid.GetSigmoidVersions(), + "Sin": sin.GetSinVersions(), + "Sinh": sinh.GetSinhVersions(), + "Slice": slice.GetSliceVersions(), + "Softmax": softmax.GetSoftmaxVersions(), + "Squeeze": squeeze.GetSqueezeVersions(), + "Sub": sub.GetSubVersions(), + "Tan": tan.GetTanVersions(), + "Tanh": tanh.GetTanhVersions(), + "Transpose": transpose.GetTransposeVersions(), + "Unsqueeze": unsqueeze.GetUnsqueezeVersions(), + "Xor": xor.GetXorVersions(), +} + +// GetClosestOperatorVersion resolves, given a certain opset version, the operator version that is closest +// to that version, going downwards. So if the opset version is 13, and an operator has version 13, this +// one is used. If the opset version is 13, and an operator has versions 7 and 14, version 7 is used, as +// it is the closest opset version going downwards. +func GetClosestOperatorVersion(opsetID int64, versions ops.OperatorVersions) func() ops.Operator { + for closestOpset := opsetID; closestOpset >= 1; closestOpset-- { + if operator, ok := versions[closestOpset]; ok { + return operator + } + } -var operatorGetters = map[int64]OpGetter{ - 13: opset13.GetOperator, + return nil } -// ResolveOperatorGetter resolves the getter for operators based on the opset version. -func ResolveOperatorGetter(opsetID int64) (OpGetter, error) { - if getOperator, ok := operatorGetters[opsetID]; ok { - return getOperator, nil +// ResolveOpset resolves the opset with all closest operator versions for the given opset version. +func ResolveOpset(opsetID int64) (Opset, error) { + if opsetID < MinSupportedOpset || opsetID > MaxSupportedOpset { + return nil, ops.ErrUnsupportedOpsetVersion + } + + opset := map[string]func() ops.Operator{} + + for operatorName, operatorVersions := range operators { + operator := GetClosestOperatorVersion(opsetID, operatorVersions) + if operator == nil { + continue + } + + opset[operatorName] = operator } - return nil, ops.ErrUnsupportedOpsetVersion + return opset, nil } diff --git a/opset_test.go b/opset_test.go index bd986c9..63b397b 100644 --- a/opset_test.go +++ b/opset_test.go @@ -7,8 +7,13 @@ import ( "github.com/stretchr/testify/assert" ) -func TestResolveOperatorGetterFail(t *testing.T) { - opGetter, err := ResolveOperatorGetter(12) - assert.Nil(t, opGetter) +func TestResolveOpset(t *testing.T) { + _, err := ResolveOpset(13) + assert.Nil(t, err) +} + +func TestResolveOpsetNotSupported(t *testing.T) { + opset, err := ResolveOpset(6) + assert.Nil(t, opset) assert.Equal(t, ops.ErrUnsupportedOpsetVersion, err) }