diff --git a/ops/sqrt/sqrt.go b/ops/sqrt/sqrt.go new file mode 100644 index 0000000..927c5a0 --- /dev/null +++ b/ops/sqrt/sqrt.go @@ -0,0 +1,42 @@ +package sqrt + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var sqrtTypeConstraints = [][]tensor.Dtype{{tensor.Float32, tensor.Float64}} + +// Sqrt represents the ONNX sqrt operator. +type Sqrt struct { + ops.BaseOperator +} + +// newSqrt creates a new sqrt operator. +func newSqrt(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Sqrt{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "sqrt", + ), + } +} + +// Init initializes the sqrt operator. +func (s *Sqrt) Init(_ *onnx.NodeProto) error { + return nil +} + +// Apply applies the sqrt operator. +func (s *Sqrt) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, err := tensor.Sqrt(inputs[0]) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/sqrt/sqrt_test.go b/ops/sqrt/sqrt_test.go new file mode 100644 index 0000000..0b6d202 --- /dev/null +++ b/ops/sqrt/sqrt_test.go @@ -0,0 +1,99 @@ +package sqrt + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestSqrtInit(t *testing.T) { + s := &Sqrt{} + err := s.Init(nil) + assert.Nil(t, err) +} + +func TestSqrt(t *testing.T) { + tests := []struct { + version int64 + backing []float32 + shape []int + expected []float32 + }{ + { + 13, + []float32{1, 2, 3, 4}, + []int{2, 2}, + []float32{1, 1.4142135, 1.7320508, 2}, + }, + { + 6, + []float32{1, 3, 4, 5}, + []int{1, 4}, + []float32{1, 1.7320508, 2, 2.236068}, + }, + { + 13, + []float32{1, 1, 1, 1}, + []int{1, 4}, + []float32{1, 1, 1, 1}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + sqrt := sqrtVersions[test.version]() + res, err := sqrt.Apply(inputs) + assert.Nil(t, err) + + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationSqrt(t *testing.T) { + tests := []struct { + version int64 + inputs []tensor.Tensor + err error + }{ + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{}, + ops.ErrInvalidInputCount(0, ops.NewBaseOperator(13, 1, 1, sqrtTypeConstraints, "sqrt")), + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", ops.NewBaseOperator(13, 1, 1, sqrtTypeConstraints, "sqrt")), + }, + } + + for _, test := range tests { + sqrt := sqrtVersions[test.version]() + validated, err := sqrt.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + assert.Equal(t, test.inputs, validated) + } +} diff --git a/ops/sqrt/versions.go b/ops/sqrt/versions.go new file mode 100644 index 0000000..b00f93b --- /dev/null +++ b/ops/sqrt/versions.go @@ -0,0 +1,12 @@ +package sqrt + +import "github.com/advancedclimatesystems/gonnx/ops" + +var sqrtVersions = ops.OperatorVersions{ + 6: ops.NewOperatorConstructor(newSqrt, 6, sqrtTypeConstraints), + 13: ops.NewOperatorConstructor(newSqrt, 13, sqrtTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return sqrtVersions +} diff --git a/ops_test.go b/ops_test.go index 07fb258..8d1ece4 100644 --- a/ops_test.go +++ b/ops_test.go @@ -503,6 +503,8 @@ var expectedTests = []string{ "test_softmax_example", "test_softmax_large_number", "test_softmax_negative_axis", + "test_sqrt", + "test_sqrt_example", "test_squeeze", "test_sub", "test_sub_bcast", diff --git a/opset.go b/opset.go index 0bfcdb9..2d35935 100644 --- a/opset.go +++ b/opset.go @@ -50,6 +50,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/sinh" "github.com/advancedclimatesystems/gonnx/ops/slice" "github.com/advancedclimatesystems/gonnx/ops/softmax" + "github.com/advancedclimatesystems/gonnx/ops/sqrt" "github.com/advancedclimatesystems/gonnx/ops/squeeze" "github.com/advancedclimatesystems/gonnx/ops/sub" "github.com/advancedclimatesystems/gonnx/ops/tan" @@ -116,6 +117,7 @@ var operators = map[string]ops.OperatorVersions{ "Sinh": sinh.GetSinhVersions(), "Slice": slice.GetSliceVersions(), "Softmax": softmax.GetSoftmaxVersions(), + "Sqrt": sqrt.GetVersions(), "Squeeze": squeeze.GetSqueezeVersions(), "Sub": sub.GetSubVersions(), "Tan": tan.GetTanVersions(),