-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Scalable operator set implementations (#219)
* WIP on dynamic operator sets * WIP on migrating operators * WIP on operator migration * Refactored versions * Finished all operator refactors * Fix lint errors * Add constants file * Remove print statement * Fix lint last ones * Proposal for base operators * POC: new design for multiple operator versions * Rewrote abs operator so it shares code * Rewrote cos into base operator * Refactor acosh into base operator * Refactor xor * Refactored unsqueeze operator * Refactored Transpose operator * Refactored tanh operator * Refactored Tan operator * Refactored input validation tests * Refactored Sub operator * Refactored squeeze operator * Refactored Softmax operator * Refactored slice, sinh and sin operators * Refactored gemm, relu, reshape, rnn, scaler, shape and sigmoid operator * Refactored reducemin operator * Refactored ReduceMax operator * Refactored PRelu operator * Refactor Mul, Not and Or operator * Refactored MatMul operator * Refactored LessOrEqual operator * Refacotred LSTM and LogSoftmax operators * Refactored LinearRegressor operator * Refactored Less operator * Refactor GRU, GreaterOrEqual, Greater, Gather and Concat operators * Refactored Expand and Equal operators * Refactored Div operator * Refactored Cosh operator * Refactored cos operator * Wip on constant of shape * Refactored remaining operators * Changed the way we use NewOperatorConstructor * Small fixes * Reinitialize operators on multiple uses * Fix lint * Fix lstm tests --------- Co-authored-by: wisse <[email protected]>
- Loading branch information
1 parent
7ed9d6f
commit 7c2b171
Showing
229 changed files
with
7,253 additions
and
4,921 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,6 @@ linters: | |
- godot | ||
- godox | ||
- goerr113 | ||
- gomnd | ||
- goprintffuncname | ||
- govet | ||
- ineffassign | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package abs | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/onnx" | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
"gorgonia.org/tensor" | ||
) | ||
|
||
var absTypeConstraint = [][]tensor.Dtype{ | ||
{tensor.Uint8, tensor.Uint16, tensor.Uint32, tensor.Uint64, tensor.Int8, tensor.Int16, tensor.Int32, tensor.Int64, tensor.Float32, tensor.Float64}, | ||
} | ||
|
||
// Abs represents the ONNX abs operator. | ||
type Abs struct { | ||
ops.BaseOperator | ||
} | ||
|
||
// newAbs creates a new abs operator. | ||
func newAbs(version int, typeConstraint [][]tensor.Dtype) ops.Operator { | ||
return &Abs{ | ||
BaseOperator: ops.NewBaseOperator( | ||
version, | ||
1, | ||
1, | ||
typeConstraint, | ||
"abs", | ||
), | ||
} | ||
} | ||
|
||
// Init initializes the abs operator. | ||
func (a *Abs) Init(*onnx.NodeProto) error { | ||
return nil | ||
} | ||
|
||
// Apply applies the abs operator. | ||
func (a *Abs) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { | ||
out, err := tensor.Abs(inputs[0]) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return []tensor.Tensor{out}, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package abs | ||
|
||
import ( | ||
"github.com/advancedclimatesystems/gonnx/ops" | ||
) | ||
|
||
var absVersions = ops.OperatorVersions{ | ||
6: ops.NewOperatorConstructor(newAbs, 6, absTypeConstraint), | ||
13: ops.NewOperatorConstructor(newAbs, 13, absTypeConstraint), | ||
} | ||
|
||
func GetAbsVersions() ops.OperatorVersions { | ||
return absVersions | ||
} |
Oops, something went wrong.