diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 2b21f7c40b..2cf8032a6f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -258,7 +258,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # The cache returns None iff the key does not exist in the cache if outputs_literal_map is None: logger.info("Cache miss, task will be executed now") - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) # TODO: need `native_inputs` LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map) logger.info( @@ -268,10 +268,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: logger.info("Cache hit") else: - es = ctx.execution_state - b = es.user_space_params.with_task_sandbox() - ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() - outputs_literal_map = self.dispatch_execute(ctx, input_literal_map) + # This code should mirror the call to `sandbox_execute` in the above cache case. + # Code is simpler with duplication and less metaprogramming, but introduces regressions + # if one is changed and not the other. + outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special @@ -326,6 +326,19 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def sandbox_execute( + self, + ctx: FlyteContext, + input_literal_map: _literal_models.LiteralMap, + ) -> _literal_models.LiteralMap: + """ + Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime. + """ + es = ctx.execution_state + b = es.user_space_params.with_task_sandbox() + ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build() + return self.dispatch_execute(ctx, input_literal_map) + @abstractmethod def dispatch_execute( self, diff --git a/tests/flytekit/unit/core/test_checkpoint.py b/tests/flytekit/unit/core/test_checkpoint.py index f1dbbbd5ef..2add1b9e7d 100644 --- a/tests/flytekit/unit/core/test_checkpoint.py +++ b/tests/flytekit/unit/core/test_checkpoint.py @@ -4,6 +4,7 @@ import flytekit from flytekit.core.checkpointer import SyncCheckpoint +from flytekit.core.local_cache import LocalTaskCache def test_sync_checkpoint_write(tmpdir): @@ -123,5 +124,23 @@ def t1(n: int) -> int: return n + 1 +@flytekit.task(cache=True, cache_version="v0") +def t2(n: int) -> int: + ctx = flytekit.current_context() + cp = ctx.checkpoint + cp.write(bytes(n + 1)) + return n + 1 + + +@pytest.fixture(scope="function", autouse=True) +def setup(): + LocalTaskCache.initialize() + LocalTaskCache.clear() + + def test_checkpoint_task(): assert t1(n=5) == 6 + + +def test_checkpoint_cached_task(): + assert t2(n=5) == 6