From 7b1c9e2b220dc4660511bc5a9a4ea2bd097850b1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 19 Jul 2023 15:26:54 -0700 Subject: [PATCH 1/7] Add is none function Signed-off-by: Kevin Su --- flytekit/core/promise.py | 25 ++++++++++++++++++++++--- flytekit/core/type_engine.py | 2 +- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 5dfd1a6b40..0a8bb0817c 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -85,7 +85,9 @@ def my_wf(in1: int, in2: int) -> int: return result -def get_primitive_val(prim: Primitive) -> Any: +def get_primitive_val(prim: Optional[Primitive] = None) -> Any: + if prim is None: + return None for value in [ prim.integer, prim.float_value, @@ -136,12 +138,24 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._lhs = lhs if lhs.is_ready: if lhs.val.scalar is None or lhs.val.scalar.primitive is None: - raise ValueError("Only primitive values can be used in comparison") + if lhs.val.scalar.union and lhs.val.scalar.union.value.scalar: + if lhs.val.scalar.union.value.scalar.primitive or lhs.val.scalar.union.value.scalar.none_type: + self._lhs = lhs.val.scalar.union.value + else: + raise ValueError("Only primitive values can be used in comparison") + else: + raise ValueError("Only primitive values can be used in comparison") if isinstance(rhs, Promise): self._rhs = rhs if rhs.is_ready: if rhs.val.scalar is None or rhs.val.scalar.primitive is None: - raise ValueError("Only primitive values can be used in comparison") + if rhs.val.scalar.union and rhs.val.scalar.union.value.scalar: + if rhs.val.scalar.union.value.scalar.primitive or rhs.val.scalar.union.value.scalar.none_type: + self._rhs = rhs.val.scalar.union.value + else: + raise ValueError("Only primitive values can be used in comparison") + else: + raise ValueError("Only primitive values can be used in comparison") if self._lhs is None: self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) if self._rhs is None: @@ -167,6 +181,8 @@ def eval(self) -> bool: if isinstance(self.rhs, Promise): rhs = self.rhs.eval() + elif self.rhs.scalar.none_type: + rhs = None else: rhs = get_primitive_val(self.rhs.scalar.primitive) @@ -353,6 +369,9 @@ def is_false(self) -> ComparisonExpression: def is_true(self): return self.is_(True) + def is_none(self): + return self.is_(None) + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8966564b28..bc557cbe00 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -862,7 +862,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if python_val is None and expected.union_type is None: + if python_val is None and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: From c3558d3104d32dbdfca1f0689fd799e9d2069555 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jul 2023 12:22:29 -0700 Subject: [PATCH 2/7] wip Signed-off-by: Kevin Su --- flytekit/core/condition.py | 2 ++ flytekit/core/promise.py | 24 +++++++++++++----------- flytekit/models/core/condition.py | 13 ++++++++++++- 3 files changed, 27 insertions(+), 12 deletions(-) diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index 76553db702..4b8167ffff 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -391,6 +391,8 @@ def transform_to_conj_expr( def transform_to_operand(v: Union[Promise, Literal]) -> Tuple[_core_cond.Operand, Optional[Promise]]: if isinstance(v, Promise): return _core_cond.Operand(var=create_branch_node_promise_var(v.ref.node_id, v.var)), v + if v.scalar.none_type: + return _core_cond.Operand(scalar=v.scalar), None return _core_cond.Operand(primitive=v.scalar.primitive), None diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 0a8bb0817c..eeab06b174 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -85,9 +85,7 @@ def my_wf(in1: int, in2: int) -> int: return result -def get_primitive_val(prim: Optional[Primitive] = None) -> Any: - if prim is None: - return None +def get_primitive_val(prim: Primitive) -> Any: for value in [ prim.integer, prim.float_value, @@ -138,9 +136,10 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._lhs = lhs if lhs.is_ready: if lhs.val.scalar is None or lhs.val.scalar.primitive is None: - if lhs.val.scalar.union and lhs.val.scalar.union.value.scalar: - if lhs.val.scalar.union.value.scalar.primitive or lhs.val.scalar.union.value.scalar.none_type: - self._lhs = lhs.val.scalar.union.value + union = lhs.val.scalar.union + if union and union.value.scalar: + if union.value.scalar.primitive or union.value.scalar.none_type: + self._lhs = union.value else: raise ValueError("Only primitive values can be used in comparison") else: @@ -149,9 +148,10 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._rhs = rhs if rhs.is_ready: if rhs.val.scalar is None or rhs.val.scalar.primitive is None: - if rhs.val.scalar.union and rhs.val.scalar.union.value.scalar: - if rhs.val.scalar.union.value.scalar.primitive or rhs.val.scalar.union.value.scalar.none_type: - self._rhs = rhs.val.scalar.union.value + union = rhs.val.scalar.union + if union and union.value.scalar: + if union.value.scalar.primitive or union.value.scalar.none_type: + self._rhs = union.value else: raise ValueError("Only primitive values can be used in comparison") else: @@ -176,6 +176,8 @@ def op(self) -> ComparisonOps: def eval(self) -> bool: if isinstance(self.lhs, Promise): lhs = self.lhs.eval() + elif self.lhs.scalar.none_type: + lhs = None else: lhs = get_primitive_val(self.lhs.scalar.primitive) @@ -366,10 +368,10 @@ def is_(self, v: bool) -> ComparisonExpression: def is_false(self) -> ComparisonExpression: return self.is_(False) - def is_true(self): + def is_true(self) -> ComparisonExpression: return self.is_(True) - def is_none(self): + def is_none(self) -> ComparisonExpression: return self.is_(None) def __eq__(self, other) -> ComparisonExpression: # type: ignore diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 845b3b4f79..fc87497f35 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -134,7 +134,7 @@ def from_flyte_idl(cls, pb2_object): class Operand(_common.FlyteIdlEntity): - def __init__(self, primitive=None, var=None): + def __init__(self, primitive=None, var=None, scalar=None): """ Defines an operand to a comparison expression. :param flytekit.models.literals.Primitive primitive: @@ -143,6 +143,7 @@ def __init__(self, primitive=None, var=None): self._primitive = primitive self._var = var + self._scalar = scalar @property def primitive(self): @@ -160,6 +161,14 @@ def var(self): return self._var + @property + def scalar(self): + """ + :rtype: flytekit.models.literals.Scalar + """ + + return self._scalar + def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.Operand @@ -167,6 +176,7 @@ def to_flyte_idl(self): return _condition.Operand( primitive=self.primitive.to_flyte_idl() if self.primitive else None, var=self.var if self.var else None, + scalar=self.scalar.to_flyte_idl() if self.scalar else None, ) @classmethod @@ -176,6 +186,7 @@ def from_flyte_idl(cls, pb2_object): if pb2_object.HasField("primitive") else None, var=pb2_object.var if pb2_object.HasField("var") else None, + scalar=pb2_object.scalar if pb2_object.HasField("scalar") else None, ) From 9664523bd635cf8b53a8d656ae1719d45aac35eb Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jul 2023 12:29:28 -0700 Subject: [PATCH 3/7] wip Signed-off-by: Kevin Su --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9aefaf91d6..2e419c389c 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.10", + # "flyteidl>=1.5.10", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", From c8df21d076b7bc231fb6197187e4cf4c0059019f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Jul 2023 14:40:14 -0700 Subject: [PATCH 4/7] add Signed-off-by: Kevin Su --- flytekit/core/promise.py | 2 +- flytekit/models/core/condition.py | 2 +- tests/flytekit/unit/core/test_conditions.py | 19 ++++++ .../unit/models/core/test_workflow.py | 67 +++++++++++++++++++ 4 files changed, 88 insertions(+), 2 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index eeab06b174..cb1062db9a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -372,7 +372,7 @@ def is_true(self) -> ComparisonExpression: return self.is_(True) def is_none(self) -> ComparisonExpression: - return self.is_(None) + return ComparisonExpression(self, ComparisonOps.EQ, None) def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index fc87497f35..ab63232cb7 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -186,7 +186,7 @@ def from_flyte_idl(cls, pb2_object): if pb2_object.HasField("primitive") else None, var=pb2_object.var if pb2_object.HasField("var") else None, - scalar=pb2_object.scalar if pb2_object.HasField("scalar") else None, + scalar=_literals.Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None, ) diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index 7b0b292baa..0da2467109 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -194,6 +194,25 @@ def decompose() -> int: assert decompose() == 20 +def test_condition_is_none(): + @task + def return_true() -> typing.Optional[None]: + return None + + @workflow + def failed() -> int: + return 10 + + @workflow + def success() -> int: + return 20 + + @workflow + def decompose_unary() -> int: + result = return_true() + return conditional("test").if_(result.is_none()).then(success()).else_().then(failed()) + + def test_subworkflow_condition_serialization(): """Test that subworkflows are correctly extracted from serialized workflows with condiationals.""" diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index de83f66f78..6775d58940 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -228,6 +228,73 @@ def test_branch_node(): assert bn.if_else.case.then_node == obj +def test_branch_node_with_none(): + nm = _get_sample_node_metadata() + task = _workflow.TaskNode(reference_id=_generic_id) + bd = _literals.BindingData(scalar=_literals.Scalar(none_type=_literals.Void())) + lt = _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99))) + bd2 = _literals.BindingData( + scalar=_literals.Scalar( + union=_literals.Union(value=lt, stored_type=_types.LiteralType(_types.SimpleType.INTEGER)) + ) + ) + binding = _literals.Binding(var="myvar", binding=bd) + binding2 = _literals.Binding(var="myothervar", binding=bd2) + + obj = _workflow.Node( + id="some:node:id", + metadata=nm, + inputs=[binding, binding2], + upstream_node_ids=[], + output_aliases=[], + task_node=task, + ) + + bn = _workflow.BranchNode( + _workflow.IfElseBlock( + case=_workflow.IfBlock( + condition=_condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + then_node=obj, + ), + other=[ + _workflow.IfBlock( + condition=_condition.BooleanExpression( + conjunction=_condition.ConjunctionExpression( + _condition.ConjunctionExpression.LogicalOperator.AND, + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + ) + ), + then_node=obj, + ) + ], + else_node=obj, + ) + ) + + bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl()) + assert bn == bn2 + assert bn.if_else.case.then_node == obj + + def test_task_node_overrides(): overrides = _workflow.TaskNodeOverrides( Resources( From d042877657ee173d597ad2cd00cf8c7ed89355fe Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 17 Aug 2023 09:47:49 -0700 Subject: [PATCH 5/7] update idl Signed-off-by: Kevin Su --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2e419c389c..60f9aab888 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - # "flyteidl>=1.5.10", + "flyteidl>=1.5.16", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", From 2a344df55a41a45103640f9c1f5ee67e9c9af4b2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 17 Aug 2023 09:54:39 -0700 Subject: [PATCH 6/7] update idl Signed-off-by: Kevin Su --- doc-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc-requirements.txt b/doc-requirements.txt index 6ca7fbc1ee..10479245ab 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -249,7 +249,7 @@ flask==2.3.2 # via mlflow flatbuffers==23.5.26 # via tensorflow -flyteidl==1.5.14 +flyteidl==1.5.16 # via flytekit fonttools==4.41.1 # via matplotlib From 2eddcf6b42f79d4f6caf6d5976b0af3597f33154 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Mon, 28 Aug 2023 13:34:02 -0700 Subject: [PATCH 7/7] Update comment Signed-off-by: Kevin Su --- flytekit/models/core/condition.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index ab63232cb7..27e0bc505b 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -137,8 +137,9 @@ class Operand(_common.FlyteIdlEntity): def __init__(self, primitive=None, var=None, scalar=None): """ Defines an operand to a comparison expression. - :param flytekit.models.literals.Primitive primitive: - :param Text var: + :param flytekit.models.literals.Primitive primitive: A primitive value + :param Text var: A variable name + :param flytekit.models.literals.Scalar scalar: A scalar value """ self._primitive = primitive