Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cos operator #159

Merged
merged 11 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions ops/opset13/cos.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Cos represents the ONNX cos operator.
type Cos struct{}

// newCos creates a new cos operator.
func newCos() ops.Operator {
return &Cos{}
}

// Init initializes the cos operator.
func (c *Cos) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the cos operator.
func (c *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(cos[float32])
wipsel marked this conversation as resolved.
Show resolved Hide resolved
case tensor.Float64:
out, err = inputs[0].Apply(cos[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Cos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Cos) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Cos) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Cos) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Cos) String() string {
return "cos operator"
}

func cos[T ops.FloatType](x T) T {
return T(math.Cos(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/cos_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestCosInit(t *testing.T) {
c := &Cos{}

// since 'cos' does not have any attributes we pass in nil. This should not
// fail initializing the cos.
err := c.Init(nil)
assert.Nil(t, err)
}

func TestCos(t *testing.T) {
tests := []struct {
cos *Cos
backing []float32
shape []int
expected []float32
}{
{
&Cos{},
wipsel marked this conversation as resolved.
Show resolved Hide resolved
[]float32{-2, -1, 0, 1},
[]int{2, 2},
[]float32{-0.41614684, 0.5403023, 1, 0.5403023},
},
{
&Cos{},
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{0.5403023, -0.9899925, -0.6536436, 0.2836622},
},
{
&Cos{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{0.5403023, 0.5403023, 0.5403023, 0.5403023},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.cos.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationCos(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Cos{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Cos{}),
},
}

for _, test := range tests {
cos := &Cos{}
validated, err := cos.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ var operators13 = map[string]func() ops.Operator{
"Concat": newConcat,
"Constant": newConstant,
"ConstantOfShape": newConstantOfShape,
"Cos": newCos,
"Div": newDiv,
"Gather": newGather,
"Gemm": newGemm,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func TestGetOperator(t *testing.T) {
newConstantOfShape(),
nil,
},
{
"Cos",
newCos(),
nil,
},
{
"Div",
newDiv(),
Expand Down
17 changes: 17 additions & 0 deletions ops/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ops

import "gorgonia.org/tensor"

type FloatType interface {
float32 | float64
}

// AllTypes is a type constraint which allows all types.
var AllTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Float32, tensor.Float64,
tensor.Complex64, tensor.Complex128,
tensor.String,
tensor.Bool,
}
10 changes: 0 additions & 10 deletions ops/validate_inputs.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,6 @@ import (
"gorgonia.org/tensor"
)

// AllTypes is a type constraint which allows all types.
var AllTypes = []tensor.Dtype{
tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64,
tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64,
tensor.Float32, tensor.Float64,
tensor.Complex64, tensor.Complex128,
tensor.String,
tensor.Bool,
}

// ValidateInputs validates if a list of nodes has enough (not too few or too many) nodes.
// When there are fewer input nodes then the given max, the list is padded with nils.
// Expects either 1 requirement ==> the expected number of inputs, or 2 requirements,
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ var expectedTests = []string{
"test_constant",
"test_constantofshape_float_ones",
"test_constantofshape_int_zeros",
"test_cos",
"test_cos_example",
"test_div",
"test_div_bcast",
"test_div_example",
Expand Down