Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added cos operator #159

Merged
merged 11 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ops/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ const AxisOutOfRangeErrTemplate = "axis argument must be in the range -%d <= x <
// AxesNotAllInRangeErrTemplate is used to format an error when not all indices
// are within a given range.
const AxesNotAllInRangeErrTemplate = "all indices entries must be in the range -%d <= x < %d"

// UnsupportedDTypeError is used when the DType of a tensor is not supported.
const UnsupportedDtypeErrTemplate = "dtype %v is not supported for operator %v"
Swopper050 marked this conversation as resolved.
Show resolved Hide resolved
76 changes: 76 additions & 0 deletions ops/opset13/cos.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package opset13

import (
"fmt"
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Cos represents the ONNX cos operator.
type Cos struct{}

// newCos creates a new cos operator.
func newCos() ops.Operator {
return &Cos{}
}

// Init initializes the cos operator.
func (s *Cos) Init(attributes []*onnx.AttributeProto) error {
return nil
}

type CosDType interface {
Swopper050 marked this conversation as resolved.
Show resolved Hide resolved
float32 | float64
}

// Apply applies the sin operator.
func (s *Cos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var out tensor.Tensor
var err error
if inputs[0].Dtype() == tensor.Float32 {
Swopper050 marked this conversation as resolved.
Show resolved Hide resolved
out, err = inputs[0].Apply(cos[float32])
wipsel marked this conversation as resolved.
Show resolved Hide resolved
} else if inputs[0].Dtype() == tensor.Float64 {
out, err = inputs[0].Apply(cos[float64])
} else {
return nil, fmt.Errorf(ops.UnsupportedDtypeErrTemplate, inputs[0].Dtype(), s)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (s *Cos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(s, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (s *Cos) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (s *Cos) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (s *Cos) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (s *Cos) String() string {
Swopper050 marked this conversation as resolved.
Show resolved Hide resolved
return "cos operator"
}

func cos[T CosDType](x T) T {
return T(math.Cos(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/cos_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"fmt"
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestCosInit(t *testing.T) {
c := &Cos{}

// since 'cos' does not have any attributes we pass in nil. This should not
// fail initializing the cos.
err := c.Init(nil)
assert.Nil(t, err)
}

func TestCos(t *testing.T) {
tests := []struct {
cos *Cos
backing []float32
shape []int
expected []float32
}{
{
&Cos{},
wipsel marked this conversation as resolved.
Show resolved Hide resolved
[]float32{-2, -1, 0, 1},
[]int{2, 2},
[]float32{-0.41614684, 0.5403023, 1, 0.5403023},
},
{
&Cos{},
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{0.5403023, -0.9899925, -0.6536436, 0.2836622},
},
{
&Cos{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{0.5403023, 0.5403023, 0.5403023, 0.5403023},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.cos.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationCos(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
fmt.Errorf("cos operator: expected 1 input tensors, got 0"),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
fmt.Errorf("cos operator: input 0 does not allow type int"),
},
}

for _, test := range tests {
cos := &Cos{}
validated, err := cos.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ var operators13 = map[string]func() ops.Operator{
"Concat": newConcat,
"Constant": newConstant,
"ConstantOfShape": newConstantOfShape,
"Cos": newCos,
"Div": newDiv,
"Gather": newGather,
"Gemm": newGemm,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ func TestGetOperator(t *testing.T) {
newConstantOfShape(),
nil,
},
{
"Cos",
newCos(),
nil,
},
{
"Div",
newDiv(),
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ var expectedTests = []string{
"test_constant",
"test_constantofshape_float_ones",
"test_constantofshape_int_zeros",
"test_cos",
"test_cos_example",
"test_div",
"test_div_bcast",
"test_div_example",
Expand Down