Skip to content

Commit

Permalink
cache ignore portion (#2275)
Browse files Browse the repository at this point in the history
Signed-off-by: troychiu <[email protected]>
  • Loading branch information
troychiu authored Mar 18, 2024
1 parent 3f45131 commit c8ac276
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 12 deletions.
23 changes: 19 additions & 4 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TaskMetadata(object):
cache (bool): Indicates if caching should be enabled. See :std:ref:`Caching <cookbook:caching>`
cache_serialize (bool): Indicates if identical (ie. same inputs) instances of this task should be executed in serial when caching is enabled. See :std:ref:`Caching <cookbook:caching>`
cache_version (str): Version to be used for the cached value
cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache
interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with
lower QoS guarantees that can include pre-emption. This can reduce the monetary cost executions incur at the
cost of performance penalties due to potential interruptions
Expand All @@ -112,6 +113,7 @@ class TaskMetadata(object):
cache: bool = False
cache_serialize: bool = False
cache_version: str = ""
cache_ignore_input_vars: Tuple[str, ...] = ()
interruptible: Optional[bool] = None
deprecated: str = ""
retries: int = 0
Expand All @@ -128,6 +130,10 @@ def __post_init__(self):
raise ValueError("Caching is enabled ``cache=True`` but ``cache_version`` is not set.")
if self.cache_serialize and not self.cache:
raise ValueError("Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled.")
if self.cache_ignore_input_vars and not self.cache:
raise ValueError(
f"Cache ignore input vars are specified ``cache_ignore_input_vars={self.cache_ignore_input_vars}`` but ``cache`` is not enabled."
)

@property
def retry_strategy(self) -> _literal_models.RetryStrategy:
Expand All @@ -151,6 +157,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
deprecated_error_message=self.deprecated,
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
)


Expand Down Expand Up @@ -281,13 +288,15 @@ def local_execute(
# TODO: how to get a nice `native_inputs` here?
logger.info(
f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} "
f"and inputs: {input_literal_map}"
f", inputs: {input_literal_map}, and ignore input vars: {self.metadata.cache_ignore_input_vars}"
)
if local_config.cache_overwrite:
outputs_literal_map = None
logger.info("Cache overwrite, task will be executed now")
else:
outputs_literal_map = LocalTaskCache.get(self.name, self.metadata.cache_version, input_literal_map)
outputs_literal_map = LocalTaskCache.get(
self.name, self.metadata.cache_version, input_literal_map, self.metadata.cache_ignore_input_vars
)
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
Expand All @@ -296,10 +305,16 @@ def local_execute(
if outputs_literal_map is None:
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
# TODO: need `native_inputs`
LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map)
LocalTaskCache.set(
self.name,
self.metadata.cache_version,
input_literal_map,
self.metadata.cache_ignore_input_vars,
outputs_literal_map,
)
logger.info(
f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} "
f"and inputs: {input_literal_map}"
f", inputs: {input_literal_map}, and ignore input vars: {self.metadata.cache_ignore_input_vars}"
)
else:
# This code should mirror the call to `sandbox_execute` in the above cache case.
Expand Down
28 changes: 22 additions & 6 deletions flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple

from diskcache import Cache

Expand Down Expand Up @@ -28,10 +28,14 @@ def _recursive_hash_placement(literal: Literal) -> Literal:
return literal


def _calculate_cache_key(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> str:
def _calculate_cache_key(
task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...] = ()
) -> str:
# Traverse the literals and replace the literal with a new literal that only contains the hash
literal_map_overridden = {}
for key, literal in input_literal_map.literals.items():
if key in cache_ignore_input_vars:
continue
literal_map_overridden[key] = _recursive_hash_placement(literal)

# Generate a stable representation of the underlying protobuf by passing `deterministic=True` to the
Expand Down Expand Up @@ -61,13 +65,25 @@ def clear():
LocalTaskCache._cache.clear()

@staticmethod
def get(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> Optional[LiteralMap]:
def get(
task_name: str, cache_version: str, input_literal_map: LiteralMap, cache_ignore_input_vars: Tuple[str, ...]
) -> Optional[LiteralMap]:
if not LocalTaskCache._initialized:
LocalTaskCache.initialize()
return LocalTaskCache._cache.get(_calculate_cache_key(task_name, cache_version, input_literal_map))
return LocalTaskCache._cache.get(
_calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars)
)

