-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added CumSum operator * Clean up comments * Rewrite to switch
- Loading branch information
1 parent
c97d70c
commit c33e5d2
Showing
7 changed files
with
326 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
package cumsum | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
var cumsumTypeConstraints = [][]tensor.Dtype{ | ||
{tensor.Int32, tensor.Int64, tensor.Uint32, tensor.Uint64, tensor.Float32, tensor.Float64}, | ||
{tensor.Int32, tensor.Int64}, | ||
} | ||
|
||
// CumSum represents the ONNX cumsum operator. | ||
type CumSum struct { | ||
ops.BaseOperator | ||
|
||
exclusive bool | ||
reverse bool | ||
} | ||
|
||
// newCumSum creates a new cumsum operator. | ||
func newCumSum(version int, typeConstraints [][]tensor.Dtype) ops.Operator { | ||
return &CumSum{ | ||
BaseOperator: ops.NewBaseOperator( | ||
version, | ||
2, | ||
2, | ||
typeConstraints, | ||
"cumsum", | ||
), | ||
exclusive: false, | ||
reverse: false, | ||
} | ||
} | ||
|
||
// Init initializes the cumsum operator. | ||
func (c *CumSum) Init(n *onnx.NodeProto) error { | ||
for _, attr := range n.GetAttribute() { | ||
switch attr.GetName() { | ||
case "exclusive": | ||
c.exclusive = attr.GetI() == 1 | ||
case "reverse": | ||
c.reverse = attr.GetI() == 1 | ||
default: | ||
return ops.ErrInvalidAttribute(attr.GetName(), c) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// Apply applies the cumsum operator. | ||
func (c *CumSum) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
axis, err := ops.AnyToInt(inputs[1].ScalarValue()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
out, err := cumsum(inputs[0], axis, c.exclusive, c.reverse) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return []tensor.Tensor{out}, nil | ||
} | ||
|
||
// Performs cumulative sum of the input elements along the given axis. | ||
// Exclusive means the the cumsum for position j will not include the j-th element. | ||
// Reverse means the cumsum will be performed in reverse order. | ||
func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, error) { | ||
out, ok := x.Clone().(tensor.Tensor) | ||
if !ok { | ||
return nil, ops.ErrCast | ||
} | ||
|
||
nDims := len(x.Shape()) | ||
axis = ops.ConvertNegativeAxis(axis, nDims) | ||
|
||
if axis < 0 || axis >= nDims { | ||
return nil, ops.ErrAxisOutOfRange(0, nDims, axis) | ||
} | ||
|
||
axisSize := x.Shape()[axis] | ||
|
||
var startValue int | ||
if reverse { | ||
startValue = axisSize - 1 | ||
} else { | ||
startValue = 0 | ||
} | ||
|
||
slices := make([]tensor.Slice, nDims) | ||
slices[axis] = ops.NewSlicer(startValue, startValue+1) | ||
|
||
prevView, err := x.Slice(slices...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
prevValues := prevView.Materialize() | ||
|
||
for i := startValue; endValueReached(i, axisSize, reverse); { | ||
slices[axis] = ops.NewSlicer(i, i+1) | ||
|
||
currentView, err := out.Slice(slices...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
currentValues := currentView.Materialize() | ||
|
||
switch { | ||
// If exclusive is true, the first result in the cumsum opertaion is zero. | ||
// We can achieve this by subtracting the current values from the current values. | ||
// This way we don't have to infer the underlying type of the tensor. | ||
case i == startValue && exclusive: | ||
zeroValues, err := ops.Sub(currentValues, currentValues) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
err = tensor.Copy(currentView, zeroValues) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
case i != startValue && exclusive: | ||
err = tensor.Copy(currentView, prevValues) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
newValues, err := ops.Add(currentValues, prevValues) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
prevValues = newValues | ||
case i != startValue: | ||
newValues, err := ops.Add(currentValues, prevValues) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
err = tensor.Copy(currentView, newValues) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
prevValues = newValues | ||
} | ||
|
||
if reverse { | ||
i-- | ||
} else { | ||
i++ | ||
} | ||
} | ||
|
||
return out, nil | ||
} | ||
|
||
func endValueReached(i, axisSize int, reverse bool) bool { | ||
if reverse { | ||
return i >= 0 | ||
} | ||
|
||
return i < axisSize | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package cumsum | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"github.com/stretchr/testify/assert" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
func TestCumSumInit(t *testing.T) { | ||
c := &CumSum{} | ||
err := c.Init( | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "exclusive", I: 1}, | ||
{Name: "reverse", I: 1}, | ||
}, | ||
}, | ||
) | ||
|
||
assert.Nil(t, err) | ||
assert.Equal(t, true, c.exclusive) | ||
assert.Equal(t, true, c.reverse) | ||
} | ||
|
||
func TestCumSumInitDefaults(t *testing.T) { | ||
c := &CumSum{} | ||
err := c.Init( | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{}, | ||
}, | ||
) | ||
|
||
assert.Nil(t, err) | ||
assert.Equal(t, false, c.exclusive) | ||
assert.Equal(t, false, c.reverse) | ||
} | ||
|
||
func TestCumSum(t *testing.T) { | ||
tests := []struct { | ||
version int64 | ||
node *onnx.NodeProto | ||
backing []float32 | ||
axis int32 | ||
shape []int | ||
expected []float32 | ||
}{ | ||
{ | ||
11, | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "exclusive", I: 0}, | ||
{Name: "reverse", I: 0}, | ||
}, | ||
}, | ||
[]float32{1, 2, 3, 4}, | ||
0, | ||
[]int{2, 2}, | ||
[]float32{1, 2, 4, 6}, | ||
}, | ||
{ | ||
11, | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "exclusive", I: 0}, | ||
{Name: "reverse", I: 0}, | ||
}, | ||
}, | ||
[]float32{1, 2, 3, 4}, | ||
1, | ||
[]int{2, 2}, | ||
[]float32{1, 3, 3, 7}, | ||
}, | ||
{ | ||
11, | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "exclusive", I: 1}, | ||
{Name: "reverse", I: 0}, | ||
}, | ||
}, | ||
[]float32{1, 2, 3}, | ||
0, | ||
[]int{3}, | ||
[]float32{0, 1, 3}, | ||
}, | ||
{ | ||
11, | ||
&onnx.NodeProto{ | ||
Attribute: []*onnx.AttributeProto{ | ||
{Name: "exclusive", I: 0}, | ||
{Name: "reverse", I: 1}, | ||
}, | ||
}, | ||
[]float32{1, 2, 3}, | ||
0, | ||
[]int{3}, | ||
[]float32{6, 5, 3}, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
inputs := []tensor.Tensor{ | ||
ops.TensorWithBackingFixture(test.backing, test.shape...), | ||
tensor.New(tensor.FromScalar(test.axis)), | ||
} | ||
|
||
cumsum := cumsumVersions[test.version]() | ||
err := cumsum.Init(test.node) | ||
assert.Nil(t, err) | ||
|
||
res, err := cumsum.Apply(inputs) | ||
assert.Nil(t, err) | ||
assert.Equal(t, test.expected, res[0].Data()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package cumsum | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
) | ||
|
||
var cumsumVersions = ops.OperatorVersions{ | ||
11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints), | ||
} | ||
|
||
func GetVersions() ops.OperatorVersions { | ||
return cumsumVersions | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters