Skip to content

Commit

Permalink
Add propeller type check and resolver
Browse files Browse the repository at this point in the history
Signed-off-by: byhsu <[email protected]>
  • Loading branch information
ByronHsu committed Oct 15, 2023
1 parent 5cebcdd commit 849e7c8
Show file tree
Hide file tree
Showing 8 changed files with 687 additions and 5 deletions.
11 changes: 11 additions & 0 deletions flytepropeller/pkg/compiler/errors/compiler_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ const (

// A gate node is missing a condition.
NoConditionFound ErrorCode = "NoConditionFound"

// Field not found in the dataclass
FieldNotFoundError ErrorCode = "FieldNotFound"
)

func NewBranchNodeNotSpecified(branchNodeID string) *CompileError {
Expand Down Expand Up @@ -216,6 +219,14 @@ func NewMismatchingBindingsErr(nodeID, sinkParam, expectedType, receivedType str
)
}

func NewFieldNotFoundErr(nodeID, fromVar, fromType, key string) *CompileError {
return newError(
FieldNotFoundError,
fmt.Sprintf("Variable [%v] (type [%v]) doesn't have field [%v].", fromVar, fromType, key),
nodeID,
)

Check warning on line 227 in flytepropeller/pkg/compiler/errors/compiler_errors.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/errors/compiler_errors.go#L222-L227

Added lines #L222 - L227 were not covered by tests
}

func NewIllegalEnumValueError(nodeID, sinkParam, receivedVal string, expectedVals []string) *CompileError {
return newError(
IllegalEnumValue,
Expand Down
21 changes: 21 additions & 0 deletions flytepropeller/pkg/compiler/validators/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,27 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin
sourceType = cType
}
}
var exist bool
var tmpType *flyte.LiteralType

// If the variable has an attribute path. Extract the type of the last attribute.
for _, attr := range val.Promise.AttrPath {
if sourceType.GetCollectionType() != nil {
sourceType = sourceType.GetCollectionType()
} else if sourceType.GetMapValueType() != nil {
sourceType = sourceType.GetMapValueType()
} else if sourceType.GetStructure() != nil && sourceType.GetStructure().GetDataclassType() != nil {

tmpType, exist = sourceType.GetStructure().GetDataclassType()[attr.GetStringValue()]

if !exist {
errs.Collect(errors.NewFieldNotFoundErr(nodeID, val.Promise.Var, sourceType.String(), attr.GetStringValue()))
return nil, nil, !errs.HasErrors()

Check warning on line 142 in flytepropeller/pkg/compiler/validators/bindings.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/compiler/validators/bindings.go#L141-L142

Added lines #L141 - L142 were not covered by tests
} else {
sourceType = tmpType
}
}
}

if !validateParamTypes || AreTypesCastable(sourceType, expectedType) {
val.Promise.NodeId = upNode.GetId()
Expand Down
149 changes: 149 additions & 0 deletions flytepropeller/pkg/compiler/validators/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

c "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common"
structpb "github.com/golang/protobuf/ptypes/struct"

"github.com/flyteorg/flyte/flyteidl/clients/go/coreutils"
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
Expand Down Expand Up @@ -305,6 +306,154 @@ func TestValidateBindings(t *testing.T) {
}
})

t.Run("List/Dict Promises with attribute path", func(t *testing.T) {
// List/Dict with attribute path should conduct validation

n := &mocks.NodeBuilder{}
n.OnGetId().Return("node1")
n.OnGetInterface().Return(&core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{},
},
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{},
},
})

n2 := &mocks.NodeBuilder{}
n2.OnGetId().Return("node2")
n2.OnGetOutputAliases().Return(nil)
n2.OnGetInterface().Return(&core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{},
},
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"n2_out": {
Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(map[string]interface{}{"x": []interface{}{1, 3, 4}})),
},
},
},
})

wf := &mocks.WorkflowBuilder{}
wf.OnGetNode("n2").Return(n2, true)
wf.On("AddExecutionEdge", mock.Anything, mock.Anything).Return(nil)