@staticmethod
def set(task_name: str, cache_version: str, input_literal_map: LiteralMap, value: LiteralMap) -> None:
def set(
task_name: str,
cache_version: str,
input_literal_map: LiteralMap,
cache_ignore_input_vars: Tuple[str, ...],
value: LiteralMap,
) -> None:
if not LocalTaskCache._initialized:
LocalTaskCache.initialize()
LocalTaskCache._cache.set(_calculate_cache_key(task_name, cache_version, input_literal_map), value)
LocalTaskCache._cache.set(
_calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars), value
)
7 changes: 6 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import datetime as _datetime
from functools import update_wrapper
from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union, overload
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, overload

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
Expand Down Expand Up @@ -91,6 +91,7 @@ def task(
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
Expand Down Expand Up @@ -122,6 +123,7 @@ def task(
cache: bool = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
retries: int = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
Expand Down Expand Up @@ -152,6 +154,7 @@ def task(
cache: bool = False,
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
retries: int = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
Expand Down Expand Up @@ -213,6 +216,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str:
:param cache_version: Cache version to use. Changes to the task signature will automatically trigger a cache miss,
but you can always manually update this field as well to force a cache miss. You should also manually bump
this version if the function body/business logic has changed, but the signature hasn't.
:param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache.
:param retries: Number of times to retry this task during a workflow execution.
:param interruptible: [Optional] Boolean that indicates that this task can be interrupted and/or scheduled on nodes
with lower QoS guarantees. This will directly reduce the `$`/`execution cost` associated,
Expand Down Expand Up @@ -295,6 +299,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
cache=cache,
cache_serialize=cache_serialize,
cache_version=cache_version,
cache_ignore_input_vars=cache_ignore_input_vars,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
Expand Down
13 changes: 13 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
deprecated_error_message,
cache_serializable,
pod_template_name,
cache_ignore_input_vars,
):
"""
Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts,
Expand All @@ -197,6 +198,7 @@ def __init__(
:param bool cache_serializable: Whether or not caching operations are executed in serial. This means only a
single instance over identical inputs is executed, other concurrent executions wait for the cached results.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache.
"""
self._discoverable = discoverable
self._runtime = runtime
Expand All @@ -207,6 +209,7 @@ def __init__(
self._deprecated_error_message = deprecated_error_message
self._cache_serializable = cache_serializable
self._pod_template_name = pod_template_name
self._cache_ignore_input_vars = cache_ignore_input_vars

@property
def discoverable(self):
Expand Down Expand Up @@ -284,6 +287,14 @@ def pod_template_name(self):
"""
return self._pod_template_name

@property
def cache_ignore_input_vars(self):
"""
Input variables that should not be included when calculating hash for cache.
:rtype: tuple[Text]
"""
return self._cache_ignore_input_vars

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.task_pb2.TaskMetadata
Expand All @@ -297,6 +308,7 @@ def to_flyte_idl(self):
deprecated_error_message=self.deprecated_error_message,
cache_serializable=self.cache_serializable,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
)
if self.timeout:
tm.timeout.FromTimedelta(self.timeout)
Expand All @@ -318,6 +330,7 @@ def from_flyte_idl(cls, pb2_object):
deprecated_error_message=pb2_object.deprecated_error_message,
cache_serializable=pb2_object.cache_serializable,
pod_template_name=pb2_object.pod_template_name,
cache_ignore_input_vars=pb2_object.cache_ignore_input_vars,
)


Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-bigquery/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self):
"This is deprecated!",
True,
"A",
(),
)
task_config = {
"Location": "us-central1",
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-openai/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def test_chatgpt_agent():
"This is deprecated!",
True,
"A",
(),
)
tmp = TaskTemplate(
id=task_id,
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-snowflake/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def test_snowflake_agent(mock_get_private_key):
"This is deprecated!",
True,
"A",
(),
)

task_config = {
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def test_databricks_agent():
"This is deprecated!",
True,
"A",
(),
)
task_config = {
"sparkConf": {
Expand Down
4 changes: 3 additions & 1 deletion tests/flytekit/common/parameterizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@
deprecated,
cache_serializable,
pod_template_name,
cache_ignore_input_vars,
)
for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name in product(
for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, cache_serializable, pod_template_name, cache_ignore_input_vars in product(
[True, False],
LIST_OF_RUNTIME_METADATA,
[timedelta(days=i) for i in range(3)],
Expand All @@ -135,6 +136,7 @@
["deprecated"],
[True, False],
["A", "B"],
[()],
)
]

Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/core/test_local_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import re
import sys
import typing
from dataclasses import dataclass
Expand Down Expand Up @@ -597,3 +598,30 @@ def t2(n: int) -> int:
@pytest.mark.serial
def test_checkpoint_cached_task():
assert t2(n=5) == 6


def test_cache_ignore_input_vars():
@task(cache=True, cache_version="v1", cache_ignore_input_vars=["a"])
def add(a: int, b: int) -> int:
return a + b

@workflow
def add_wf(a: int, b: int) -> int:
return add(a=a, b=b)

assert add_wf(a=10, b=5) == 15
assert add_wf(a=20, b=5) == 15 # since a is ignored, this line will hit cache of a=10, b=5
assert add_wf(a=20, b=8) == 28


def test_set_cache_ignore_input_vars_without_set_cache():
with pytest.raises(
ValueError,
match=re.escape(
"Cache ignore input vars are specified ``cache_ignore_input_vars=['a']`` but ``cache`` is not enabled."
),
):

@task(cache_ignore_input_vars=["a"])
def add(a: int, b: int) -> int:
return a + b
3 changes: 3 additions & 0 deletions tests/flytekit/unit/models/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def test_task_metadata():
"This is deprecated!",
True,
"A",
(),
)

assert obj.discoverable is True
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_task_spec():
"This is deprecated!",
True,
"A",
(),
)

int_type = types.LiteralType(types.SimpleType.INTEGER)
Expand Down Expand Up @@ -202,6 +204,7 @@ def test_task_template_k8s_pod_target():
"deprecated",
False,
"A",
(),
),
interface_models.TypedInterface(
# inputs
Expand Down
1 change: 1 addition & 0 deletions tests/flytekit/unit/models/test_workflow_closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_workflow_closure():
"This is deprecated!",
True,
"A",
(),
)

cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")
Expand Down

0 comments on commit c8ac276

Please sign in to comment.