Skip to content

Commit

Permalink
Added sqrt operator (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
Swopper050 authored Dec 22, 2024
1 parent 7c2b171 commit 7d3fe7a
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 0 deletions.
42 changes: 42 additions & 0 deletions ops/sqrt/sqrt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package sqrt

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var sqrtTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}}

// Sqrt represents the ONNX sqrt operator.
type Sqrt struct {
ops.BaseOperator
}

// newSqrt creates a new sqrt operator.
func newSqrt(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Sqrt{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraints,
"sqrt",
),
}
}

// Init initializes the sqrt operator.
func (s *Sqrt) Init(_ *onnx.NodeProto) error {
return nil
}

// Apply applies the sqrt operator.
func (s *Sqrt) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, err := tensor.Sqrt(inputs[0])
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}
99 changes: 99 additions & 0 deletions ops/sqrt/sqrt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package sqrt

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestSqrtInit(t *testing.T) {
s := &Sqrt{}
err := s.Init(nil)
assert.Nil(t, err)
}

func TestSqrt(t *testing.T) {
tests := []struct {
version int64
backing []float32
shape []int
expected []float32
}{
{
13,
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{1, 1.4142135, 1.7320508, 2},
},
{
6,
[]float32{1, 3, 4, 5},
[]int{1, 4},
[]float32{1, 1.7320508, 2, 2.236068},
},
{
13,
[]float32{1, 1, 1, 1},
[]int{1, 4},
[]float32{1, 1, 1, 1},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

sqrt := sqrtVersions[test.version]()
res, err := sqrt.Apply(inputs)
assert.Nil(t, err)

assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationSqrt(t *testing.T) {
tests := []struct {
version int64
inputs []tensor.Tensor
err error
}{
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
},
nil,
},
{
13,
[]tensor.Tensor{},
ops.ErrInvalidInputCount(0, ops.NewBaseOperator(13, 1, 1, sqrtTypeConstraints, "sqrt")),
},
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, sqrtTypeConstraints, "sqrt")),
},
}

for _, test := range tests {
sqrt := sqrtVersions[test.version]()
validated, err := sqrt.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)
assert.Equal(t, test.inputs, validated)
}
}
12 changes: 12 additions & 0 deletions ops/sqrt/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package sqrt

import "github.com/advancedclimatesystems/gonnx/ops"

var sqrtVersions = ops.OperatorVersions{
6: ops.NewOperatorConstructor(newSqrt, 6, sqrtTypeConstraints),
13: ops.NewOperatorConstructor(newSqrt, 13, sqrtTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return sqrtVersions
}
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ var expectedTests = []string{
"test_softmax_example",
"test_softmax_large_number",
"test_softmax_negative_axis",
"test_sqrt",
"test_sqrt_example",
"test_squeeze",
"test_sub",
"test_sub_bcast",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/sinh"
"github.com/advancedclimatesystems/gonnx/ops/slice"
"github.com/advancedclimatesystems/gonnx/ops/softmax"
"github.com/advancedclimatesystems/gonnx/ops/sqrt"
"github.com/advancedclimatesystems/gonnx/ops/squeeze"
"github.com/advancedclimatesystems/gonnx/ops/sub"
"github.com/advancedclimatesystems/gonnx/ops/tan"
Expand Down Expand Up @@ -116,6 +117,7 @@ var operators = map[string]ops.OperatorVersions{
"Sinh": sinh.GetSinhVersions(),
"Slice": slice.GetSliceVersions(),
"Softmax": softmax.GetSoftmaxVersions(),
"Sqrt": sqrt.GetVersions(),
"Squeeze": squeeze.GetSqueezeVersions(),
"Sub": sub.GetSubVersions(),
"Tan": tan.GetTanVersions(),
Expand Down

0 comments on commit 7d3fe7a

Please sign in to comment.