Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add is none function #1757

Merged
merged 8 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ flask==2.3.2
# via mlflow
flatbuffers==23.5.26
# via tensorflow
flyteidl==1.5.14
flyteidl==1.5.16
# via flytekit
fonttools==4.41.1
# via matplotlib
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ def transform_to_conj_expr(
def transform_to_operand(v: Union[Promise, Literal]) -> Tuple[_core_cond.Operand, Optional[Promise]]:
if isinstance(v, Promise):
return _core_cond.Operand(var=create_branch_node_promise_var(v.ref.node_id, v.var)), v
if v.scalar.none_type:
return _core_cond.Operand(scalar=v.scalar), None
return _core_cond.Operand(primitive=v.scalar.primitive), None


Expand Down
27 changes: 24 additions & 3 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,26 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr
self._lhs = lhs
if lhs.is_ready:
if lhs.val.scalar is None or lhs.val.scalar.primitive is None:
raise ValueError("Only primitive values can be used in comparison")
union = lhs.val.scalar.union
if union and union.value.scalar:
if union.value.scalar.primitive or union.value.scalar.none_type:
self._lhs = union.value
else:
raise ValueError("Only primitive values can be used in comparison")
else:
raise ValueError("Only primitive values can be used in comparison")
if isinstance(rhs, Promise):
self._rhs = rhs
if rhs.is_ready:
if rhs.val.scalar is None or rhs.val.scalar.primitive is None:
raise ValueError("Only primitive values can be used in comparison")
union = rhs.val.scalar.union
if union and union.value.scalar:
if union.value.scalar.primitive or union.value.scalar.none_type:
self._rhs = union.value
else:
raise ValueError("Only primitive values can be used in comparison")
else:
raise ValueError("Only primitive values can be used in comparison")
if self._lhs is None:
self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None)
if self._rhs is None:
Expand All @@ -162,11 +176,15 @@ def op(self) -> ComparisonOps:
def eval(self) -> bool:
if isinstance(self.lhs, Promise):
lhs = self.lhs.eval()
elif self.lhs.scalar.none_type:
lhs = None
else:
lhs = get_primitive_val(self.lhs.scalar.primitive)

if isinstance(self.rhs, Promise):
rhs = self.rhs.eval()
elif self.rhs.scalar.none_type:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
rhs = None
else:
rhs = get_primitive_val(self.rhs.scalar.primitive)

Expand Down Expand Up @@ -350,9 +368,12 @@ def is_(self, v: bool) -> ComparisonExpression:
def is_false(self) -> ComparisonExpression:
return self.is_(False)

def is_true(self):
def is_true(self) -> ComparisonExpression:
return self.is_(True)

def is_none(self) -> ComparisonExpression:
return ComparisonExpression(self, ComparisonOps.EQ, None)

def __eq__(self, other) -> ComparisonExpression: # type: ignore
return ComparisonExpression(self, ComparisonOps.EQ, other)

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)
if python_val is None and expected.union_type is None:
if python_val is None and expected and expected.union_type is None:
raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}")
transformer = cls.get_transformer(python_type)
if transformer.type_assertions_enabled:
Expand Down
18 changes: 15 additions & 3 deletions flytekit/models/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,17 @@ def from_flyte_idl(cls, pb2_object):


class Operand(_common.FlyteIdlEntity):
def __init__(self, primitive=None, var=None):
def __init__(self, primitive=None, var=None, scalar=None):
"""
Defines an operand to a comparison expression.
:param flytekit.models.literals.Primitive primitive:
:param Text var:
:param flytekit.models.literals.Primitive primitive: A primitive value
:param Text var: A variable name
:param flytekit.models.literals.Scalar scalar: A scalar value
"""

self._primitive = primitive
self._var = var
self._scalar = scalar

@property
def primitive(self):
Expand All @@ -160,13 +162,22 @@ def var(self):

return self._var

@property
def scalar(self):
"""
:rtype: flytekit.models.literals.Scalar
"""

return self._scalar

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.condition_pb2.Operand
"""
return _condition.Operand(
primitive=self.primitive.to_flyte_idl() if self.primitive else None,
var=self.var if self.var else None,
scalar=self.scalar.to_flyte_idl() if self.scalar else None,
)

@classmethod
Expand All @@ -176,6 +187,7 @@ def from_flyte_idl(cls, pb2_object):
if pb2_object.HasField("primitive")
else None,
var=pb2_object.var if pb2_object.HasField("var") else None,
scalar=_literals.Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None,
)


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.5.14",
"flyteidl>=1.5.16",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ def decompose() -> int:
assert decompose() == 20


def test_condition_is_none():
@task
def return_true() -> typing.Optional[None]:
return None

@workflow
def failed() -> int:
return 10

@workflow
def success() -> int:
return 20

@workflow
def decompose_unary() -> int:
result = return_true()
return conditional("test").if_(result.is_none()).then(success()).else_().then(failed())


def test_subworkflow_condition_serialization():
"""Test that subworkflows are correctly extracted from serialized workflows with condiationals."""

Expand Down
67 changes: 67 additions & 0 deletions tests/flytekit/unit/models/core/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,73 @@ def test_branch_node():
assert bn.if_else.case.then_node == obj


def test_branch_node_with_none():
nm = _get_sample_node_metadata()
task = _workflow.TaskNode(reference_id=_generic_id)
bd = _literals.BindingData(scalar=_literals.Scalar(none_type=_literals.Void()))
lt = _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99)))
bd2 = _literals.BindingData(
scalar=_literals.Scalar(
union=_literals.Union(value=lt, stored_type=_types.LiteralType(_types.SimpleType.INTEGER))
)
)
binding = _literals.Binding(var="myvar", binding=bd)
binding2 = _literals.Binding(var="myothervar", binding=bd2)

obj = _workflow.Node(
id="some:node:id",
metadata=nm,
inputs=[binding, binding2],
upstream_node_ids=[],
output_aliases=[],
task_node=task,
)

bn = _workflow.BranchNode(
_workflow.IfElseBlock(
case=_workflow.IfBlock(
condition=_condition.BooleanExpression(
comparison=_condition.ComparisonExpression(
_condition.ComparisonExpression.Operator.EQ,
_condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())),
_condition.Operand(primitive=_literals.Primitive(integer=2)),
)
),
then_node=obj,
),
other=[
_workflow.IfBlock(
condition=_condition.BooleanExpression(
conjunction=_condition.ConjunctionExpression(
_condition.ConjunctionExpression.LogicalOperator.AND,
_condition.BooleanExpression(
comparison=_condition.ComparisonExpression(
_condition.ComparisonExpression.Operator.EQ,
_condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())),
_condition.Operand(primitive=_literals.Primitive(integer=2)),
)
),
_condition.BooleanExpression(
comparison=_condition.ComparisonExpression(
_condition.ComparisonExpression.Operator.EQ,
_condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())),
_condition.Operand(primitive=_literals.Primitive(integer=2)),
)
),
)
),
then_node=obj,
)
],
else_node=obj,
)
)

bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl())
assert bn == bn2
assert bn.if_else.case.then_node == obj


def test_task_node_overrides():
overrides = _workflow.TaskNodeOverrides(
Resources(
Expand Down