Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving from flytepropeller - Check for None values in branch nodes #4154

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flytepropeller/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@ require (
github.com/DiSiqueira/GoTree v1.0.1-0.20180907134536-53a8e837f295
github.com/benlaurie/objecthash v0.0.0-20180202135721-d1e3d6079fc1
github.com/fatih/color v1.13.0
<<<<<<< HEAD
github.com/flyteorg/flyte/flyteplugins v0.0.0-00010101000000-000000000000
github.com/flyteorg/flyte/flytestdlib v0.0.0-00010101000000-000000000000
github.com/flyteorg/flyteidl v0.0.0-00010101000000-000000000000
=======
github.com/flyteorg/flyteidl v1.5.16
github.com/flyteorg/flyteplugins v1.1.30
github.com/flyteorg/flytestdlib v1.0.24
>>>>>>> flytepropeller/is-none
github.com/ghodss/yaml v1.0.0
github.com/go-redis/redis v6.15.7+incompatible
github.com/go-test/deep v1.0.7
Expand Down
9 changes: 9 additions & 0 deletions flytepropeller/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
<<<<<<< HEAD
=======
github.com/flyteorg/flyteidl v1.5.16 h1:S70wD7K99nKHZxmo8U16Jjhy1kZwoBh5ZQhZf3/6MPU=
github.com/flyteorg/flyteidl v1.5.16/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flyteplugins v1.1.30 h1:AVqS6Eb9Nr9Z3Mb3CtP04ffAVS9LMx5Q1Z7AyFFk/e0=
github.com/flyteorg/flyteplugins v1.1.30/go.mod h1:FujFQdL/f9r1HvFR81JCiNYusDy9F0lExhyoyMHXXbg=
github.com/flyteorg/flytestdlib v1.0.24 h1:jDvymcjlsTRCwOtxPapro0WZBe3isTz+T3Tiq+mZUuk=
github.com/flyteorg/flytestdlib v1.0.24/go.mod h1:6nXa5g00qFIsgdvQ7jKQMJmDniqO0hG6Z5X5olfduqQ=
>>>>>>> flytepropeller/is-none
github.com/flyteorg/stow v0.3.7 h1:Cx7j8/Ux6+toD5hp5fy++927V+yAcAttDeQAlUD/864=
github.com/flyteorg/stow v0.3.7/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
Expand Down
9 changes: 8 additions & 1 deletion flytepropeller/pkg/compiler/validators/condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ func validateOperand(node c.NodeBuilder, paramName string, operand *flyte.Operan
} else if operand.GetPrimitive() != nil {
// no validation
literalType = literalTypeForPrimitive(operand.GetPrimitive())
} else if operand.GetScalar().GetPrimitive() != nil {
literalType = literalTypeForPrimitive(operand.GetPrimitive())
} else if operand.GetScalar().GetNoneType() != nil {
literalType = &flyte.LiteralType{Type: &flyte.LiteralType_Simple{Simple: flyte.SimpleType_NONE}}
} else if len(operand.GetVar()) > 0 {
if node.GetInterface() != nil {
if param, paramOk := validateInputVar(node, operand.GetVar(), requireParamType, errs.NewScope()); paramOk {
Expand All @@ -41,7 +45,10 @@ func ValidateBooleanExpression(w c.WorkflowBuilder, node c.NodeBuilder, expr *fl
expr.GetComparison().GetRightValue(), requireParamType, errs.NewScope())
op2Type, op2Valid := validateOperand(node, "LeftValue",
expr.GetComparison().GetLeftValue(), requireParamType, errs.NewScope())
if op1Valid && op2Valid && op1Type != nil && op2Type != nil {
// Valid expression
// 1. Both operands are primitive types and have the same types.
// 2. One of the operands is the None type.
if op1Valid && op2Valid && op1Type != nil && op2Type != nil && op1Type.GetSimple() != flyte.SimpleType_NONE && op2Type.GetSimple() != flyte.SimpleType_NONE {
if op1Type.String() != op2Type.String() {
errs.Collect(errors.NewMismatchingTypesErr(node.GetId(), "RightValue",
op1Type.String(), op2Type.String()))
Expand Down
56 changes: 34 additions & 22 deletions flytepropeller/pkg/controller/nodes/branch/comparator.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,21 @@ var perTypeComparators = map[string]comparators{
},
}

func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) {
lValueType := reflect.TypeOf(lValue.Value)
rValueType := reflect.TypeOf(rValue.Value)
func Evaluate(lValue *core.Scalar, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetNoneType() != nil || rValue.GetNoneType() != nil {
lIsNone := lValue.GetNoneType() != nil
rIsNone := rValue.GetNoneType() != nil
switch op {
case core.ComparisonExpression_EQ:
return lIsNone == rIsNone, nil
case core.ComparisonExpression_NEQ:
return lIsNone != rIsNone, nil
default:
return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between nil and non-nil values with operator [%v] is not supported. lVal[%v]:rVal[%v]", op, lValue, rValue)
}
}
lValueType := reflect.TypeOf(lValue.GetPrimitive().Value)
rValueType := reflect.TypeOf(rValue.GetPrimitive().Value)
if lValueType != rValueType {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between different primitives types. lVal[%v]:rVal[%v]", lValueType, rValueType)
}
Expand All @@ -90,50 +102,50 @@ func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.Comparison
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[GT] not defined for boolean operands.")
}
return comps.gt(lValue, rValue), nil
return comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_GTE:
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[GTE] not defined for boolean operands.")
}
return comps.eq(lValue, rValue) || comps.gt(lValue, rValue), nil
return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_LT:
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[LT] not defined for boolean operands.")
}
return !(comps.gt(lValue, rValue) || comps.eq(lValue, rValue)), nil
return !(comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()) || comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive())), nil
case core.ComparisonExpression_LTE:
if isBoolean {
return false, errors.Errorf(ErrorCodeMalformedBranch, "[LTE] not defined for boolean operands.")
}
return !comps.gt(lValue, rValue), nil
return !comps.gt(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_EQ:
return comps.eq(lValue, rValue), nil
return comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
case core.ComparisonExpression_NEQ:
return !comps.eq(lValue, rValue), nil
return !comps.eq(lValue.GetPrimitive(), rValue.GetPrimitive()), nil
}
return false, errors.Errorf(ErrorCodeMalformedBranch, "Unsupported operator type in Propeller. System error.")
}

func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive.")
func Evaluate1(lValue *core.Scalar, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue)
}
return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op)
return Evaluate(lValue, rValue.GetScalar(), op)
}

