Skip to content

Commit

Permalink
Added CumSum operator (#225)
Browse files Browse the repository at this point in the history
* Added CumSum operator

* Clean up comments

* Rewrite to switch
  • Loading branch information
Swopper050 authored Dec 22, 2024
1 parent c97d70c commit c33e5d2
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 2 deletions.
2 changes: 0 additions & 2 deletions model_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gonnx

import (
"fmt"
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
Expand Down Expand Up @@ -97,7 +96,6 @@ func TestModel(t *testing.T) {
}

for _, test := range tests {
fmt.Println(test.path)
model, err := NewModelFromFile(test.path)
assert.Nil(t, err)

Expand Down
170 changes: 170 additions & 0 deletions ops/cumsum/cumsum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package cumsum

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var cumsumTypeConstraints = [][]tensor.Dtype{
{tensor.Int32, tensor.Int64, tensor.Uint32, tensor.Uint64, tensor.Float32, tensor.Float64},
{tensor.Int32, tensor.Int64},
}

// CumSum represents the ONNX cumsum operator.
type CumSum struct {
ops.BaseOperator

exclusive bool
reverse bool
}

// newCumSum creates a new cumsum operator.
func newCumSum(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &CumSum{
BaseOperator: ops.NewBaseOperator(
version,
2,
2,
typeConstraints,
"cumsum",
),
exclusive: false,
reverse: false,
}
}

// Init initializes the cumsum operator.
func (c *CumSum) Init(n *onnx.NodeProto) error {
for _, attr := range n.GetAttribute() {
switch attr.GetName() {
case "exclusive":
c.exclusive = attr.GetI() == 1
case "reverse":
c.reverse = attr.GetI() == 1
default:
return ops.ErrInvalidAttribute(attr.GetName(), c)
}
}

return nil
}

// Apply applies the cumsum operator.
func (c *CumSum) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
axis, err := ops.AnyToInt(inputs[1].ScalarValue())
if err != nil {
return nil, err
}

out, err := cumsum(inputs[0], axis, c.exclusive, c.reverse)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// Performs cumulative sum of the input elements along the given axis.
// Exclusive means the the cumsum for position j will not include the j-th element.
// Reverse means the cumsum will be performed in reverse order.
func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, error) {
out, ok := x.Clone().(tensor.Tensor)
if !ok {
return nil, ops.ErrCast
}

nDims := len(x.Shape())
axis = ops.ConvertNegativeAxis(axis, nDims)

if axis < 0 || axis >= nDims {
return nil, ops.ErrAxisOutOfRange(0, nDims, axis)
}

axisSize := x.Shape()[axis]

var startValue int
if reverse {
startValue = axisSize - 1
} else {
startValue = 0
}

slices := make([]tensor.Slice, nDims)
slices[axis] = ops.NewSlicer(startValue, startValue+1)

prevView, err := x.Slice(slices...)
if err != nil {
return nil, err
}

prevValues := prevView.Materialize()

for i := startValue; endValueReached(i, axisSize, reverse); {
slices[axis] = ops.NewSlicer(i, i+1)

currentView, err := out.Slice(slices...)
if err != nil {
return nil, err
}

currentValues := currentView.Materialize()

switch {
// If exclusive is true, the first result in the cumsum opertaion is zero.
// We can achieve this by subtracting the current values from the current values.
// This way we don't have to infer the underlying type of the tensor.
case i == startValue && exclusive:
zeroValues, err := ops.Sub(currentValues, currentValues)
if err != nil {
return nil, err
}

err = tensor.Copy(currentView, zeroValues)
if err != nil {
return nil, err
}

case i != startValue && exclusive:
err = tensor.Copy(currentView, prevValues)
if err != nil {
return nil, err
}

newValues, err := ops.Add(currentValues, prevValues)
if err != nil {
return nil, err
}

prevValues = newValues
case i != startValue:
newValues, err := ops.Add(currentValues, prevValues)
if err != nil {
return nil, err
}

err = tensor.Copy(currentView, newValues)
if err != nil {
return nil, err
}

prevValues = newValues
}

if reverse {
i--
} else {
i++
}
}

return out, nil
}

func endValueReached(i, axisSize int, reverse bool) bool {
if reverse {
return i >= 0
}

return i < axisSize
}
118 changes: 118 additions & 0 deletions ops/cumsum/cumsum_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package cumsum

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestCumSumInit(t *testing.T) {
c := &CumSum{}
err := c.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 1},
{Name: "reverse", I: 1},
},
},
)

