Skip to content

Commit

Permalink
Added acos operator (#162)
Browse files Browse the repository at this point in the history
* Added acos operator

* Merge develop

* Group declarations

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Nov 24, 2023
1 parent 8431865 commit 1e529ec
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 0 deletions.
75 changes: 75 additions & 0 deletions ops/opset13/acos.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Acos represents the ONNX acos operator.
type Acos struct{}

// newAcos creates a new acos operator.
func newAcos() ops.Operator {
return &Acos{}
}

// Init initializes the acos operator.
func (c *Acos) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the acos operator.
func (c *Acos) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(acos[float32])
case tensor.Float64:
out, err = inputs[0].Apply(acos[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), c)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (c *Acos) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(c, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (c *Acos) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (c *Acos) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (c *Acos) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (c *Acos) String() string {
return "acos operator"
}

func acos[T ops.FloatType](x T) T {
return T(math.Acos(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/acos_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestAcosInit(t *testing.T) {
c := &Acos{}

// since 'acos' does not have any attributes we pass in nil. This should not
// fail initializing the acos.
err := c.Init(nil)
assert.Nil(t, err)
}

func TestAcos(t *testing.T) {
tests := []struct {
acos *Acos
backing []float32
shape []int
expected []float32
}{
{
&Acos{},
[]float32{-1, -1, 0, 1},
[]int{2, 2},
[]float32{3.1415927, 3.1415927, 1.5707964, 0},
},
{
&Acos{},
[]float32{1, 0.5, 0.0, -0.5},
[]int{1, 4},
[]float32{0, 1.0471976, 1.5707964, 2.0943952},
},
{
&Acos{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{3.1415927, 3.1415927, 3.1415927, 3.1415927},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.acos.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationAcos(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Acos{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Acos{}),
},
}

for _, test := range tests {
acos := &Acos{}
validated, err := acos.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

var operators13 = map[string]func() ops.Operator{
"Abs": newAbs,
"Acos": newAcos,
"Acosh": newAcosh,
"Add": newAdd,
"Cast": newCast,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ func TestGetOperator(t *testing.T) {
newAbs(),
nil,
},
{
"Acos",
newAcos(),
nil,
},
{
"Acosh",
newAcosh(),
Expand Down
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ func readTestTensors(basePath, baseFile string, inputs []*onnx.ValueInfoProto) (
// With this we check if we truly run all tests we expected from the integration test.
var expectedTests = []string{
"test_abs",
"test_acos",
"test_acos_example",
"test_acosh",
"test_acosh_example",
"test_add",
Expand Down

0 comments on commit 1e529ec

Please sign in to comment.