bindings := []*core.Binding{
{
Var: "x",
Binding: &core.BindingData{
Value: &core.BindingData_Promise{
Promise: &core.OutputReference{
Var: "n2_out",
NodeId: "n2",
AttrPath: []*core.PromiseAttribute{
{
Value: &core.PromiseAttribute_StringValue{"x"},
},
{
Value: &core.PromiseAttribute_IntValue{0},
},
},
},
},
},
},
}

vars := &core.VariableMap{
Variables: map[string]*core.Variable{
"x": {
Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(1)),
},
},
}

compileErrors := compilerErrors.NewCompileErrors()
_, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors)
assert.True(t, ok)
if compileErrors.HasErrors() {
assert.NoError(t, compileErrors)
}
})

t.Run("pb.Struct Promises with attribute path", func(t *testing.T) {
// Dataclass with attribute path should skip validation

n := &mocks.NodeBuilder{}
n.OnGetId().Return("node1")
n.OnGetInterface().Return(&core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{},
},
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{},
},
})

n2 := &mocks.NodeBuilder{}
n2.OnGetId().Return("node2")
n2.OnGetOutputAliases().Return(nil)
literalType := LiteralTypeForLiteral(coreutils.MustMakeLiteral(&structpb.Struct{}))
literalType.Structure = &core.TypeStructure{}
literalType.Structure.DataclassType = map[string]*core.LiteralType{"x": LiteralTypeForLiteral(coreutils.MustMakeLiteral(1))}

n2.OnGetInterface().Return(&core.TypedInterface{
Inputs: &core.VariableMap{
Variables: map[string]*core.Variable{},
},
Outputs: &core.VariableMap{
Variables: map[string]*core.Variable{
"n2_out": {
Type: literalType,
},
},
},
})

wf := &mocks.WorkflowBuilder{}
wf.OnGetNode("n2").Return(n2, true)
wf.On("AddExecutionEdge", mock.Anything, mock.Anything).Return(nil)

bindings := []*core.Binding{
{
Var: "x",
Binding: &core.BindingData{
Value: &core.BindingData_Promise{
Promise: &core.OutputReference{
Var: "n2_out",
NodeId: "n2",
AttrPath: []*core.PromiseAttribute{
{
Value: &core.PromiseAttribute_StringValue{"x"},
},
{
Value: &core.PromiseAttribute_IntValue{0},
},
},
},
},
},
},
}

vars := &core.VariableMap{
Variables: map[string]*core.Variable{
"x": {
Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(1)),
},
},
}

compileErrors := compilerErrors.NewCompileErrors()
_, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors)
assert.True(t, ok)
if compileErrors.HasErrors() {
assert.NoError(t, compileErrors)
}
})

t.Run("Nil Binding Value", func(t *testing.T) {
n := &mocks.NodeBuilder{}
n.OnGetId().Return("node1")
Expand Down
156 changes: 156 additions & 0 deletions flytepropeller/pkg/controller/nodes/attr_path_resolver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package nodes

import (
"context"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
"github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors"
"google.golang.org/protobuf/types/known/structpb"
)

// resolveAttrPathInPromise resolves the literal with attribute path
// If the promise is chained with attributes (e.g. promise.a["b"][0]), then we need to resolve the promise
func resolveAttrPathInPromise(ctx context.Context, nodeID string, literal *core.Literal, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) {
var currVal *core.Literal = literal
var tmpVal *core.Literal
var err error
var exist bool
count := 0

for _, attr := range bindAttrPath {
switch currVal.GetValue().(type) {
case *core.Literal_Map:
tmpVal, exist = currVal.GetMap().GetLiterals()[attr.GetStringValue()]
if exist == false {
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal.GetMap().GetLiterals())
}
currVal = tmpVal
count += 1
case *core.Literal_Collection:
if int(attr.GetIntValue()) >= len(currVal.GetCollection().GetLiterals()) {
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal.GetCollection().GetLiterals())
}
currVal = currVal.GetCollection().GetLiterals()[attr.GetIntValue()]
count += 1
// scalar is always the leaf, so we can break here
case *core.Literal_Scalar:
break
}
}

// resolve dataclass
if currVal.GetScalar() != nil && currVal.GetScalar().GetGeneric() != nil {
st := currVal.GetScalar().GetGeneric()
// start from index "count"
currVal, err = resolveAttrPathInPbStruct(ctx, nodeID, st, bindAttrPath[count:])
if err != nil {
return nil, err
}
}

return currVal, nil
}