func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
func Evaluate2(lValue *core.Literal, rValue *core.Scalar, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue)
}
return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op)
return Evaluate(lValue.GetScalar(), rValue, op)
}

func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
if lValue.GetScalar() == nil || (lValue.GetScalar().GetPrimitive() == nil && lValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable [%v] is non primitive.", lValue)
}
if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive")
if rValue.GetScalar() == nil || (rValue.GetScalar().GetPrimitive() == nil && rValue.GetScalar().GetNoneType() == nil) {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable [%v] is non primitive", rValue)
}
return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op)
return Evaluate(lValue.GetScalar(), rValue.GetScalar(), op)
}
24 changes: 12 additions & 12 deletions flytepropeller/pkg/controller/nodes/branch/comparator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

func TestEvaluate_int(t *testing.T) {
p1 := coreutils.MustMakePrimitive(1)
p2 := coreutils.MustMakePrimitive(2)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -82,8 +82,8 @@ func TestEvaluate_int(t *testing.T) {
}

func TestEvaluate_float(t *testing.T) {
p1 := coreutils.MustMakePrimitive(1.0)
p2 := coreutils.MustMakePrimitive(2.0)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(2)}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -153,8 +153,8 @@ func TestEvaluate_float(t *testing.T) {
}

func TestEvaluate_string(t *testing.T) {
p1 := coreutils.MustMakePrimitive("a")
p2 := coreutils.MustMakePrimitive("b")
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("a")}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive("b")}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -224,8 +224,8 @@ func TestEvaluate_string(t *testing.T) {
}

func TestEvaluate_datetime(t *testing.T) {
p1 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC))
p2 := coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC))
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 00, 00, time.UTC))}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(time.Date(2018, 7, 4, 12, 00, 01, 00, time.UTC))}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -295,8 +295,8 @@ func TestEvaluate_datetime(t *testing.T) {
}

