diff --git a/dev-requirements.in b/dev-requirements.in index ce4171018b..27c17ac6d0 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -55,6 +55,8 @@ pyarrow scikit-learn types-requests prometheus-client +jupyter-client +ipykernel orjson kubernetes>=12.0.1 diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 97a9940425..6dab3e0cb0 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -828,6 +828,7 @@ class SerializationSettings(DataClassJsonMixin): can be fast registered (and thus omit building a Docker image) this object contains additional parameters for serialization. source_root (Optional[str]): The root directory of the source code. + interactive_mode_enabled (bool): Whether or not the task is being serialized in interactive mode. """ image_config: ImageConfig @@ -840,6 +841,7 @@ class SerializationSettings(DataClassJsonMixin): flytekit_virtualenv_root: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None source_root: Optional[str] = None + interactive_mode_enabled: bool = False def __post_init__(self): if self.flytekit_virtualenv_root is None: @@ -914,6 +916,7 @@ def new_builder(self) -> Builder: python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, source_root=self.source_root, + interactive_mode_enabled=self.interactive_mode_enabled, ) def should_fast_serialize(self) -> bool: @@ -965,6 +968,7 @@ class Builder(object): python_interpreter: Optional[str] = None fast_serialization_settings: Optional[FastSerializationSettings] = None source_root: Optional[str] = None + interactive_mode_enabled: bool = False def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder: self.fast_serialization_settings = fss @@ -982,4 +986,5 @@ def build(self) -> SerializationSettings: python_interpreter=self.python_interpreter, fast_serialization_settings=self.fast_serialization_settings, source_root=self.source_root, + interactive_mode_enabled=self.interactive_mode_enabled, ) diff --git a/flytekit/core/options.py b/flytekit/core/options.py index 79d46c2039..ad35bc3ea1 100644 --- a/flytekit/core/options.py +++ b/flytekit/core/options.py @@ -1,5 +1,6 @@ import typing from dataclasses import dataclass +from typing import Callable, Optional from flytekit.models import common as common_models from flytekit.models import security @@ -34,6 +35,9 @@ class Options(object): notifications: typing.Optional[typing.List[common_models.Notification]] = None disable_notifications: typing.Optional[bool] = None overwrite_cache: typing.Optional[bool] = None + file_uploader: Optional[Callable] = ( + None # This is used by the translator to upload task files, like pickled code etc + ) @classmethod def default_from( diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 0d1554d44f..cea96ac013 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -660,6 +660,14 @@ def _append_attr(self, key) -> Promise: return new_promise + def __getstate__(self) -> Dict[str, Any]: + # This func is used to pickle the object. + return vars(self) + + def __setstate__(self, state: Dict[str, Any]) -> None: + # This func is used to unpickle the object without infinite recursion. + vars(self).update(state) + def create_native_named_tuple( ctx: FlyteContext, diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index b1bc0052b7..6bfaded0a3 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -24,6 +24,7 @@ T = TypeVar("T") _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name" +PICKLE_FILE_PATH = "pkl.gz" class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC): @@ -163,6 +164,13 @@ def get_default_command(self, settings: SerializationSettings) -> List[str]: return container_args + def set_resolver(self, resolver: TaskResolverMixin): + """ + By default, flytekit uses the DefaultTaskResolver to resolve the task. This method allows the user to set a custom + task resolver. It can be useful to override the task resolver for specific cases like running tasks in the jupyter notebook. + """ + self._task_resolver = resolver + def set_command_fn(self, get_command_fn: Optional[Callable[[SerializationSettings], List[str]]] = None): """ By default, the task will run on the Flyte platform using the pyflyte-execute command. @@ -274,6 +282,34 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore default_task_resolver = DefaultTaskResolver() +class DefaultNotebookTaskResolver(TrackedInstance, TaskResolverMixin): + """ + This resolved is used when the task is defined in a notebook. It is used to load the task from the notebook. + """ + + def name(self) -> str: + return "DefaultNotebookTaskResolver" + + @timeit("Load task") + def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: + import gzip + + import cloudpickle + + with gzip.open(PICKLE_FILE_PATH, "r") as f: + return cloudpickle.load(f) + + def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore + _, m, t, _ = extract_task_module(task) + return ["task-module", m, "task-name", t] + + def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore + raise NotImplementedError + + +default_notebook_task_resolver = DefaultNotebookTaskResolver() + + def update_image_spec_copy_handling(image_spec: ImageSpec, settings: SerializationSettings): """ This helper function is where the relationship between fast register and ImageSpec is codified. diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index a080193809..7319975aa9 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -244,6 +244,14 @@ def compile_into_workflow( # See comment on reference entity checking a bit down below in this function. # This is the only circular dependency between the translator.py module and the rest of the flytekit # authoring experience. + + # TODO: After backend support pickling dynamic task, add fast_register_file_uploader to the FlyteContext, + # and pass the fast_registerfile_uploader to serializer via the options. + # If during runtime we are execution a dynamic function that is pickled, all subsequent sub-tasks in + # dynamic should also be pickled. As this is not possible to do during static compilation, we will have to + # upload the pickled file to the metadata store directly during runtime. + # If at runtime we are in dynamic task, we will automatically have the fast_register_file_uploader set, + # so we can use that to pass the file uploader to the translator. workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable( model_entities, ctx.serialization_settings, wf ) diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index ad38405148..382ca4b234 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -10,6 +10,7 @@ from flytekit.configuration.feature_flags import FeatureFlags from flytekit.exceptions import system as _system_exceptions from flytekit.loggers import developer_logger, logger +from flytekit.tools.interactive import ipython_check def import_module_from_file(module_name, file): @@ -248,6 +249,24 @@ def istestfunction(func) -> bool: return False +def is_ipython_or_pickle_exists() -> bool: + """ + Returns true if the code is running in an IPython notebook or if a pickle file exists. + + We skip module path resolution in both cases due to the following reasons: + + 1. In an IPython notebook, we cannot resolve the module path in the local file system. + 2. When the code is serialized (pickled) and executed in a remote environment, only + the pickled file exists at PICKLE_FILE_PATH. The remote environment won't have the + plain python file and module path resolution will fail. + + This check ensures we avoid attempting module path resolution in both environments. + """ + from flytekit.core.python_auto_container import PICKLE_FILE_PATH + + return ipython_check() or os.path.exists(PICKLE_FILE_PATH) + + class _ModuleSanitizer(object): """ Sanitizes and finds the absolute module path irrespective of the import location. @@ -278,6 +297,18 @@ def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str] if dirname == package_root: return basename + # Execution in a Jupyter notebook, we cannot resolve the module path + if not os.path.exists(dirname): + logger.debug( + f"Directory {dirname} does not exist. It is likely that we are in a Jupyter notebook or a pickle file was received." + ) + + if not is_ipython_or_pickle_exists(): + raise AssertionError( + f"Directory {dirname} does not exist, and we are not in a Jupyter notebook or received a pickle file." + ) + return basename + # If we have reached a directory with no __init__, ignore if "__init__.py" not in os.listdir(dirname): return basename @@ -326,7 +357,12 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, mod, mod_name, name = _task_module_from_callable(f) if mod is None: - raise AssertionError(f"Unable to determine module of {f}") + if not is_ipython_or_pickle_exists(): + raise AssertionError(f"Unable to determine module of {f}") + logger.debug( + "Could not determine module of function. It is likely that we are in a Jupyter notebook or received a pickle file." + ) + return f"{mod_name}.{name}", mod_name, name, "" if mod_name == "__main__": if hasattr(f, "task_function"): diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 379943d36d..4605480542 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -202,6 +202,7 @@ def __init__( default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None, data_upload_location: str = "flyte://my-s3-bucket/", + interactive_mode_enabled: bool = False, **kwargs, ): """Initialize a FlyteRemote object. @@ -212,10 +213,14 @@ def __init__( :param default_domain: default domain to use when fetching or executing flyte entities. :param data_upload_location: this is where all the default data will be uploaded when providing inputs. The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases. + :param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow. """ if config is None or config.platform is None or config.platform.endpoint is None: raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") + if interactive_mode_enabled is True: + logger.warning("Jupyter notebook and interactive task support is still alpha.") + if data_upload_location is None: data_upload_location = FlyteContext.current_context().file_access.raw_output_prefix self._kwargs = kwargs @@ -235,6 +240,7 @@ def __init__( # Save the file access object locally, build a context for it and save that as well. self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build() + self._interactive_mode_enabled = interactive_mode_enabled @property def context(self) -> FlyteContext: @@ -268,6 +274,11 @@ def file_access(self) -> FileAccessProvider: """File access provider to use for offloading non-literal inputs/outputs.""" return self._file_access + @property + def interactive_mode_enabled(self) -> bool: + """If set to True, the FlyteRemote will pickle the task/workflow.""" + return self._interactive_mode_enabled + def get( self, flyte_uri: typing.Optional[str] = None ) -> typing.Optional[typing.Union[LiteralsResolver, Literal, HTML, bytes]]: @@ -758,6 +769,10 @@ async def _serialize_and_register( ) if serialization_settings.version is None: serialization_settings.version = version + serialization_settings.interactive_mode_enabled = self.interactive_mode_enabled + + options = options or Options() + options.file_uploader = options.file_uploader or self.upload_file _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) # concurrent register @@ -862,6 +877,7 @@ def register_workflow( ident = run_sync( self._serialize_and_register, entity, serialization_settings, version, options, default_launch_plan ) + fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version) fwf._python_interface = entity.python_interface return fwf @@ -1811,14 +1827,15 @@ def execute_local_task( """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) + not_found = False try: flyte_task: FlyteTask = self.fetch_task(**resolved_identifiers_dict) except FlyteEntityNotExistException: - if isinstance(entity, PythonAutoContainerTask): - if not image_config: - raise ValueError(f"PythonTask {entity.name} not already registered, but image_config missing") + not_found = True + + if not_found: ss = SerializationSettings( - image_config=image_config, + image_config=image_config or ImageConfig.auto_default_image(), project=project or self.default_project, domain=domain or self._default_domain, version=version, @@ -1881,6 +1898,9 @@ def execute_local_workflow( """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) + if not image_config: + image_config = ImageConfig.auto_default_image() + ss = SerializationSettings( image_config=image_config, project=resolved_identifiers.project, @@ -1893,8 +1913,6 @@ def execute_local_workflow( self.fetch_workflow(**resolved_identifiers_dict) except FlyteEntityNotExistException: logger.info("Registering workflow because it wasn't found in Flyte Admin.") - if not image_config: - raise ValueError("Need image config since we are registering") self.register_workflow(entity, ss, version=version, options=options) try: diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index 6fbefc6809..125b8024d4 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -20,6 +20,7 @@ from flytekit.constants import CopyFileDetection from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.python_auto_container import PICKLE_FILE_PATH from flytekit.core.utils import timeit from flytekit.exceptions.user import FlyteDataNotFoundException from flytekit.loggers import logger @@ -242,12 +243,13 @@ def download_distribution(additional_distribution: str, destination: str): except FlyteDataNotFoundException as ex: raise RuntimeError("task execution code was not found") from ex tarfile_name = os.path.basename(additional_distribution) - if not tarfile_name.endswith(".tar.gz"): + if tarfile_name.endswith(".tar.gz"): + # This will overwrite the existing user flyte workflow code in the current working code dir. + result = subprocess.run( + ["tar", "-xvf", os.path.join(destination, tarfile_name), "-C", destination], + stdout=subprocess.PIPE, + ) + result.check_returncode() + elif tarfile_name != PICKLE_FILE_PATH: + # The distribution is not a pickled file. raise RuntimeError("Unrecognized additional distribution format for {}".format(additional_distribution)) - - # This will overwrite the existing user flyte workflow code in the current working code dir. - result = subprocess.run( - ["tar", "-xvf", os.path.join(destination, tarfile_name), "-C", destination], - stdout=subprocess.PIPE, - ) - result.check_returncode() diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index e27bdbe09a..9800d0eee3 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -1,4 +1,7 @@ +import os +import pathlib import sys +import tempfile import typing from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple, Union @@ -19,11 +22,17 @@ from flytekit.core.legacy_map_task import MapPythonTask from flytekit.core.node import Node from flytekit.core.options import Options -from flytekit.core.python_auto_container import PythonAutoContainerTask +from flytekit.core.python_auto_container import ( + PICKLE_FILE_PATH, + PythonAutoContainerTask, + default_notebook_task_resolver, +) from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec, ReferenceTemplate from flytekit.core.task import ReferenceTask from flytekit.core.utils import ClassDecorator, _dnsify from flytekit.core.workflow import ReferenceWorkflow, WorkflowBase +from flytekit.exceptions.user import FlyteAssertion +from flytekit.loggers import logger from flytekit.models import common as _common_models from flytekit.models import interface as interface_models from flytekit.models import launch_plan as _launch_plan_models @@ -118,6 +127,52 @@ def fn(settings: SerializationSettings) -> List[str]: return fn +def _update_serialization_settings_for_ipython( + entity: FlyteLocalEntity, + serialization_settings: SerializationSettings, + options: Optional[Options] = None, +): + # We are in an interactive environment. We will serialize the task as a pickled object and upload it to remote + # storage. + if isinstance(entity, PythonFunctionTask): + if entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: + raise FlyteAssertion( + f"Dynamic tasks are not supported in interactive mode. {entity.name} is a dynamic task." + ) + + if options is None or options.file_uploader is None: + raise FlyteAssertion("To work interactively with Flyte, a code transporter/uploader should be configured.") + + # For map tasks, we need to serialize the actual task, not the map task itself + if isinstance(entity, ArrayNodeMapTask): + entity._run_task.set_resolver(default_notebook_task_resolver) + actual_task = entity._run_task + else: + entity.set_resolver(default_notebook_task_resolver) + actual_task = entity + + import gzip + + import cloudpickle + + from flytekit.configuration import FastSerializationSettings + + with tempfile.TemporaryDirectory() as tmp_dir: + dest = pathlib.Path(tmp_dir, PICKLE_FILE_PATH) + with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped: + cloudpickle.dump(actual_task, gzipped) + if os.path.getsize(dest) > 150 * 1024 * 1024: + raise ValueError( + "The size of the task to pickled exceeds the limit of 150MB. Please reduce the size of the task." + ) + logger.debug(f"Uploading Pickled representation of Task `{actual_task.name}` to remote storage...") + _, native_url = options.file_uploader(dest) + + serialization_settings.fast_serialization_settings = FastSerializationSettings( + enabled=True, distribution_location=native_url, destination_dir="." + ) + + def get_serializable_task( entity_mapping: OrderedDict, settings: SerializationSettings, @@ -132,6 +187,14 @@ def get_serializable_task( settings.version, ) + # Try to update the serialization settings for ipython / jupyter notebook / interactive mode if we are in an + # interactive environment like Jupyter notebook + if settings.interactive_mode_enabled is True: + # If the entity is not a PythonAutoContainerTask, we don't need to do anything, as only Tasks with container | + # user code in container needs to be serialized as pickled objects. + if isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)): + _update_serialization_settings_for_ipython(entity, settings, options) + if isinstance(entity, PythonFunctionTask) and entity.execution_mode == PythonFunctionTask.ExecutionBehavior.DYNAMIC: for e in context_manager.FlyteEntities.entities: if isinstance(e, PythonAutoContainerTask): @@ -737,7 +800,7 @@ def get_serializable( cp_entity = get_reference_spec(entity_mapping, settings, entity) elif isinstance(entity, PythonTask): - cp_entity = get_serializable_task(entity_mapping, settings, entity) + cp_entity = get_serializable_task(entity_mapping, settings, entity, options) elif isinstance(entity, WorkflowBase): cp_entity = get_serializable_workflow(entity_mapping, settings, entity, options) diff --git a/tests/flytekit/integration/jupyter/test_notebook_run.py b/tests/flytekit/integration/jupyter/test_notebook_run.py new file mode 100644 index 0000000000..e72cd44328 --- /dev/null +++ b/tests/flytekit/integration/jupyter/test_notebook_run.py @@ -0,0 +1,95 @@ +from datetime import timedelta +import pathlib +import os +import pytest +from jupyter_client.manager import KernelManager, BlockingKernelClient + +from flytekit.configuration import Config +from flytekit.remote import FlyteRemote +import ipykernel.kernelspec + + +CONFIG = os.environ.get("FLYTECTL_CONFIG", str(pathlib.Path.home() / ".flyte" / "config-sandbox.yaml")) +# Run `make build-dev` to build and push the image to the local registry. +IMAGE = os.environ.get("FLYTEKIT_IMAGE", "localhost:30000/flytekit:dev") +PROJECT = "flytesnacks" +DOMAIN = "development" +VERSION = f"v{os.getpid()}" +KERNEL_NAME = "python3-flytekit-integration" + + +@pytest.fixture(scope="module", autouse=True) +def install_kernel(): + ipykernel.kernelspec.install(user=True, kernel_name=KERNEL_NAME) + + +@pytest.fixture +def jupyter_kernel(): + km = KernelManager(kernel_name=KERNEL_NAME) + km.start_kernel() + kc = km.client() + kc.start_channels() + kc.wait_for_ready() + yield kc + kc.stop_channels() + km.shutdown_kernel() + + +def execute_code_in_kernel(kc: BlockingKernelClient, code: str): + kc.execute(code) + reply = kc.get_shell_msg(timeout=5) + if reply['content']['status'] == 'error': + raise RuntimeError(f"Error executing code: {reply['content']}") + + output = [] + while True: + msg = kc.get_iopub_msg(timeout=5) + if msg['msg_type'] == 'error': + raise RuntimeError(f"Error in execution: {msg['content']}") + elif msg['msg_type'] == 'stream': # print(...) streams the output out + output.append(msg['content']['text'].strip()) + if msg['msg_type'] == 'status' and msg['content']['execution_state'] == 'idle': + break # nothing is running anymore so we break + + return output + + +NOTEBOOK_CODE = f""" +from flytekit import task, workflow +from flytekit.configuration import Config +from flytekit.remote import FlyteRemote + +remote = FlyteRemote( + Config.auto("{CONFIG}"), + default_project="{PROJECT}", + default_domain="{DOMAIN}", + interactive_mode_enabled=True, +) + +@task(container_image="{IMAGE}") +def hello(name: str) -> str: + return f"Hello {{name}}" + +@task(container_image="{IMAGE}") +def world(pre: str) -> str: + return f"{{pre}}, Welcome to the world!" + +@workflow +def wf(name: str) -> str: + return world(pre=hello(name=name)) + +out = remote.execute(wf, inputs={{"name": "flytekit"}}, version="{VERSION}") +print(out.id.name) +""" + + +def test_jupyter_code_execution(jupyter_kernel): + output = execute_code_in_kernel(jupyter_kernel, NOTEBOOK_CODE) + assert len(output) == 1 + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + execution_id = output[0] + execution = remote.fetch_execution(name=execution_id) + execution = remote.wait(execution, sync_nodes=True, poll_interval=timedelta(seconds=5)) + + assert execution.outputs["o0"] == "Hello flytekit, Welcome to the world!" diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index cfbd1c16ff..f80a76b4c5 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -14,6 +14,7 @@ from urllib.parse import urlparse import uuid import pytest +from mock import mock, patch from flytekit import LaunchPlan, kwtypes from flytekit.configuration import Config, ImageConfig, SerializationSettings @@ -160,7 +161,7 @@ def test_monitor_workflow_execution(register): break with pytest.raises( - FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs.", + FlyteAssertion, match="Please wait until the execution has completed before requesting the outputs.", ): execution.outputs @@ -410,7 +411,8 @@ def test_execute_with_default_launch_plan(register): from workflows.basic.subworkflows import parent_wf remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) - execution = remote.execute(parent_wf, inputs={"a": 101}, version=VERSION, wait=True, image_config=ImageConfig.auto(img_name=IMAGE)) + execution = remote.execute(parent_wf, inputs={"a": 101}, version=VERSION, wait=True, + image_config=ImageConfig.auto(img_name=IMAGE)) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"} @@ -524,6 +526,7 @@ def test_execute_workflow_with_maptask(register): ) assert execution.outputs["o0"] == [4, 5, 6] + @pytest.mark.lftransfers class TestLargeFileTransfers: """A class to capture tests and helper functions for large file transfers.""" @@ -562,7 +565,7 @@ def _ephemeral_minio_project_domain_filename_root(s3_client, project, domain): """An ephemeral minio S3 path which is wiped upon the context manager's exit""" # Generate a random path in our Minio s3 bucket, under /PROJECT/DOMAIN/ buckets = s3_client.list_buckets()["Buckets"] - assert len(buckets) == 1 # We expect just the default sandbox bucket + assert len(buckets) == 1 # We expect just the default sandbox bucket bucket = buckets[0]["Name"] root = str(uuid.uuid4()) key = f"{PROJECT}/{DOMAIN}/{root}/" @@ -573,7 +576,6 @@ def _ephemeral_minio_project_domain_filename_root(s3_client, project, domain): for obj in response["Contents"]: TestLargeFileTransfers._delete_s3_file(s3_client, bucket, obj["Key"]) - @staticmethod @pytest.mark.parametrize("gigabytes", [2, 3]) def test_flyteremote_uploads_large_file(gigabytes): @@ -614,6 +616,125 @@ def test_flyteremote_uploads_large_file(gigabytes): assert s3_md5_bytes == md5_bytes +def test_workflow_remote_func(): + """Test the logic of the remote execution of workflows and tasks.""" + from workflows.basic.child_workflow import parent_wf, double + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + out0 = remote.execute( + double, + inputs={"a": 3}, + wait=True, + version=VERSION, + image_config=ImageConfig.from_images(IMAGE), + ) + out1 = remote.execute( + parent_wf, + inputs={"a": 3}, + wait=True, + version=VERSION + "-1", + image_config=ImageConfig.from_images(IMAGE), + ) + out2 = remote.execute( + parent_wf, + inputs={"a": 2}, + wait=True, + version=VERSION + "-2", + image_config=ImageConfig.from_images(IMAGE), + ) + + assert out0.outputs["o0"] == 6 + assert out1.outputs["o0"] == 18 + assert out2.outputs["o0"] == 12 + + +def test_execute_task_remote_func_list_of_floats(): + """Test remote execution of a @task-decorated python function with a list of floats.""" + from workflows.basic.list_float_wf import concat_list + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + xs: typing.List[float] = [0.1, 0.2, 0.3, 0.4, -99999.7] + out = remote.execute( + concat_list, + inputs={"xs": xs}, + wait=True, + version=VERSION, + image_config=ImageConfig.from_images(IMAGE), + ) + assert out.outputs["o0"] == "[0.1, 0.2, 0.3, 0.4, -99999.7]" + + +def test_execute_task_remote_func_convert_dict(): + """Test remote execution of a @task-decorated python function with a dict of strings.""" + from workflows.basic.dict_str_wf import convert_to_string + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"} + out = remote.execute( + convert_to_string, + inputs={"d": d}, + wait=True, + version=VERSION, + image_config=ImageConfig.from_images(IMAGE), + ) + assert json.loads(out.outputs["o0"]) == {"key1": "value1", "key2": "value2"} + + +def test_execute_python_workflow_remote_func_dict_of_string_to_string(): + """Test remote execution of a @workflow-decorated python function with a dict of strings.""" + from workflows.basic.dict_str_wf import my_wf + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + d: typing.Dict[str, str] = {"k1": "v1", "k2": "v2"} + out = remote.execute( + my_wf, + inputs={"d": d}, + wait=True, + version=VERSION + "dict_str_wf", + image_config=ImageConfig.from_images(IMAGE), + ) + assert json.loads(out.outputs["o0"]) == {"k1": "v1", "k2": "v2"} + + +def test_execute_python_workflow_remote_func_list_of_floats(): + """Test remote execution of a @workflow-decorated python function with a list of floats.""" + """Test execution of a @workflow-decorated python function.""" + from workflows.basic.list_float_wf import my_wf + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + xs: typing.List[float] = [42.24, 999.1, 0.0001] + out = remote.execute( + my_wf, + inputs={"xs": xs}, + wait=True, + version=VERSION + "list_float_wf", + image_config=ImageConfig.from_images(IMAGE), + ) + assert out.outputs["o0"] == "[42.24, 999.1, 0.0001]" + + +def test_execute_workflow_remote_fn_with_maptask(): + """Test remote execution of a @workflow-decorated python function with a map task.""" + from workflows.basic.array_map import workflow_with_maptask + + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True) + + d: typing.List[int] = [1, 2, 3] + out = remote.execute( + workflow_with_maptask, + inputs={"data": d, "y": 3}, + wait=True, + version=VERSION, + image_config=ImageConfig.from_images(IMAGE), + ) + assert out.outputs["o0"] == [4, 5, 6] + + def test_register_wf_fast(register): from workflows.basic.subworkflows import parent_wf diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index fae81d1355..1d365d6629 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,5 +1,6 @@ import functools import os +import pathlib import typing from collections import OrderedDict from typing import List @@ -13,6 +14,7 @@ from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver +from flytekit.core.python_auto_container import PICKLE_FILE_PATH from flytekit.core.task import TaskMetadata from flytekit.core.type_engine import TypeEngine from flytekit.extras.accelerators import GPUAccelerator @@ -22,7 +24,7 @@ LiteralMap, LiteralOffloadedMetadata, ) -from flytekit.tools.translator import get_serializable +from flytekit.tools.translator import get_serializable, Options from flytekit.types.pickle import BatchSize @@ -38,6 +40,19 @@ def serialization_settings(): ) +@pytest.fixture +def interactive_serialization_settings(): + default_img = Image(name="default", fqn="test", tag="tag") + return SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + interactive_mode_enabled=True, + ) + + def test_map(serialization_settings): @task def say_hello(name: str) -> str: @@ -139,6 +154,55 @@ def t1(a: int) -> int: ] +def test_interactive_serialization(interactive_serialization_settings): + @task + def t1(a: int) -> int: + return a + 1 + + def mock_file_uploader(dest: pathlib.Path): + return (0, dest.name) + + arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) + option = Options() + option.file_uploader = mock_file_uploader + task_spec = get_serializable(OrderedDict(), interactive_serialization_settings, arraynode_maptask, options=option) + + assert task_spec.template.metadata.retries.retries == 2 + assert task_spec.template.custom["minSuccessRatio"] == 1.0 + assert task_spec.template.type == "python-task" + assert task_spec.template.task_type_version == 1 + assert task_spec.template.container.args == [ + "pyflyte-fast-execute", + "--additional-distribution", + PICKLE_FILE_PATH, + "--dest-dir", + ".", + "--", + "pyflyte-map-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--resolver", + "flytekit.core.array_node_map_task.ArrayNodeMapTaskResolver", + "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_notebook_task_resolver", + "task-module", + "tests.flytekit.unit.core.test_array_node_map_task", + "task-name", + "t1", + ] + + def test_fast_serialization(serialization_settings): serialization_settings.fast_serialization_settings = FastSerializationSettings(enabled=True) diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index 70a2d552d8..8379a3d3eb 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -267,7 +267,7 @@ def test_serialization_settings_transport(): ss = SerializationSettings.from_transport(tp) assert ss is not None assert ss == serialization_settings - assert len(tp) == 408 + assert len(tp) == 432 def test_exec_params(): diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 6101fc2429..ca8d6e95c2 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -263,3 +263,21 @@ def test_prom_with_union_literals(): assert bd.scalar.union.stored_type.structure.tag == "int" bd = binding_data_from_python_std(ctx, lt, "hello", pt, []) assert bd.scalar.union.stored_type.structure.tag == "str" + +def test_pickling_promise_object(): + @task + def t1(a: int) -> int: + return a + + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) + p = create_and_link_node(ctx, t1, a=3) + assert p.ref.node_id == "n0" + assert p.ref.var == "o0" + assert len(p.ref.node.bindings) == 1 + + import cloudpickle + + p2 = cloudpickle.loads(cloudpickle.dumps(p)) + assert p2.ref.node_id == "n0" + assert p2.ref.var == "o0" + assert len(p2.ref.node.bindings) == 1 diff --git a/tests/flytekit/unit/core/test_python_auto_container.py b/tests/flytekit/unit/core/test_python_auto_container.py index 2749d52cec..2f05c1227b 100644 --- a/tests/flytekit/unit/core/test_python_auto_container.py +++ b/tests/flytekit/unit/core/test_python_auto_container.py @@ -1,3 +1,4 @@ +import pathlib from collections import OrderedDict from typing import Any @@ -7,10 +8,10 @@ from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.base_task import TaskMetadata from flytekit.core.pod_template import PodTemplate -from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image +from flytekit.core.python_auto_container import PythonAutoContainerTask, get_registerable_container_image, PICKLE_FILE_PATH from flytekit.core.resources import Resources from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec -from flytekit.tools.translator import get_serializable_task +from flytekit.tools.translator import get_serializable_task, Options @pytest.fixture @@ -42,6 +43,13 @@ def minimal_serialization_settings_no_default_image(no_default_image_config): return SerializationSettings(project="p", domain="d", version="v", image_config=no_default_image_config) +@pytest.fixture +def interactive_serialization_settings(default_image_config): + return SerializationSettings( + project="p", domain="d", version="v", image_config=default_image_config, env={"FOO": "bar"}, interactive_mode_enabled=True + ) + + @pytest.fixture( params=[ "default_serialization_settings", @@ -100,6 +108,7 @@ def test_default_command(default_serialization_settings): ] + def test_get_container(default_serialization_settings): c = task.get_container(default_serialization_settings) assert c.image == "docker.io/xyz:some-git-hash" @@ -133,6 +142,26 @@ def test_get_container_without_serialization_settings_envvars(minimal_serializat assert ts.template.container.env == {"HAM": "spam"} +def test_get_container_with_interactive_settings(interactive_serialization_settings): + c = task_with_env_vars.get_container(interactive_serialization_settings) + assert c.image == "docker.io/xyz:some-git-hash" + assert c.env == {"FOO": "bar", "HAM": "spam"} + + def mock_file_uploader(dest: pathlib.Path): + return (0, dest.name) + + option = Options() + option.file_uploader = mock_file_uploader + ts = get_serializable_task(OrderedDict(), interactive_serialization_settings, task_with_env_vars, options=option) + assert ts.template.container.image == "docker.io/xyz:some-git-hash" + assert ts.template.container.env == {"FOO": "bar", "HAM": "spam"} + assert 'flytekit.core.python_auto_container.default_notebook_task_resolver' in ts.template.container.args + assert interactive_serialization_settings.fast_serialization_settings is not None + assert interactive_serialization_settings.fast_serialization_settings.enabled is True + assert interactive_serialization_settings.fast_serialization_settings.destination_dir == "." + assert interactive_serialization_settings.fast_serialization_settings.distribution_location == PICKLE_FILE_PATH + + task_with_pod_template = DummyAutoContainerTask( name="x", metadata=TaskMetadata( diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index 792ca0b131..213a267611 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -11,7 +11,7 @@ from flytekit.core.workflow import ReferenceWorkflow, workflow from flytekit.models.core import identifier as identifier_models from flytekit.models.task import Resources as resource_model -from flytekit.tools.translator import get_serializable +from flytekit.tools.translator import get_serializable, Options default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -84,6 +84,23 @@ def my_wf(a: int, b: str) -> (int, str): assert lp_model.id.name == "testlp" +def test_interactive(): + @task + def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str): + return a + 2, "world" + + b = serialization_settings.new_builder() + b.interactive_mode_enabled = True + ssettings = b.build() + + fake_file_uploader = lambda dest: (0, dest) + options = Options(file_uploader=fake_file_uploader) + + task_spec = get_serializable(OrderedDict(), ssettings, t1, options) + assert "--dest-dir" in task_spec.template.container.args + assert task_spec.template.container.args[task_spec.template.container.args.index("--dest-dir") + 1] == "." + + def test_fast(): @task def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):