Skip to content

Commit

Permalink
Added Sinh operator (#158)
Browse files Browse the repository at this point in the history
* Added sinh operator

* Updated comments

* Use FloatType

---------

Co-authored-by: Swopper050 <[email protected]>
  • Loading branch information
Swopper050 and Swopper050 authored Nov 26, 2023
1 parent 22aba5b commit 62247cb
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 0 deletions.
1 change: 1 addition & 0 deletions ops/opset13/opset13.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var operators13 = map[string]func() ops.Operator{
"Shape": newShape,
"Sigmoid": newSigmoid,
"Sin": newSin,
"Sinh": newSinh,
"Slice": newSlice,
"Squeeze": newSqueeze,
"Sub": newSub,
Expand Down
5 changes: 5 additions & 0 deletions ops/opset13/opset13_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ func TestGetOperator(t *testing.T) {
newSin(),
nil,
},
{
"Sinh",
newSinh(),
nil,
},
{
"Slice",
newSlice(),
Expand Down
75 changes: 75 additions & 0 deletions ops/opset13/sinh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package opset13

import (
"math"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

// Sinh represents the ONNX sinh operator.
type Sinh struct{}

// newSin creates a new sinh operator.
func newSinh() ops.Operator {
return &Sinh{}
}

// Init initializes the sinh operator.
func (s *Sinh) Init(_ []*onnx.AttributeProto) error {
return nil
}

// Apply applies the sinh operator.
func (s *Sinh) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
var (
out tensor.Tensor
err error
)

switch inputs[0].Dtype() {
case tensor.Float32:
out, err = inputs[0].Apply(sinh[float32])
case tensor.Float64:
out, err = inputs[0].Apply(sinh[float64])
default:
return nil, ops.ErrInvalidInputType(0, inputs[0].Dtype().String(), s)
}

if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (s *Sinh) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(s, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (s *Sinh) GetMinInputs() int {
return 1
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (s *Sinh) GetMaxInputs() int {
return 1
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (s *Sinh) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (s *Sinh) String() string {
return "sinh operator"
}

func sinh[T ops.FloatType](x T) T {
return T(math.Sinh(float64(x)))
}
99 changes: 99 additions & 0 deletions ops/opset13/sinh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestSinhInit(t *testing.T) {
s := &Sinh{}

// since 'sinh' does not have any attributes we pass in nil. This should not
// fail initializing the sinh.
err := s.Init(nil)
assert.Nil(t, err)
}

func TestSinh(t *testing.T) {
tests := []struct {
sinh *Sinh
backing []float32
shape []int
expected []float32
}{
{
&Sinh{},
[]float32{-2, -1, 0, 1},
[]int{2, 2},
[]float32{-3.6268604, -1.1752012, 0, 1.1752012},
},
{
&Sinh{},
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{1.1752012, 10.017875, 27.289917, 74.20321},
},
{
&Sinh{},
[]float32{-1, -1, -1, -1},
[]int{1, 4},
[]float32{-1.1752012, -1.1752012, -1.1752012, -1.1752012},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

res, err := test.sinh.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationSinh(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, &Sinh{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", &Sinh{}),
},
}

for _, test := range tests {
sinh := &Sinh{}
validated, err := sinh.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ var expectedTests = []string{
"test_sin_example",
"test_sigmoid_example",
"test_sigmoid",
"test_sinh",
"test_sinh_example",
"test_slice_negative_axes",
"test_slice_default_steps",
"test_slice",
Expand Down

0 comments on commit 62247cb

Please sign in to comment.