From fc6f98c4c894b0bbfcff7c1fae0ef304a1e6687e Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Wed, 4 Dec 2024 20:21:26 -0800 Subject: [PATCH 1/2] Set map task metadata only for subnode (#2979) * set metadata for subnode only Signed-off-by: Paul Dittamo * update unit test Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo Signed-off-by: Eduardo Apolinario --- flytekit/core/array_node_map_task.py | 5 +++++ flytekit/tools/translator.py | 2 +- .../unit/core/test_array_node_map_task.py | 17 ++++++++++++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 94454f417b..256d5dae13 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -139,6 +139,11 @@ def python_interface(self): def construct_node_metadata(self) -> NodeMetadata: # TODO: add support for other Flyte entities + return NodeMetadata( + name=self.name, + ) + + def construct_sub_node_metadata(self) -> NodeMetadata: nm = super().construct_node_metadata() nm._name = self.name return nm diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b357ae3385..e792dcc74d 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -659,7 +659,7 @@ def get_serializable_array_node_map_task( ) node = workflow_model.Node( id=entity.name, - metadata=entity.construct_node_metadata(), + metadata=entity.construct_sub_node_metadata(), inputs=node.bindings, upstream_node_ids=[], output_aliases=[], diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index fae81d1355..f716a6b5ef 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,4 +1,5 @@ import functools +from datetime import timedelta import os import typing from collections import OrderedDict @@ -377,7 +378,12 @@ def test_serialization_metadata2(serialization_settings): def t1(a: int) -> int: return a + 1 - arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2, interruptible=True)) + arraynode_maptask = map_task( + t1, + min_success_ratio=0.9, + concurrency=10, + metadata=TaskMetadata(retries=2, interruptible=True, timeout=timedelta(seconds=10)) + ) assert arraynode_maptask.metadata.interruptible @workflow @@ -387,11 +393,16 @@ def wf(x: typing.List[int]): od = OrderedDict() wf_spec = get_serializable(od, serialization_settings, wf) - assert arraynode_maptask.construct_node_metadata().interruptible - assert wf_spec.template.nodes[0].metadata.interruptible + array_node = wf_spec.template.nodes[0] + assert array_node.metadata.timeout == timedelta() + assert array_node.array_node._min_success_ratio == 0.9 + assert array_node.array_node._parallelism == 10 + assert not array_node.array_node._is_original_sub_node_interface + assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible + assert task_spec.template.metadata.timeout == timedelta(seconds=10) def test_serialization_extended_resources(serialization_settings): From 1c95a4641d7297a2944adba38749fe397ce24616 Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Wed, 4 Dec 2024 20:21:26 -0800 Subject: [PATCH 2/2] Set map task metadata only for subnode (#2979) * set metadata for subnode only Signed-off-by: Paul Dittamo * update unit test Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo Signed-off-by: Eduardo Apolinario --- tests/flytekit/unit/core/test_array_node_map_task.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index f716a6b5ef..2d7169eaba 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -393,12 +393,7 @@ def wf(x: typing.List[int]): od = OrderedDict() wf_spec = get_serializable(od, serialization_settings, wf) - array_node = wf_spec.template.nodes[0] - assert array_node.metadata.timeout == timedelta() - assert array_node.array_node._min_success_ratio == 0.9 - assert array_node.array_node._parallelism == 10 - assert not array_node.array_node._is_original_sub_node_interface - assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE + assert wf_spec.template.nodes[0].metadata.timeout == timedelta() task_spec = od[arraynode_maptask] assert task_spec.template.metadata.retries.retries == 2 assert task_spec.template.metadata.interruptible