Skip to content

Commit

Permalink
merge master
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Feb 27, 2024
2 parents 240c7e0 + 95ea92d commit 19e8e56
Show file tree
Hide file tree
Showing 25 changed files with 300 additions and 39 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ on:

env:
FLYTE_SDK_LOGGING_LEVEL: 10 # debug
# --dist loadscope guarantees that tests in the same module are picked up by the same worker (i.e. run serially in a worker),
# as per https://pytest-xdist.readthedocs.io/en/stable/distribution.html.
PYTEST_OPTS: -n2 --dist loadscope

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
Expand Down
16 changes: 10 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ export REPOSITORY=flytekit

PIP_COMPILE = pip-compile --upgrade --verbose --resolver=backtracking
MOCK_FLYTE_REPO=tests/flytekit/integration/remote/mock_flyte_repo/workflows
PYTEST_OPTS ?= -n auto --dist=loadscope
PYTEST = pytest ${PYTEST_OPTS}
PYTEST_OPTS ?= -n auto --dist=loadfile
PYTEST_AND_OPTS = pytest ${PYTEST_OPTS}
PYTEST = pytest

.SILENT: help
.PHONY: help
Expand Down Expand Up @@ -62,19 +63,22 @@ unit_test_extras_codecov:
unit_test:
# Skip all extra tests and run them with the necessary env var set so that a working (albeit slower)
# library is used to serialize/deserialize protobufs is used.
$(PYTEST) -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models ${CODECOV_OPTS}
$(PYTEST_AND_OPTS) -m "not (serial or sandbox_test)" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models ${CODECOV_OPTS}
# Run serial tests without any parallelism
$(PYTEST) -m "serial" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models ${CODECOV_OPTS}


.PHONY: unit_test_extras
unit_test_extras:
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST) tests/flytekit/unit/extras ${CODECOV_OPTS}
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST_AND_OPTS) tests/flytekit/unit/extras ${CODECOV_OPTS}

.PHONY: test_serialization_codecov
test_serialization_codecov:
$(MAKE) CODECOV_OPTS="--cov=./ --cov-report=xml --cov-append" test_serialization

.PHONY: test_serialization
test_serialization:
$(PYTEST) tests/flytekit/unit/models ${CODECOV_OPTS}
$(PYTEST_AND_OPTS) tests/flytekit/unit/models ${CODECOV_OPTS}


.PHONY: integration_test_codecov
Expand All @@ -83,7 +87,7 @@ integration_test_codecov:

.PHONY: integration_test
integration_test:
$(PYTEST) tests/flytekit/integration ${CODECOV_OPTS}
$(PYTEST_AND_OPTS) tests/flytekit/integration ${CODECOV_OPTS}