assert.Nil(t, err)
assert.Equal(t, true, c.exclusive)
assert.Equal(t, true, c.reverse)
}

func TestCumSumInitDefaults(t *testing.T) {
c := &CumSum{}
err := c.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{},
},
)

assert.Nil(t, err)
assert.Equal(t, false, c.exclusive)
assert.Equal(t, false, c.reverse)
}

func TestCumSum(t *testing.T) {
tests := []struct {
version int64
node *onnx.NodeProto
backing []float32
axis int32
shape []int
expected []float32
}{
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 0},
{Name: "reverse", I: 0},
},
},
[]float32{1, 2, 3, 4},
0,
[]int{2, 2},
[]float32{1, 2, 4, 6},
},
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 0},
{Name: "reverse", I: 0},
},
},
[]float32{1, 2, 3, 4},
1,
[]int{2, 2},
[]float32{1, 3, 3, 7},
},
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 1},
{Name: "reverse", I: 0},
},
},
[]float32{1, 2, 3},
0,
[]int{3},
[]float32{0, 1, 3},
},
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 0},
{Name: "reverse", I: 1},
},
},
[]float32{1, 2, 3},
0,
[]int{3},
[]float32{6, 5, 3},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
tensor.New(tensor.FromScalar(test.axis)),
}

cumsum := cumsumVersions[test.version]()
err := cumsum.Init(test.node)
assert.Nil(t, err)

res, err := cumsum.Apply(inputs)
assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}
13 changes: 13 additions & 0 deletions ops/cumsum/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package cumsum

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var cumsumVersions = ops.OperatorVersions{
11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return cumsumVersions
}
16 changes: 16 additions & 0 deletions ops/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ func OffsetTensorIfNegative(t tensor.Tensor, offset int) error {
return nil
}

// AnyToInt casts the given data to an int, but only if the data is of some sort of int type.
func AnyToInt(value interface{}) (int, error) {
switch data := value.(type) {
case int8:
return int(data), nil
case int16:
return int(data), nil
case int32:
return int(data), nil
case int64:
return int(data), nil
default:
return 0, ErrCast
}
}

// AnyToIntSlice casts the data of a node to an int list. This will only
// be done if the data is of some sort of int type.
func AnyToIntSlice(value interface{}) ([]int, error) {
Expand Down
7 changes: 7 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,13 @@ var expectedTests = []string{
"test_cos_example",
"test_cosh",
"test_cosh_example",
"test_cumsum_1d",
"test_cumsum_1d_exclusive",
"test_cumsum_1d_reverse",
"test_cumsum_1d_reverse_exclusive",
"test_cumsum_2d_axis_0",
"test_cumsum_2d_axis_1",
"test_cumsum_2d_negative_axis",
"test_div",
"test_div_bcast",
"test_div_example",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/conv"
"github.com/advancedclimatesystems/gonnx/ops/cos"
"github.com/advancedclimatesystems/gonnx/ops/cosh"
"github.com/advancedclimatesystems/gonnx/ops/cumsum"
"github.com/advancedclimatesystems/gonnx/ops/div"
"github.com/advancedclimatesystems/gonnx/ops/equal"
"github.com/advancedclimatesystems/gonnx/ops/erf"
Expand Down Expand Up @@ -89,6 +90,7 @@ var operators = map[string]ops.OperatorVersions{
"Conv": conv.GetConvVersions(),
"Cos": cos.GetCosVersions(),
"Cosh": cosh.GetCoshVersions(),
"CumSum": cumsum.GetVersions(),
"Div": div.GetDivVersions(),
"Equal": equal.GetEqualVersions(),
"Erf": erf.GetVersions(),
Expand Down

0 comments on commit c33e5d2

Please sign in to comment.