func TestEvaluate_duration(t *testing.T) {
p1 := coreutils.MustMakePrimitive(10 * time.Second)
p2 := coreutils.MustMakePrimitive(11 * time.Second)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(10 * time.Second)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(11 * time.Second)}}
{
// p1 > p2 = false
b, err := Evaluate(p1, p2, core.ComparisonExpression_GT)
Expand Down Expand Up @@ -366,8 +366,8 @@ func TestEvaluate_duration(t *testing.T) {
}

func TestEvaluate_boolean(t *testing.T) {
p1 := coreutils.MustMakePrimitive(true)
p2 := coreutils.MustMakePrimitive(false)
p1 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(true)}}
p2 := &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(false)}}
f := func(op core.ComparisonExpression_Operator) {
// GT/LT = false
msg := fmt.Sprintf("Evaluating: [%s]", op.String())
Expand Down
42 changes: 32 additions & 10 deletions flytepropeller/pkg/controller/nodes/branch/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,53 @@ const ErrorCodeFailedFetchOutputs = "FailedFetchOutputs"
func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.LiteralMap) (bool, error) {
var lValue *core.Literal
var rValue *core.Literal
var lPrim *core.Primitive
var rPrim *core.Primitive
var lPrim *core.Scalar
var rPrim *core.Scalar

if expr.GetLeftValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
if expr.GetLeftValue().GetScalar().GetNoneType() != nil {
lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetLeftValue().GetScalar()}}
} else if expr.GetLeftValue().GetScalar().GetUnion() != nil {
lValue = expr.GetLeftValue().GetScalar().GetUnion().GetValue()
} else {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
input := nodeInputs.Literals[expr.GetLeftValue().GetVar()]
if input.GetScalar().GetUnion().GetValue() != nil {
lValue = input.GetScalar().GetUnion().GetValue()
} else {
lValue = input
}
}
lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()]
if lValue == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
} else {
lPrim = expr.GetLeftValue().GetPrimitive()
lPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetLeftValue().GetPrimitive()}}
}

if expr.GetRightValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
if expr.GetRightValue().GetScalar().GetNoneType() != nil {
rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: expr.GetRightValue().GetScalar()}}
} else if expr.GetRightValue().GetScalar().GetUnion() != nil {
rValue = expr.GetRightValue().GetScalar().GetUnion().GetValue()
} else {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
input := nodeInputs.Literals[expr.GetRightValue().GetVar()]
if input.GetScalar().GetUnion().GetValue() != nil {
rValue = input.GetScalar().GetUnion().GetValue()
} else {
rValue = input
}
}
rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()]
if rValue == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar())
}
} else {
rPrim = expr.GetRightValue().GetPrimitive()
rPrim = &core.Scalar{Value: &core.Scalar_Primitive{Primitive: expr.GetRightValue().GetPrimitive()}}
}

if lValue != nil && rValue != nil {
Expand Down
84 changes: 84 additions & 0 deletions flytepropeller/pkg/controller/nodes/branch/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ func createUnaryConjunction(l *core.ComparisonExpression, op core.ConjunctionExp
}
}

func getNoneOperand() *core.Operand {
return &core.Operand{
Val: &core.Operand_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_NoneType{NoneType: &core.Void{}},
},
},
}
}

func TestEvaluateComparison(t *testing.T) {
t.Run("ComparePrimitives", func(t *testing.T) {
// Compare primitives
Expand Down Expand Up @@ -100,6 +110,80 @@ func TestEvaluateComparison(t *testing.T) {
assert.NoError(t, err)
assert.False(t, v)
})
t.Run("CompareNoneAndLiteral", func(t *testing.T) {
// Compare lVal -> None and rVal -> literal
exp := &core.ComparisonExpression{
LeftValue: getNoneOperand(),
Operator: core.ComparisonExpression_EQ,
RightValue: &core.Operand{
Val: &core.Operand_Primitive{
Primitive: coreutils.MustMakePrimitive(1),
},
},
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.False(t, v)
})
t.Run("CompareLiteralAndNone", func(t *testing.T) {
// Compare lVal -> literal and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{
Val: &core.Operand_Primitive{
Primitive: coreutils.MustMakePrimitive(1),
},
},
Operator: core.ComparisonExpression_NEQ,
RightValue: getNoneOperand(),
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareUnionLiteralAndNone", func(t *testing.T) {
// Compare lVal -> literal and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{
Val: &core.Operand_Scalar{
Scalar: &core.Scalar{
Value: &core.Scalar_Union{
Union: &core.Union{
Value: &core.Literal{
Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_Primitive{Primitive: coreutils.MustMakePrimitive(1)}}},
},
},
},
},
},
},
Operator: core.ComparisonExpression_NEQ,
RightValue: getNoneOperand(),
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareNoneAndNone", func(t *testing.T) {
// Compare lVal -> None and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: getNoneOperand(),
Operator: core.ComparisonExpression_EQ,
RightValue: getNoneOperand(),
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareNoneAndNoneWithError", func(t *testing.T) {
// Compare lVal -> None and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: getNoneOperand(),
Operator: core.ComparisonExpression_GTE,
RightValue: getNoneOperand(),
}
_, err := EvaluateComparison(exp, nil)
assert.Error(t, err)
})
t.Run("CompareLiteralAndPrimitive", func(t *testing.T) {

// Compare lVal -> literal and rVal -> primitive
Expand Down
Loading