diff --git a/doc-requirements.txt b/doc-requirements.txt index 6ca7fbc1ee..10479245ab 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -249,7 +249,7 @@ flask==2.3.2 # via mlflow flatbuffers==23.5.26 # via tensorflow -flyteidl==1.5.14 +flyteidl==1.5.16 # via flytekit fonttools==4.41.1 # via matplotlib diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index 37c4afc88f..b9bda12650 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -391,6 +391,8 @@ def transform_to_conj_expr( def transform_to_operand(v: Union[Promise, Literal]) -> Tuple[_core_cond.Operand, Optional[Promise]]: if isinstance(v, Promise): return _core_cond.Operand(var=create_branch_node_promise_var(v.ref.node_id, v.var)), v + if v.scalar.none_type: + return _core_cond.Operand(scalar=v.scalar), None return _core_cond.Operand(primitive=v.scalar.primitive), None diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 35b72b5a56..c28c357930 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -136,12 +136,26 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._lhs = lhs if lhs.is_ready: if lhs.val.scalar is None or lhs.val.scalar.primitive is None: - raise ValueError("Only primitive values can be used in comparison") + union = lhs.val.scalar.union + if union and union.value.scalar: + if union.value.scalar.primitive or union.value.scalar.none_type: + self._lhs = union.value + else: + raise ValueError("Only primitive values can be used in comparison") + else: + raise ValueError("Only primitive values can be used in comparison") if isinstance(rhs, Promise): self._rhs = rhs if rhs.is_ready: if rhs.val.scalar is None or rhs.val.scalar.primitive is None: - raise ValueError("Only primitive values can be used in comparison") + union = rhs.val.scalar.union + if union and union.value.scalar: + if union.value.scalar.primitive or union.value.scalar.none_type: + self._rhs = union.value + else: + raise ValueError("Only primitive values can be used in comparison") + else: + raise ValueError("Only primitive values can be used in comparison") if self._lhs is None: self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) if self._rhs is None: @@ -162,11 +176,15 @@ def op(self) -> ComparisonOps: def eval(self) -> bool: if isinstance(self.lhs, Promise): lhs = self.lhs.eval() + elif self.lhs.scalar.none_type: + lhs = None else: lhs = get_primitive_val(self.lhs.scalar.primitive) if isinstance(self.rhs, Promise): rhs = self.rhs.eval() + elif self.rhs.scalar.none_type: + rhs = None else: rhs = get_primitive_val(self.rhs.scalar.primitive) @@ -350,9 +368,12 @@ def is_(self, v: bool) -> ComparisonExpression: def is_false(self) -> ComparisonExpression: return self.is_(False) - def is_true(self): + def is_true(self) -> ComparisonExpression: return self.is_(True) + def is_none(self) -> ComparisonExpression: + return ComparisonExpression(self, ComparisonOps.EQ, None) + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f1203c7fc7..8a9d9114c8 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -862,7 +862,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if python_val is None and expected.union_type is None: + if python_val is None and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 845b3b4f79..27e0bc505b 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -134,15 +134,17 @@ def from_flyte_idl(cls, pb2_object): class Operand(_common.FlyteIdlEntity): - def __init__(self, primitive=None, var=None): + def __init__(self, primitive=None, var=None, scalar=None): """ Defines an operand to a comparison expression. - :param flytekit.models.literals.Primitive primitive: - :param Text var: + :param flytekit.models.literals.Primitive primitive: A primitive value + :param Text var: A variable name + :param flytekit.models.literals.Scalar scalar: A scalar value """ self._primitive = primitive self._var = var + self._scalar = scalar @property def primitive(self): @@ -160,6 +162,14 @@ def var(self): return self._var + @property + def scalar(self): + """ + :rtype: flytekit.models.literals.Scalar + """ + + return self._scalar + def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.Operand @@ -167,6 +177,7 @@ def to_flyte_idl(self): return _condition.Operand( primitive=self.primitive.to_flyte_idl() if self.primitive else None, var=self.var if self.var else None, + scalar=self.scalar.to_flyte_idl() if self.scalar else None, ) @classmethod @@ -176,6 +187,7 @@ def from_flyte_idl(cls, pb2_object): if pb2_object.HasField("primitive") else None, var=pb2_object.var if pb2_object.HasField("var") else None, + scalar=_literals.Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None, ) diff --git a/setup.py b/setup.py index b758a3178d..02f906a813 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.14", + "flyteidl>=1.5.16", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index 7b0b292baa..0da2467109 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -194,6 +194,25 @@ def decompose() -> int: assert decompose() == 20 +def test_condition_is_none(): + @task + def return_true() -> typing.Optional[None]: + return None + + @workflow + def failed() -> int: + return 10 + + @workflow + def success() -> int: + return 20 + + @workflow + def decompose_unary() -> int: + result = return_true() + return conditional("test").if_(result.is_none()).then(success()).else_().then(failed()) + + def test_subworkflow_condition_serialization(): """Test that subworkflows are correctly extracted from serialized workflows with condiationals.""" diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index de83f66f78..6775d58940 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -228,6 +228,73 @@ def test_branch_node(): assert bn.if_else.case.then_node == obj +def test_branch_node_with_none(): + nm = _get_sample_node_metadata() + task = _workflow.TaskNode(reference_id=_generic_id) + bd = _literals.BindingData(scalar=_literals.Scalar(none_type=_literals.Void())) + lt = _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99))) + bd2 = _literals.BindingData( + scalar=_literals.Scalar( + union=_literals.Union(value=lt, stored_type=_types.LiteralType(_types.SimpleType.INTEGER)) + ) + ) + binding = _literals.Binding(var="myvar", binding=bd) + binding2 = _literals.Binding(var="myothervar", binding=bd2) + + obj = _workflow.Node( + id="some:node:id", + metadata=nm, + inputs=[binding, binding2], + upstream_node_ids=[], + output_aliases=[], + task_node=task, + ) + + bn = _workflow.BranchNode( + _workflow.IfElseBlock( + case=_workflow.IfBlock( + condition=_condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + then_node=obj, + ), + other=[ + _workflow.IfBlock( + condition=_condition.BooleanExpression( + conjunction=_condition.ConjunctionExpression( + _condition.ConjunctionExpression.LogicalOperator.AND, + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + ) + ), + then_node=obj, + ) + ], + else_node=obj, + ) + ) + + bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl()) + assert bn == bn2 + assert bn.if_else.case.then_node == obj + + def test_task_node_overrides(): overrides = _workflow.TaskNodeOverrides( Resources(