Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Check for None values in branch nodes #592

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions pkg/controller/nodes/branch/comparator.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ var perTypeComparators = map[string]comparators{
}

func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) {
if lValue == nil || rValue == nil {
switch op {
case core.ComparisonExpression_EQ:
return lValue == rValue, nil
case core.ComparisonExpression_NEQ:
return lValue != rValue, nil
default:
return false, errors.Errorf(ErrorCodeMalformedBranch, "Comparison between nil and non-nil values with operator [%v] is not supported. lVal[%v]:rVal[%v]", op, lValue, rValue)
}
}
lValueType := reflect.TypeOf(lValue.Value)
rValueType := reflect.TypeOf(rValue.Value)
if lValueType != rValueType {
Expand Down Expand Up @@ -116,24 +126,32 @@ func Evaluate(lValue *core.Primitive, rValue *core.Primitive, op core.Comparison

func Evaluate1(lValue *core.Primitive, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive.")
if rValue.GetScalar().GetNoneType() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive")
}
}
return Evaluate(lValue, rValue.GetScalar().GetPrimitive(), op)
}

func Evaluate2(lValue *core.Literal, rValue *core.Primitive, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
if lValue.GetScalar().GetNoneType() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
}
}
return Evaluate(lValue.GetScalar().GetPrimitive(), rValue, op)
}

func EvaluateLiterals(lValue *core.Literal, rValue *core.Literal, op core.ComparisonExpression_Operator) (bool, error) {
if lValue.GetScalar() == nil || lValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
if lValue.GetScalar().GetNoneType() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. LHS Variable is non primitive.")
}
}
if rValue.GetScalar() == nil || rValue.GetScalar().GetPrimitive() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive")
if rValue.GetScalar().GetNoneType() == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Only primitives can be compared. RHS Variable is non primitive")
}
}
return Evaluate(lValue.GetScalar().GetPrimitive(), rValue.GetScalar().GetPrimitive(), op)
}
20 changes: 14 additions & 6 deletions pkg/controller/nodes/branch/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.Litera
var rPrim *core.Primitive

if expr.GetLeftValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
if len(expr.GetLeftValue().GetVar()) == 0 {
lValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_NoneType{NoneType: &core.Void{}}}}}
} else {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()]
}
lValue = nodeInputs.Literals[expr.GetLeftValue().GetVar()]
if lValue == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
Expand All @@ -36,10 +40,14 @@ func EvaluateComparison(expr *core.ComparisonExpression, nodeInputs *core.Litera
}

if expr.GetRightValue().GetPrimitive() == nil {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
if len(expr.GetRightValue().GetVar()) == 0 {
rValue = &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{Value: &core.Scalar_NoneType{NoneType: &core.Void{}}}}}
} else {
if nodeInputs == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetLeftValue().GetVar())
}
rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()]
}
rValue = nodeInputs.Literals[expr.GetRightValue().GetVar()]
if rValue == nil {
return false, errors.Errorf(ErrorCodeMalformedBranch, "Failed to find Value for Variable [%v]", expr.GetRightValue().GetVar())
}
Expand Down
51 changes: 51 additions & 0 deletions pkg/controller/nodes/branch/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,57 @@ func TestEvaluateComparison(t *testing.T) {
assert.NoError(t, err)
assert.False(t, v)
})
t.Run("CompareNoneAndLiteral", func(t *testing.T) {
// Compare lVal -> None and rVal -> literal
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{},
Operator: core.ComparisonExpression_EQ,
RightValue: &core.Operand{
Val: &core.Operand_Primitive{
Primitive: coreutils.MustMakePrimitive(1),
},
},
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.False(t, v)
})
t.Run("CompareLiteralAndNone", func(t *testing.T) {
// Compare lVal -> literal and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{
Val: &core.Operand_Primitive{
Primitive: coreutils.MustMakePrimitive(1),
},
},
Operator: core.ComparisonExpression_NEQ,
RightValue: &core.Operand{},
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareNoneAndNone", func(t *testing.T) {
// Compare lVal -> None and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{},
Operator: core.ComparisonExpression_EQ,
RightValue: &core.Operand{},
}
v, err := EvaluateComparison(exp, nil)
assert.NoError(t, err)
assert.True(t, v)
})
t.Run("CompareNoneAndNoneWithError", func(t *testing.T) {
// Compare lVal -> None and rVal -> None
exp := &core.ComparisonExpression{
LeftValue: &core.Operand{},
Operator: core.ComparisonExpression_GTE,
RightValue: &core.Operand{},
}
_, err := EvaluateComparison(exp, nil)
assert.Error(t, err)
})
t.Run("CompareLiteralAndPrimitive", func(t *testing.T) {

// Compare lVal -> literal and rVal -> primitive
Expand Down