doc-requirements.txt: export CUSTOM_COMPILE_COMMAND := make doc-requirements.txt
doc-requirements.txt: doc-requirements.in install-piptools
Expand Down
2 changes: 2 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ def _run(*args, **kwargs):
if run_level_params.envvars:
for env_var, value in run_level_params.envvars.items():
os.environ[env_var] = value
if run_level_params.overwrite_cache:
os.environ["FLYTE_LOCAL_CACHE_OVERWRITE"] = "true"
output = entity(**inputs)
if inspect.iscoroutine(output):
# TODO: make eager mode workflows run with local-mode
Expand Down
2 changes: 2 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,12 +618,14 @@ class LocalConfig(object):
"""

cache_enabled: bool = True
cache_overwrite: bool = False

@classmethod
def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> LocalConfig:
config_file = get_config_file(config_file)
kwargs = {}
kwargs = set_if_exists(kwargs, "cache_enabled", _internal.Local.CACHE_ENABLED.read(config_file))
kwargs = set_if_exists(kwargs, "cache_overwrite", _internal.Local.CACHE_OVERWRITE.read(config_file))
return LocalConfig(**kwargs)


Expand Down
1 change: 1 addition & 0 deletions flytekit/configuration/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class AZURE(object):
class Local(object):
SECTION = "local"
CACHE_ENABLED = ConfigEntry(LegacyConfigEntry(SECTION, "cache_enabled", bool))
CACHE_OVERWRITE = ConfigEntry(LegacyConfigEntry(SECTION, "cache_overwrite", bool))


class Credentials(object):
Expand Down
15 changes: 10 additions & 5 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,24 @@ def local_execute(
f"Checking cache for task named {self.name}, cache version {self.metadata.cache_version} "
f"and inputs: {input_literal_map}"
)
outputs_literal_map = LocalTaskCache.get(self.name, self.metadata.cache_version, input_literal_map)
# The cache returns None iff the key does not exist in the cache
if local_config.cache_overwrite:
outputs_literal_map = None
logger.info("Cache overwrite, task will be executed now")
else:
outputs_literal_map = LocalTaskCache.get(self.name, self.metadata.cache_version, input_literal_map)
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
else:
logger.info("Cache hit")
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
# TODO: need `native_inputs`
LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map)
logger.info(
f"Cache set for task named {self.name}, cache version {self.metadata.cache_version} "
f"and inputs: {input_literal_map}"
)
else:
logger.info("Cache hit")
else:
# This code should mirror the call to `sandbox_execute` in the above cache case.
# Code is simpler with duplication and less metaprogramming, but introduces regressions
Expand Down
17 changes: 17 additions & 0 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def create(
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
ctx = FlyteContextManager.current_context()
default_inputs = default_inputs or {}
Expand Down Expand Up @@ -176,6 +177,7 @@ def create(
max_parallelism=max_parallelism,
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
)

# This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
Expand Down Expand Up @@ -205,6 +207,7 @@ def get_or_create(
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
"""
This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not
Expand Down Expand Up @@ -243,6 +246,8 @@ def get_or_create(
or auth_role is not None
or max_parallelism is not None
or security_context is not None
or trigger is not None
or overwrite_cache is not None
):
raise ValueError(
"Only named launchplans can be created that have other properties. Drop the name if you want to create a default launchplan. Default launchplans cannot have any other associations"
Expand Down Expand Up @@ -274,6 +279,7 @@ def get_or_create(
or raw_output_data_config != cached_outputs["_raw_output_data_config"]
or max_parallelism != cached_outputs["_max_parallelism"]
or security_context != cached_outputs["_security_context"]
or overwrite_cache != cached_outputs["_overwrite_cache"]
):
raise AssertionError("The cached values aren't the same as the current call arguments")

Expand All @@ -300,6 +306,7 @@ def get_or_create(
auth_role=auth_role,
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
)
LaunchPlan.CACHE[name or workflow.name] = lp
return lp
Expand All @@ -318,6 +325,8 @@ def __init__(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
additional_metadata: Optional[Any] = None,
):
self._name = name
self._workflow = workflow
Expand All @@ -336,6 +345,8 @@ def __init__(
self._max_parallelism = max_parallelism
self._security_context = security_context
self._trigger = trigger
self._overwrite_cache = overwrite_cache
self._additional_metadata = additional_metadata

FlyteEntities.entities.append(self)

Expand All @@ -352,6 +363,7 @@ def clone_with(
max_parallelism: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
) -> LaunchPlan:
return LaunchPlan(
name=name,
Expand All @@ -366,8 +378,13 @@ def clone_with(
max_parallelism=max_parallelism or self.max_parallelism,
security_context=security_context or self.security_context,
trigger=trigger,
overwrite_cache=overwrite_cache or self.overwrite_cache,
)

@property
def overwrite_cache(self) -> Optional[bool]:
return self._overwrite_cache

@property
def python_interface(self) -> Interface:
return self.workflow.python_interface
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ def get(task_name: str, cache_version: str, input_literal_map: LiteralMap) -> Op
def set(task_name: str, cache_version: str, input_literal_map: LiteralMap, value: LiteralMap) -> None:
if not LocalTaskCache._initialized:
LocalTaskCache.initialize()
LocalTaskCache._cache.add(_calculate_cache_key(task_name, cache_version, input_literal_map), value)
LocalTaskCache._cache.set(_calculate_cache_key(task_name, cache_version, input_literal_map), value)
15 changes: 15 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,21 @@ def with_overrides(self, *args, **kwargs):
assert_not_promise(v, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl())

if "cache" in kwargs:
v = kwargs["cache"]
assert_not_promise(v, "cache")
self._metadata._cacheable = kwargs["cache"]

if "cache_version" in kwargs:
v = kwargs["cache_version"]
assert_not_promise(v, "cache_version")
self._metadata._cache_version = kwargs["cache_version"]

if "cache_serialize" in kwargs:
v = kwargs["cache_serialize"]
assert_not_promise(v, "cache_serialize")
self._metadata._cache_serializable = kwargs["cache_serialize"]

return self


Expand Down
29 changes: 22 additions & 7 deletions flytekit/extras/tasks/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def interpolate(
T = typing.TypeVar("T")


def _run_script(script) -> typing.Tuple[int, str, str]:
def _run_script(script: str, shell: str) -> typing.Tuple[int, str, str]:
"""
Run script as a subprocess and return the returncode, stdout, and stderr.
Expand All @@ -154,10 +154,20 @@ def _run_script(script) -> typing.Tuple[int, str, str]:
:param script: script to be executed
:type script: str
:param shell: shell to use to run the script
:type shell: str
:return: tuple containing the process returncode, stdout, and stderr
:rtype: typing.Tuple[int, str, str]
"""
process = subprocess.Popen(script, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0, shell=True, text=True)
process = subprocess.Popen(
script,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
bufsize=0,
shell=True,
text=True,
executable=shell,
)

process_stdout, process_stderr = process.communicate()
out = ""
Expand All @@ -179,6 +189,7 @@ def __init__(
script: typing.Optional[str] = None,
script_file: typing.Optional[str] = None,
task_config: T = None,
shell: str = "/bin/sh",
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
output_locs: typing.Optional[typing.List[OutputLocation]] = None,
**kwargs,
Expand All @@ -189,7 +200,8 @@ def __init__(
debug: bool Print the generated script and other debugging information
script: The actual script specified as a string
script_file: A path to the file that contains the script (Only script or script_file) can be provided
task_config: T Configuration for the task, can be either a Pod (or coming soon, BatchJob) config
task_config: Configuration for the task, can be either a Pod (or coming soon, BatchJob) config
shell: Shell to use to run the script
inputs: A Dictionary of input names to types
output_locs: A list of :py:class:`OutputLocations`
**kwargs: Other arguments that can be passed to
Expand Down Expand Up @@ -223,6 +235,7 @@ def __init__(
self._script = script
self._script_file = script_file
self._debug = debug
self._shell = shell
self._output_locs = output_locs if output_locs else []
self._interpolizer = _PythonFStringInterpolizer()
outputs = self._validate_output_locs()
Expand Down Expand Up @@ -284,11 +297,13 @@ def execute(self, **kwargs) -> typing.Any:
print(gen_script)
print("\n==============================================\n")

if platform.system() == "Windows" and os.environ.get("ComSpec") is None:
# https://github.com/python/cpython/issues/101283
os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe"
if platform.system() == "Windows":
if os.environ.get("ComSpec") is None:
# https://github.com/python/cpython/issues/101283
os.environ["ComSpec"] = "C:\\Windows\\System32\\cmd.exe"
self._shell = os.environ["ComSpec"]

returncode, stdout, stderr = _run_script(gen_script)
returncode, stdout, stderr = _run_script(gen_script, self._shell)
if returncode != 0:
files = os.listdir(".")
fstr = "\n-".join(files)
Expand Down
43 changes: 37 additions & 6 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,34 @@ def from_flyte_idl(cls, pb2_objct):


class NodeMetadata(_common.FlyteIdlEntity):
def __init__(self, name, timeout=None, retries=None, interruptible=None):
def __init__(
self,
name,
timeout=None,
retries=None,
interruptible: typing.Optional[bool] = None,
cacheable: typing.Optional[bool] = None,
cache_version: typing.Optional[str] = None,
cache_serializable: typing.Optional[bool] = None,
):
"""
Defines extra information about the Node.
:param Text name: Friendly name for the Node.
:param datetime.timedelta timeout: [Optional] Overall timeout for a task.
:param flytekit.models.literals.RetryStrategy retries: [Optional] Number of retries per task.
:param bool interruptible: [Optional] Can be safely interrupted during execution.
:param bool interruptible: Can be safely interrupted during execution.
:param cacheable: Indicates that this nodes outputs should be cached.
:param cache_version: The version of the cached data.
:param cacheable: Indicates that cache operations on this node should be serialized.
"""
self._name = name
self._timeout = timeout if timeout is not None else datetime.timedelta()
self._retries = retries if retries is not None else _RetryStrategy(0)
self._interruptible = interruptible
self._cacheable = cacheable
self._cache_version = cache_version
self._cache_serializable = cache_serializable

@property
def name(self):
Expand All @@ -198,12 +213,21 @@ def retries(self):
return self._retries

@property
def interruptible(self):
"""
:rtype: flytekit.models
"""
def interruptible(self) -> typing.Optional[bool]:
return self._interruptible

@property
def cacheable(self) -> typing.Optional[bool]:
return self._cacheable

@property
def cache_version(self) -> typing.Optional[str]:
return self._cache_version

@property
def cache_serializable(self) -> typing.Optional[bool]:
return self._cache_serializable

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.workflow_pb2.NodeMetadata
Expand All @@ -212,6 +236,9 @@ def to_flyte_idl(self):
name=self.name,
retries=self.retries.to_flyte_idl(),
interruptible=self.interruptible,
cacheable=self.cacheable,
cache_version=self.cache_version,
cache_serializable=self.cache_serializable,
)
if self.timeout:
node_metadata.timeout.FromTimedelta(self.timeout)
Expand All @@ -223,6 +250,10 @@ def from_flyte_idl(cls, pb2_object):
pb2_object.name,
pb2_object.timeout.ToTimedelta(),
_RetryStrategy.from_flyte_idl(pb2_object.retries),
pb2_object.interruptible if pb2_object.HasField("interruptible") else None,
pb2_object.cacheable if pb2_object.HasField("cacheable") else None,
pb2_object.cache_version if pb2_object.HasField("cache_version") else None,
pb2_object.cache_serializable if pb2_object.HasField("cache_serializable") else None,
)


Expand Down
Loading

0 comments on commit 19e8e56

Please sign in to comment.