// resolveAttrPathInPbStruct resolves the protobuf struct (e.g. dataclass) with attribute path
func resolveAttrPathInPbStruct(ctx context.Context, nodeID string, st *structpb.Struct, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) {

var currVal interface{}
var tmpVal interface{}
var exist bool

currVal = st.AsMap()

// Turn the current value to a map so it can be resolved more easily
for _, attr := range bindAttrPath {
switch currVal.(type) {
// map
case map[string]interface{}:
tmpVal, exist = currVal.(map[string]interface{})[attr.GetStringValue()]
if exist == false {
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal)
}
currVal = tmpVal
// list
case []interface{}:
if int(attr.GetIntValue()) >= len(currVal.([]interface{})) {
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal)
}
currVal = currVal.([]interface{})[attr.GetIntValue()]
}
}

// After resolve, convert the interface to literal
literal, err := convertInterfaceToLiteral(ctx, nodeID, currVal)

return literal, err
}

// convertInterfaceToLiteral converts the protobuf struct (e.g. dataclass) to literal
func convertInterfaceToLiteral(ctx context.Context, nodeID string, obj interface{}) (*core.Literal, error) {

literal := &core.Literal{}

switch obj.(type) {
case map[string]interface{}:
new_st, err := structpb.NewStruct(obj.(map[string]interface{}))
if err != nil {
return nil, err
}
literal.Value = &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Generic{
Generic: new_st,
},
},
}

Check warning on line 105 in flytepropeller/pkg/controller/nodes/attr_path_resolver.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/attr_path_resolver.go#L94-L105

Added lines #L94 - L105 were not covered by tests
case []interface{}:
literals := []*core.Literal{}
for _, v := range obj.([]interface{}) {
// recursively convert the interface to literal
literal, err := convertInterfaceToLiteral(ctx, nodeID, v)
if err != nil {
return nil, err
}

Check warning on line 113 in flytepropeller/pkg/controller/nodes/attr_path_resolver.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/attr_path_resolver.go#L112-L113

Added lines #L112 - L113 were not covered by tests
literals = append(literals, literal)
}
literal.Value = &core.Literal_Collection{
Collection: &core.LiteralCollection{
Literals: literals,
},
}
case interface{}:
scalar, err := convertInterfaceToLiteralScalar(ctx, nodeID, obj)
if err != nil {
return nil, err
}

Check warning on line 125 in flytepropeller/pkg/controller/nodes/attr_path_resolver.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/attr_path_resolver.go#L124-L125

Added lines #L124 - L125 were not covered by tests
literal.Value = scalar
}

return literal, nil
}

// convertInterfaceToLiteralScalar converts the a single value to a literal scalar
func convertInterfaceToLiteralScalar(ctx context.Context, nodeID string, obj interface{}) (*core.Literal_Scalar, error) {
value := &core.Primitive{}

switch obj.(type) {
case string:
value.Value = &core.Primitive_StringValue{StringValue: obj.(string)}
case int:
value.Value = &core.Primitive_Integer{Integer: int64(obj.(int))}
case float64:
value.Value = &core.Primitive_FloatValue{FloatValue: obj.(float64)}
case bool:
value.Value = &core.Primitive_Boolean{Boolean: obj.(bool)}
default:
return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "Failed to resolve interface to literal scalar")

Check warning on line 146 in flytepropeller/pkg/controller/nodes/attr_path_resolver.go

View check run for this annotation

Codecov / codecov/patch

flytepropeller/pkg/controller/nodes/attr_path_resolver.go#L139-L146

Added lines #L139 - L146 were not covered by tests
}

return &core.Literal_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Primitive{
Primitive: value,
},
},
}, nil
}
Loading

0 comments on commit 849e7c8

Please sign in to comment.