From 87bfcc79332ab00533cfbe38b3ef9f7ffe96e21a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 29 Jul 2024 16:02:46 -0700 Subject: [PATCH] [Artifacts/Elastic] Skip partitions (#2620) Signed-off-by: Yee Hing Tong Signed-off-by: mao3267 --- flytekit/core/artifact.py | 2 ++ .../flytekitplugins/kfpytorch/task.py | 4 ++-- tests/flytekit/unit/core/test_artifacts.py | 12 ++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/flytekit/core/artifact.py b/flytekit/core/artifact.py index 954151504f..fba84187b3 100644 --- a/flytekit/core/artifact.py +++ b/flytekit/core/artifact.py @@ -318,6 +318,8 @@ def set_reference_artifact(self, artifact: Artifact): p.reference_artifact = artifact def __getattr__(self, item): + if item == "partitions" or item == "_partitions": + raise AttributeError("Partitions in an uninitialized state, skipping partitions") if self.partitions and item in self.partitions: return self.partitions[item] raise AttributeError(f"Partition {item} not found in {self}") diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index ad9b5368b0..3384c9cacc 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -241,7 +241,7 @@ class ElasticWorkerResult(NamedTuple): return_value: Any decks: List[flytekit.Deck] - om: OutputMetadata + om: Optional[OutputMetadata] = None def spawn_helper( @@ -435,7 +435,7 @@ def fn_partial(): if isinstance(e, FlyteRecoverableException): create_recoverable_error_file() raise - return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks) + return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=None) launcher_target_func = fn_partial launcher_args = () diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index 2eccdf52d5..9437d16add 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -619,3 +619,15 @@ def test_lims(): # test an artifact with 11 partition keys with pytest.raises(ValueError): Artifact(name="test artifact", time_partitioned=True, partition_keys=[f"key_{i}" for i in range(11)]) + + +def test_cloudpickle(): + a1_b = Artifact(name="my_data", partition_keys=["b"]) + + spec = a1_b(b="my_b_value") + import cloudpickle + + d = cloudpickle.dumps(spec) + spec2 = cloudpickle.loads(d) + + assert spec2.partitions.b.value.static_value == "my_b_value"