From 849e7c8605d62d1675019dd62deff4d411125604 Mon Sep 17 00:00:00 2001 From: byhsu Date: Sat, 14 Oct 2023 17:37:47 -0700 Subject: [PATCH] Add propeller type check and resolver Signed-off-by: byhsu --- .../pkg/compiler/errors/compiler_errors.go | 11 + .../pkg/compiler/validators/bindings.go | 21 ++ .../pkg/compiler/validators/bindings_test.go | 149 ++++++++ .../controller/nodes/attr_path_resolver.go | 156 +++++++++ .../nodes/attr_path_resolver_test.go | 329 ++++++++++++++++++ .../pkg/controller/nodes/errors/codes.go | 1 + .../pkg/controller/nodes/output_resolver.go | 22 +- .../pkg/controller/nodes/resolve.go | 3 +- 8 files changed, 687 insertions(+), 5 deletions(-) create mode 100644 flytepropeller/pkg/controller/nodes/attr_path_resolver.go create mode 100644 flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go diff --git a/flytepropeller/pkg/compiler/errors/compiler_errors.go b/flytepropeller/pkg/compiler/errors/compiler_errors.go index f3fd02f96f2..2e0e15367a2 100755 --- a/flytepropeller/pkg/compiler/errors/compiler_errors.go +++ b/flytepropeller/pkg/compiler/errors/compiler_errors.go @@ -93,6 +93,9 @@ const ( // A gate node is missing a condition. NoConditionFound ErrorCode = "NoConditionFound" + + // Field not found in the dataclass + FieldNotFoundError ErrorCode = "FieldNotFound" ) func NewBranchNodeNotSpecified(branchNodeID string) *CompileError { @@ -216,6 +219,14 @@ func NewMismatchingBindingsErr(nodeID, sinkParam, expectedType, receivedType str ) } +func NewFieldNotFoundErr(nodeID, fromVar, fromType, key string) *CompileError { + return newError( + FieldNotFoundError, + fmt.Sprintf("Variable [%v] (type [%v]) doesn't have field [%v].", fromVar, fromType, key), + nodeID, + ) +} + func NewIllegalEnumValueError(nodeID, sinkParam, receivedVal string, expectedVals []string) *CompileError { return newError( IllegalEnumValue, diff --git a/flytepropeller/pkg/compiler/validators/bindings.go b/flytepropeller/pkg/compiler/validators/bindings.go index 7c6f0e60e6c..ab5986ae44c 100644 --- a/flytepropeller/pkg/compiler/validators/bindings.go +++ b/flytepropeller/pkg/compiler/validators/bindings.go @@ -124,6 +124,27 @@ func validateBinding(w c.WorkflowBuilder, nodeID c.NodeID, nodeParam string, bin sourceType = cType } } + var exist bool + var tmpType *flyte.LiteralType + + // If the variable has an attribute path. Extract the type of the last attribute. + for _, attr := range val.Promise.AttrPath { + if sourceType.GetCollectionType() != nil { + sourceType = sourceType.GetCollectionType() + } else if sourceType.GetMapValueType() != nil { + sourceType = sourceType.GetMapValueType() + } else if sourceType.GetStructure() != nil && sourceType.GetStructure().GetDataclassType() != nil { + + tmpType, exist = sourceType.GetStructure().GetDataclassType()[attr.GetStringValue()] + + if !exist { + errs.Collect(errors.NewFieldNotFoundErr(nodeID, val.Promise.Var, sourceType.String(), attr.GetStringValue())) + return nil, nil, !errs.HasErrors() + } else { + sourceType = tmpType + } + } + } if !validateParamTypes || AreTypesCastable(sourceType, expectedType) { val.Promise.NodeId = upNode.GetId() diff --git a/flytepropeller/pkg/compiler/validators/bindings_test.go b/flytepropeller/pkg/compiler/validators/bindings_test.go index 9d90e1c843c..0fa56e4a2ec 100644 --- a/flytepropeller/pkg/compiler/validators/bindings_test.go +++ b/flytepropeller/pkg/compiler/validators/bindings_test.go @@ -4,6 +4,7 @@ import ( "testing" c "github.com/flyteorg/flyte/flytepropeller/pkg/compiler/common" + structpb "github.com/golang/protobuf/ptypes/struct" "github.com/flyteorg/flyte/flyteidl/clients/go/coreutils" "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" @@ -305,6 +306,154 @@ func TestValidateBindings(t *testing.T) { } }) + t.Run("List/Dict Promises with attribute path", func(t *testing.T) { + // List/Dict with attribute path should conduct validation + + n := &mocks.NodeBuilder{} + n.OnGetId().Return("node1") + n.OnGetInterface().Return(&core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + }) + + n2 := &mocks.NodeBuilder{} + n2.OnGetId().Return("node2") + n2.OnGetOutputAliases().Return(nil) + n2.OnGetInterface().Return(&core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "n2_out": { + Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(map[string]interface{}{"x": []interface{}{1, 3, 4}})), + }, + }, + }, + }) + + wf := &mocks.WorkflowBuilder{} + wf.OnGetNode("n2").Return(n2, true) + wf.On("AddExecutionEdge", mock.Anything, mock.Anything).Return(nil) + + bindings := []*core.Binding{ + { + Var: "x", + Binding: &core.BindingData{ + Value: &core.BindingData_Promise{ + Promise: &core.OutputReference{ + Var: "n2_out", + NodeId: "n2", + AttrPath: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{"x"}, + }, + { + Value: &core.PromiseAttribute_IntValue{0}, + }, + }, + }, + }, + }, + }, + } + + vars := &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": { + Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(1)), + }, + }, + } + + compileErrors := compilerErrors.NewCompileErrors() + _, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors) + assert.True(t, ok) + if compileErrors.HasErrors() { + assert.NoError(t, compileErrors) + } + }) + + t.Run("pb.Struct Promises with attribute path", func(t *testing.T) { + // Dataclass with attribute path should skip validation + + n := &mocks.NodeBuilder{} + n.OnGetId().Return("node1") + n.OnGetInterface().Return(&core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + }) + + n2 := &mocks.NodeBuilder{} + n2.OnGetId().Return("node2") + n2.OnGetOutputAliases().Return(nil) + literalType := LiteralTypeForLiteral(coreutils.MustMakeLiteral(&structpb.Struct{})) + literalType.Structure = &core.TypeStructure{} + literalType.Structure.DataclassType = map[string]*core.LiteralType{"x": LiteralTypeForLiteral(coreutils.MustMakeLiteral(1))} + + n2.OnGetInterface().Return(&core.TypedInterface{ + Inputs: &core.VariableMap{ + Variables: map[string]*core.Variable{}, + }, + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "n2_out": { + Type: literalType, + }, + }, + }, + }) + + wf := &mocks.WorkflowBuilder{} + wf.OnGetNode("n2").Return(n2, true) + wf.On("AddExecutionEdge", mock.Anything, mock.Anything).Return(nil) + + bindings := []*core.Binding{ + { + Var: "x", + Binding: &core.BindingData{ + Value: &core.BindingData_Promise{ + Promise: &core.OutputReference{ + Var: "n2_out", + NodeId: "n2", + AttrPath: []*core.PromiseAttribute{ + { + Value: &core.PromiseAttribute_StringValue{"x"}, + }, + { + Value: &core.PromiseAttribute_IntValue{0}, + }, + }, + }, + }, + }, + }, + } + + vars := &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": { + Type: LiteralTypeForLiteral(coreutils.MustMakeLiteral(1)), + }, + }, + } + + compileErrors := compilerErrors.NewCompileErrors() + _, ok := ValidateBindings(wf, n, bindings, vars, true, c.EdgeDirectionBidirectional, compileErrors) + assert.True(t, ok) + if compileErrors.HasErrors() { + assert.NoError(t, compileErrors) + } + }) + t.Run("Nil Binding Value", func(t *testing.T) { n := &mocks.NodeBuilder{} n.OnGetId().Return("node1") diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go new file mode 100644 index 00000000000..7d60f54d893 --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver.go @@ -0,0 +1,156 @@ +package nodes + +import ( + "context" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" + "google.golang.org/protobuf/types/known/structpb" +) + +// resolveAttrPathInPromise resolves the literal with attribute path +// If the promise is chained with attributes (e.g. promise.a["b"][0]), then we need to resolve the promise +func resolveAttrPathInPromise(ctx context.Context, nodeID string, literal *core.Literal, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) { + var currVal *core.Literal = literal + var tmpVal *core.Literal + var err error + var exist bool + count := 0 + + for _, attr := range bindAttrPath { + switch currVal.GetValue().(type) { + case *core.Literal_Map: + tmpVal, exist = currVal.GetMap().GetLiterals()[attr.GetStringValue()] + if exist == false { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal.GetMap().GetLiterals()) + } + currVal = tmpVal + count += 1 + case *core.Literal_Collection: + if int(attr.GetIntValue()) >= len(currVal.GetCollection().GetLiterals()) { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal.GetCollection().GetLiterals()) + } + currVal = currVal.GetCollection().GetLiterals()[attr.GetIntValue()] + count += 1 + // scalar is always the leaf, so we can break here + case *core.Literal_Scalar: + break + } + } + + // resolve dataclass + if currVal.GetScalar() != nil && currVal.GetScalar().GetGeneric() != nil { + st := currVal.GetScalar().GetGeneric() + // start from index "count" + currVal, err = resolveAttrPathInPbStruct(ctx, nodeID, st, bindAttrPath[count:]) + if err != nil { + return nil, err + } + } + + return currVal, nil +} + +// resolveAttrPathInPbStruct resolves the protobuf struct (e.g. dataclass) with attribute path +func resolveAttrPathInPbStruct(ctx context.Context, nodeID string, st *structpb.Struct, bindAttrPath []*core.PromiseAttribute) (*core.Literal, error) { + + var currVal interface{} + var tmpVal interface{} + var exist bool + + currVal = st.AsMap() + + // Turn the current value to a map so it can be resolved more easily + for _, attr := range bindAttrPath { + switch currVal.(type) { + // map + case map[string]interface{}: + tmpVal, exist = currVal.(map[string]interface{})[attr.GetStringValue()] + if exist == false { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "key [%v] does not exist in literal %v", attr.GetStringValue(), currVal) + } + currVal = tmpVal + // list + case []interface{}: + if int(attr.GetIntValue()) >= len(currVal.([]interface{})) { + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "index [%v] is out of range of %v", attr.GetIntValue(), currVal) + } + currVal = currVal.([]interface{})[attr.GetIntValue()] + } + } + + // After resolve, convert the interface to literal + literal, err := convertInterfaceToLiteral(ctx, nodeID, currVal) + + return literal, err +} + +// convertInterfaceToLiteral converts the protobuf struct (e.g. dataclass) to literal +func convertInterfaceToLiteral(ctx context.Context, nodeID string, obj interface{}) (*core.Literal, error) { + + literal := &core.Literal{} + + switch obj.(type) { + case map[string]interface{}: + new_st, err := structpb.NewStruct(obj.(map[string]interface{})) + if err != nil { + return nil, err + } + literal.Value = &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: new_st, + }, + }, + } + case []interface{}: + literals := []*core.Literal{} + for _, v := range obj.([]interface{}) { + // recursively convert the interface to literal + literal, err := convertInterfaceToLiteral(ctx, nodeID, v) + if err != nil { + return nil, err + } + literals = append(literals, literal) + } + literal.Value = &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literals, + }, + } + case interface{}: + scalar, err := convertInterfaceToLiteralScalar(ctx, nodeID, obj) + if err != nil { + return nil, err + } + literal.Value = scalar + } + + return literal, nil +} + +// convertInterfaceToLiteralScalar converts the a single value to a literal scalar +func convertInterfaceToLiteralScalar(ctx context.Context, nodeID string, obj interface{}) (*core.Literal_Scalar, error) { + value := &core.Primitive{} + + switch obj.(type) { + case string: + value.Value = &core.Primitive_StringValue{StringValue: obj.(string)} + case int: + value.Value = &core.Primitive_Integer{Integer: int64(obj.(int))} + case float64: + value.Value = &core.Primitive_FloatValue{FloatValue: obj.(float64)} + case bool: + value.Value = &core.Primitive_Boolean{Boolean: obj.(bool)} + default: + return nil, errors.Errorf(errors.PromiseAttributeResolveError, nodeID, "Failed to resolve interface to literal scalar") + } + + return &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: value, + }, + }, + }, nil +} diff --git a/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go new file mode 100644 index 00000000000..4ba90061afd --- /dev/null +++ b/flytepropeller/pkg/controller/nodes/attr_path_resolver_test.go @@ -0,0 +1,329 @@ +package nodes + +import ( + "testing" + + "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyte/flytepropeller/pkg/controller/nodes/errors" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/structpb" +) + +func NewScalarLiteral(value string) *core.Literal { + return &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: value, + }, + }, + }, + }, + }, + } +} + +func NewStructFromMap(m map[string]interface{}) *structpb.Struct { + st, _ := structpb.NewStruct(m) + return st +} + +func TestResolveAttrPathIn(t *testing.T) { + // - map {"foo": "bar"} + // - collection ["foo", "bar"] + // - struct1 {"foo": "bar"} + // - struct2 {"foo": ["bar1", "bar2"]} + // - nested list struct {"foo": [["bar1", "bar2"]]} + // - map+collection+struct {"foo": [{"bar": "car"}]} + // - exception key error with map + // - exception out of range with collection + // - exception key error with struct + // - exception out of range with struct + + args := []struct { + literal *core.Literal + path []*core.PromiseAttribute + expected *core.Literal + hasError bool + }{ + { + literal: &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": NewScalarLiteral("bar"), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + }, + expected: NewScalarLiteral("bar"), + hasError: false, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + NewScalarLiteral("foo"), + NewScalarLiteral("bar"), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_IntValue{ + IntValue: 1, + }, + }, + }, + expected: NewScalarLiteral("bar"), + hasError: false, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap(map[string]interface{}{"foo": "bar"}), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + }, + expected: NewScalarLiteral("bar"), + hasError: false, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "foo": []interface{}{"bar1", "bar2"}, + }, + ), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_IntValue{ + IntValue: 1, + }, + }, + }, + expected: NewScalarLiteral("bar2"), + hasError: false, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "foo": []interface{}{[]interface{}{"bar1", "bar2"}}, + }, + ), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + }, + expected: &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + NewScalarLiteral("bar1"), + NewScalarLiteral("bar2"), + }, + }, + }, + }, + }, + }, + }, + }, + hasError: false, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap(map[string]interface{}{"bar": "car"}), + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_IntValue{ + IntValue: 0, + }, + }, + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "bar", + }, + }, + }, + expected: NewScalarLiteral("car"), + hasError: false, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Map{ + Map: &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "foo": NewScalarLiteral("bar"), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "random", + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + NewScalarLiteral("foo"), + NewScalarLiteral("bar"), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_IntValue{ + IntValue: 2, + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap(map[string]interface{}{"foo": "bar"}), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "random", + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + { + literal: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Generic{ + Generic: NewStructFromMap( + map[string]interface{}{ + "foo": []interface{}{"bar1", "bar2"}, + }, + ), + }, + }, + }, + }, + path: []*core.PromiseAttribute{ + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_StringValue{ + StringValue: "foo", + }, + }, + &core.PromiseAttribute{ + Value: &core.PromiseAttribute_IntValue{ + IntValue: 100, + }, + }, + }, + expected: &core.Literal{}, + hasError: true, + }, + } + + for i, arg := range args { + resolved, err := resolveAttrPathInPromise(nil, "", arg.literal, arg.path) + if arg.hasError { + assert.Error(t, err, i) + assert.ErrorContains(t, err, errors.PromiseAttributeResolveError, i) + } else { + assert.Equal(t, arg.expected, resolved, i) + } + } +} diff --git a/flytepropeller/pkg/controller/nodes/errors/codes.go b/flytepropeller/pkg/controller/nodes/errors/codes.go index a3f877517dc..53d3bbc8d7f 100644 --- a/flytepropeller/pkg/controller/nodes/errors/codes.go +++ b/flytepropeller/pkg/controller/nodes/errors/codes.go @@ -26,4 +26,5 @@ const ( EventRecordingFailed ErrorCode = "EventRecordingFailed" CatalogCallFailed ErrorCode = "CatalogCallFailed" InvalidArrayLength ErrorCode = "InvalidArrayLength" + PromiseAttributeResolveError ErrorCode = "PromiseAttributeResolveError" ) diff --git a/flytepropeller/pkg/controller/nodes/output_resolver.go b/flytepropeller/pkg/controller/nodes/output_resolver.go index 4a771c3172b..93b00a9213a 100644 --- a/flytepropeller/pkg/controller/nodes/output_resolver.go +++ b/flytepropeller/pkg/controller/nodes/output_resolver.go @@ -22,7 +22,7 @@ type VarName = string type OutputResolver interface { // Extracts a subset of node outputs to literals. ExtractOutput(ctx context.Context, nl executors.NodeLookup, n v1alpha1.ExecutableNode, - bindToVar VarName) (values *core.Literal, err error) + bindToVar VarName, bindAttrPath []*core.PromiseAttribute) (values *core.Literal, err error) } func CreateAliasMap(aliases []v1alpha1.Alias) map[string]string { @@ -39,7 +39,7 @@ type remoteFileOutputResolver struct { } func (r remoteFileOutputResolver) ExtractOutput(ctx context.Context, nl executors.NodeLookup, n v1alpha1.ExecutableNode, - bindToVar VarName) (values *core.Literal, err error) { + bindToVar VarName, bindAttrPath []*core.PromiseAttribute) (values *core.Literal, err error) { nodeStatus := nl.GetNodeExecutionStatus(ctx, n.GetID()) outputsFileRef := v1alpha1.GetOutputsFile(nodeStatus.GetOutputDir()) @@ -54,11 +54,25 @@ func (r remoteFileOutputResolver) ExtractOutput(ctx context.Context, nl executor actualVar = variable } + var output *core.Literal + + // retrieving task output if index == nil { - return resolveSingleOutput(ctx, r.store, n.GetID(), outputsFileRef, actualVar) + output, err = resolveSingleOutput(ctx, r.store, n.GetID(), outputsFileRef, actualVar) + } else { + output, err = resolveSubtaskOutput(ctx, r.store, n.GetID(), outputsFileRef, *index, actualVar) + } + + if err != nil { + return nil, err + } + + // resolving binding attribute path if exist + if len(bindAttrPath) > 0 { + output, err = resolveAttrPathInPromise(ctx, n.GetID(), output, bindAttrPath) } - return resolveSubtaskOutput(ctx, r.store, n.GetID(), outputsFileRef, *index, actualVar) + return output, err } func resolveSubtaskOutput(ctx context.Context, store storage.ProtobufStore, nodeID string, outputsFileRef storage.DataReference, diff --git a/flytepropeller/pkg/controller/nodes/resolve.go b/flytepropeller/pkg/controller/nodes/resolve.go index 48935226981..2dd322c601e 100644 --- a/flytepropeller/pkg/controller/nodes/resolve.go +++ b/flytepropeller/pkg/controller/nodes/resolve.go @@ -61,6 +61,7 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, nl e upstreamNodeID := bindingData.GetPromise().GetNodeId() bindToVar := bindingData.GetPromise().GetVar() + bindAttrPath := bindingData.GetPromise().GetAttrPath() if nl == nil { return nil, errors.Errorf(errors.IllegalStateError, upstreamNodeID, @@ -79,7 +80,7 @@ func ResolveBindingData(ctx context.Context, outputResolver OutputResolver, nl e "Undefined node in Workflow") } - return outputResolver.ExtractOutput(ctx, nl, n, bindToVar) + return outputResolver.ExtractOutput(ctx, nl, n, bindToVar, bindAttrPath) case *core.BindingData_Scalar: logger.Debugf(ctx, "bindingData.GetValue() [%v] is of type Scalar", bindingData.GetValue()) literal.Value = &core.Literal_Scalar{Scalar: bindingData.GetScalar()}