Skip to content

Commit

Permalink
Pickle remote task for Jupyter Notebook Environment (#2733)
Browse files Browse the repository at this point in the history
* Use Jupyter notebooks to execute tasks

Signed-off-by: Ketan Umare <[email protected]>

* More updates

Signed-off-by: Ketan Umare <[email protected]>

* map task update

Signed-off-by: Ketan Umare <[email protected]>

* entrypoint fix

Signed-off-by: Ketan Umare <[email protected]>

* Working map task

Signed-off-by: Ketan Umare <[email protected]>

* updated for map tasks and removed unnecessary prints

Signed-off-by: Ketan Umare <[email protected]>

* Almost working, dynamic needs work in remote

Signed-off-by: Ketan Umare <[email protected]>

* Dynamic task trial

Signed-off-by: Ketan Umare <[email protected]>

* dynamic update

Signed-off-by: Ketan Umare <[email protected]>

* Updates

Signed-off-by: Ketan Umare <[email protected]>

* map task fix

Signed-off-by: Ketan Umare <[email protected]>

* create interface for task & workflow for jupyter notebook

Signed-off-by: Mecoli1219 <[email protected]>

* lint & docstring

Signed-off-by: Mecoli1219 <[email protected]>

* enable computed version

Signed-off-by: Mecoli1219 <[email protected]>

* add version for workflow

Signed-off-by: Mecoli1219 <[email protected]>

* disable file upload for workflow

Signed-off-by: Mecoli1219 <[email protected]>

* lint

Signed-off-by: Mecoli1219 <[email protected]>

* Add testing jupyter interaction

Signed-off-by: Mecoli1219 <[email protected]>

* lint & fix import script error

Signed-off-by: Mecoli1219 <[email protected]>

* Update unit test

Signed-off-by: Mecoli1219 <[email protected]>

* Make it thread safe

Signed-off-by: Mecoli1219 <[email protected]>

* Enable options for both init_remote() and remote() & restrict one-time init_remote() & update doc

Signed-off-by: Mecoli1219 <[email protected]>

* [Still Trying...] Store the interactive_mode_enabled to avoid using ipython_check

Signed-off-by: Mecoli1219 <[email protected]>

* Add interactive mode enable for init_remote & fix unit test

Signed-off-by: Mecoli1219 <[email protected]>

* Fix intergration test

Signed-off-by: Mecoli1219 <[email protected]>

* Fix unit-test

Signed-off-by: Mecoli1219 <[email protected]>

* Update integration test names & skip the unsupported test

Signed-off-by: Mecoli1219 <[email protected]>

* lint

Signed-off-by: Mecoli1219 <[email protected]>

* enable verbosity & small update & update docstring

Signed-off-by: Mecoli1219 <[email protected]>

* Enable map_task with partial_fn & Add unit test

Signed-off-by: Mecoli1219 <[email protected]>

* save from merge

Signed-off-by: Mecoli1219 <[email protected]>

* Add translator unit test for interactive

Signed-off-by: Mecoli1219 <[email protected]>

* Remove fast_register_file_uploader

Signed-off-by: Mecoli1219 <[email protected]>

* Fix merge error

Signed-off-by: Mecoli1219 <[email protected]>

* Remove future

Signed-off-by: Mecoli1219 <[email protected]>

* Fix unit-test

Signed-off-by: Mecoli1219 <[email protected]>

* Fix entrypoint

Signed-off-by: Mecoli1219 <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

* update tracker

Signed-off-by: Kevin Su <[email protected]>

* Fix I/O on closed file error

Signed-off-by: Mecoli1219 <[email protected]>

* Update integration test

Signed-off-by: Mecoli1219 <[email protected]>

* Update integration test

Signed-off-by: Mecoli1219 <[email protected]>

* Remove unused try & error for event loop

Signed-off-by: Mecoli1219 <[email protected]>

* Set size limit to pickled file

Signed-off-by: Mecoli1219 <[email protected]>

* Remove nest_asyncio

Signed-off-by: Mecoli1219 <[email protected]>

* Remove some args for pyflyte-execute

Signed-off-by: Mecoli1219 <[email protected]>

* Remove changes to entrypoint

Signed-off-by: Mecoli1219 <[email protected]>

* Fix map task

Signed-off-by: Mecoli1219 <[email protected]>

* Remove unnecessary code

Signed-off-by: Mecoli1219 <[email protected]>

* Remove unnecessary code

Signed-off-by: Mecoli1219 <[email protected]>

* Add more test

Signed-off-by: Mecoli1219 <[email protected]>

* fix unit test

Signed-off-by: Mecoli1219 <[email protected]>

* notebook task resolver

Signed-off-by: Kevin Su <[email protected]>

* set resolver

Signed-off-by: Kevin Su <[email protected]>

* Update test

Signed-off-by: Mecoli1219 <[email protected]>

* Update test

Signed-off-by: Mecoli1219 <[email protected]>

* Update test

Signed-off-by: Mecoli1219 <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

* fix unit test

Signed-off-by: Kevin Su <[email protected]>

* fix unit test

Signed-off-by: Kevin Su <[email protected]>

* interactive_mode_enabled=True

Signed-off-by: Kevin Su <[email protected]>

* display_ipython_warning

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Mecoli1219 <[email protected]>

* fix unit test

Signed-off-by: Mecoli1219 <[email protected]>

* Move pkl.gz to global variable

Signed-off-by: Mecoli1219 <[email protected]>

* change logging level

Signed-off-by: Mecoli1219 <[email protected]>

* Fix map task not using notebook resolver

Signed-off-by: Mecoli1219 <[email protected]>

* Add map task serialization unit test

Signed-off-by: Mecoli1219 <[email protected]>

* moving some codes & fix map_task problem

Signed-off-by: Mecoli1219 <[email protected]>

* Fix map task extract pickled task module error

Signed-off-by: Mecoli1219 <[email protected]>

* wrap the ipython check into a function

Signed-off-by: Mecoli1219 <[email protected]>

* Fix test

Signed-off-by: Thomas J. Fan <[email protected]>

* Adds integration test for jupyter

Signed-off-by: Thomas J. Fan <[email protected]>

* Install kernel before using KernelManager

Signed-off-by: Thomas J. Fan <[email protected]>

* Add a module scope kernel install

Signed-off-by: Thomas J. Fan <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Mecoli1219 <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Thomas J. Fan <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
  • Loading branch information
4 people authored Oct 4, 2024
1 parent fdf93da commit a3131f2
Show file tree
Hide file tree
Showing 17 changed files with 552 additions and 26 deletions.
2 changes: 2 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pyarrow
scikit-learn
types-requests
prometheus-client
jupyter-client
ipykernel

orjson
kubernetes>=12.0.1
5 changes: 5 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ class SerializationSettings(DataClassJsonMixin):
can be fast registered (and thus omit building a Docker image) this object contains additional parameters
for serialization.
source_root (Optional[str]): The root directory of the source code.
interactive_mode_enabled (bool): Whether or not the task is being serialized in interactive mode.
"""

image_config: ImageConfig
Expand All @@ -840,6 +841,7 @@ class SerializationSettings(DataClassJsonMixin):
flytekit_virtualenv_root: Optional[str] = None
fast_serialization_settings: Optional[FastSerializationSettings] = None
source_root: Optional[str] = None
interactive_mode_enabled: bool = False

def __post_init__(self):
if self.flytekit_virtualenv_root is None:
Expand Down Expand Up @@ -914,6 +916,7 @@ def new_builder(self) -> Builder:
python_interpreter=self.python_interpreter,
fast_serialization_settings=self.fast_serialization_settings,
source_root=self.source_root,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def should_fast_serialize(self) -> bool:
Expand Down Expand Up @@ -965,6 +968,7 @@ class Builder(object):
python_interpreter: Optional[str] = None
fast_serialization_settings: Optional[FastSerializationSettings] = None
source_root: Optional[str] = None
interactive_mode_enabled: bool = False

def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder:
self.fast_serialization_settings = fss
Expand All @@ -982,4 +986,5 @@ def build(self) -> SerializationSettings:
python_interpreter=self.python_interpreter,
fast_serialization_settings=self.fast_serialization_settings,
source_root=self.source_root,
interactive_mode_enabled=self.interactive_mode_enabled,
)
4 changes: 4 additions & 0 deletions flytekit/core/options.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
from dataclasses import dataclass
from typing import Callable, Optional

from flytekit.models import common as common_models
from flytekit.models import security
Expand Down Expand Up @@ -34,6 +35,9 @@ class Options(object):
notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None
file_uploader: Optional[Callable] = (
None # This is used by the translator to upload task files, like pickled code etc
)

@classmethod
def default_from(
Expand Down
8 changes: 8 additions & 0 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,14 @@ def _append_attr(self, key) -> Promise:

return new_promise

def __getstate__(self) -> Dict[str, Any]:
# This func is used to pickle the object.
return vars(self)

def __setstate__(self, state: Dict[str, Any]) -> None:
# This func is used to unpickle the object without infinite recursion.
vars(self).update(state)


def create_native_named_tuple(
ctx: FlyteContext,
Expand Down
36 changes: 36 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

T = TypeVar("T")
_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
PICKLE_FILE_PATH = "pkl.gz"


class PythonAutoContainerTask(PythonTask[T], ABC, metaclass=FlyteTrackedABC):
Expand Down Expand Up @@ -163,6 +164,13 @@ def get_default_command(self, settings: SerializationSettings) -> List[str]:

return container_args

def set_resolver(self, resolver: TaskResolverMixin):
"""
By default, flytekit uses the DefaultTaskResolver to resolve the task. This method allows the user to set a custom
task resolver. It can be useful to override the task resolver for specific cases like running tasks in the jupyter notebook.
"""
self._task_resolver = resolver

def set_command_fn(self, get_command_fn: Optional[Callable[[SerializationSettings], List[str]]] = None):
"""
By default, the task will run on the Flyte platform using the pyflyte-execute command.
Expand Down Expand Up @@ -274,6 +282,34 @@ def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
default_task_resolver = DefaultTaskResolver()


class DefaultNotebookTaskResolver(TrackedInstance, TaskResolverMixin):
"""
This resolved is used when the task is defined in a notebook. It is used to load the task from the notebook.
"""

def name(self) -> str:
return "DefaultNotebookTaskResolver"

@timeit("Load task")
def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
import gzip

import cloudpickle

with gzip.open(PICKLE_FILE_PATH, "r") as f:
return cloudpickle.load(f)

def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore
_, m, t, _ = extract_task_module(task)
return ["task-module", m, "task-name", t]

def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore
raise NotImplementedError


default_notebook_task_resolver = DefaultNotebookTaskResolver()


def update_image_spec_copy_handling(image_spec: ImageSpec, settings: SerializationSettings):
"""
This helper function is where the relationship between fast register and ImageSpec is codified.
Expand Down
8 changes: 8 additions & 0 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ def compile_into_workflow(
# See comment on reference entity checking a bit down below in this function.
# This is the only circular dependency between the translator.py module and the rest of the flytekit
# authoring experience.

# TODO: After backend support pickling dynamic task, add fast_register_file_uploader to the FlyteContext,
# and pass the fast_registerfile_uploader to serializer via the options.
# If during runtime we are execution a dynamic function that is pickled, all subsequent sub-tasks in
# dynamic should also be pickled. As this is not possible to do during static compilation, we will have to
# upload the pickled file to the metadata store directly during runtime.
# If at runtime we are in dynamic task, we will automatically have the fast_register_file_uploader set,
# so we can use that to pass the file uploader to the translator.
workflow_spec: admin_workflow_models.WorkflowSpec = get_serializable(
model_entities, ctx.serialization_settings, wf
)
Expand Down
38 changes: 37 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from flytekit.configuration.feature_flags import FeatureFlags
from flytekit.exceptions import system as _system_exceptions
from flytekit.loggers import developer_logger, logger
from flytekit.tools.interactive import ipython_check


def import_module_from_file(module_name, file):
Expand Down Expand Up @@ -248,6 +249,24 @@ def istestfunction(func) -> bool:
return False


def is_ipython_or_pickle_exists() -> bool:
"""
Returns true if the code is running in an IPython notebook or if a pickle file exists.
We skip module path resolution in both cases due to the following reasons:
1. In an IPython notebook, we cannot resolve the module path in the local file system.
2. When the code is serialized (pickled) and executed in a remote environment, only
the pickled file exists at PICKLE_FILE_PATH. The remote environment won't have the
plain python file and module path resolution will fail.
This check ensures we avoid attempting module path resolution in both environments.
"""
from flytekit.core.python_auto_container import PICKLE_FILE_PATH

return ipython_check() or os.path.exists(PICKLE_FILE_PATH)


class _ModuleSanitizer(object):
"""
Sanitizes and finds the absolute module path irrespective of the import location.
Expand Down Expand Up @@ -278,6 +297,18 @@ def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str]
if dirname == package_root:
return basename

# Execution in a Jupyter notebook, we cannot resolve the module path
if not os.path.exists(dirname):
logger.debug(
f"Directory {dirname} does not exist. It is likely that we are in a Jupyter notebook or a pickle file was received."
)

if not is_ipython_or_pickle_exists():
raise AssertionError(
f"Directory {dirname} does not exist, and we are not in a Jupyter notebook or received a pickle file."
)
return basename

# If we have reached a directory with no __init__, ignore
if "__init__.py" not in os.listdir(dirname):
return basename
Expand Down Expand Up @@ -326,7 +357,12 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
mod, mod_name, name = _task_module_from_callable(f)

if mod is None:
raise AssertionError(f"Unable to determine module of {f}")
if not is_ipython_or_pickle_exists():
raise AssertionError(f"Unable to determine module of {f}")
logger.debug(
"Could not determine module of function. It is likely that we are in a Jupyter notebook or received a pickle file."
)
return f"{mod_name}.{name}", mod_name, name, ""

if mod_name == "__main__":
if hasattr(f, "task_function"):
Expand Down
30 changes: 24 additions & 6 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
default_project: typing.Optional[str] = None,
default_domain: typing.Optional[str] = None,
data_upload_location: str = "flyte://my-s3-bucket/",
interactive_mode_enabled: bool = False,
**kwargs,
):
"""Initialize a FlyteRemote object.
Expand All @@ -212,10 +213,14 @@ def __init__(
:param default_domain: default domain to use when fetching or executing flyte entities.
:param data_upload_location: this is where all the default data will be uploaded when providing inputs.
The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases.
:param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow.
"""
if config is None or config.platform is None or config.platform.endpoint is None:
raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.")

if interactive_mode_enabled is True:
logger.warning("Jupyter notebook and interactive task support is still alpha.")

if data_upload_location is None:
data_upload_location = FlyteContext.current_context().file_access.raw_output_prefix
self._kwargs = kwargs
Expand All @@ -235,6 +240,7 @@ def __init__(

# Save the file access object locally, build a context for it and save that as well.
self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build()
self._interactive_mode_enabled = interactive_mode_enabled

@property
def context(self) -> FlyteContext:
Expand Down Expand Up @@ -268,6 +274,11 @@ def file_access(self) -> FileAccessProvider:
"""File access provider to use for offloading non-literal inputs/outputs."""
return self._file_access

@property
def interactive_mode_enabled(self) -> bool:
"""If set to True, the FlyteRemote will pickle the task/workflow."""
return self._interactive_mode_enabled

def get(
self, flyte_uri: typing.Optional[str] = None
) -> typing.Optional[typing.Union[LiteralsResolver, Literal, HTML, bytes]]:
Expand Down Expand Up @@ -758,6 +769,10 @@ async def _serialize_and_register(
)
if serialization_settings.version is None:
serialization_settings.version = version
serialization_settings.interactive_mode_enabled = self.interactive_mode_enabled

options = options or Options()
options.file_uploader = options.file_uploader or self.upload_file

_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
# concurrent register
Expand Down Expand Up @@ -862,6 +877,7 @@ def register_workflow(
ident = run_sync(
self._serialize_and_register, entity, serialization_settings, version, options, default_launch_plan
)

fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version)
fwf._python_interface = entity.python_interface
return fwf
Expand Down Expand Up @@ -1811,14 +1827,15 @@ def execute_local_task(
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
resolved_identifiers_dict = asdict(resolved_identifiers)
not_found = False
try:
flyte_task: FlyteTask = self.fetch_task(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
if isinstance(entity, PythonAutoContainerTask):
if not image_config:
raise ValueError(f"PythonTask {entity.name} not already registered, but image_config missing")
not_found = True

if not_found:
ss = SerializationSettings(
image_config=image_config,
image_config=image_config or ImageConfig.auto_default_image(),
project=project or self.default_project,
domain=domain or self._default_domain,
version=version,
Expand Down Expand Up @@ -1881,6 +1898,9 @@ def execute_local_workflow(
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
resolved_identifiers_dict = asdict(resolved_identifiers)
if not image_config:
image_config = ImageConfig.auto_default_image()

ss = SerializationSettings(
image_config=image_config,
project=resolved_identifiers.project,
Expand All @@ -1893,8 +1913,6 @@ def execute_local_workflow(
self.fetch_workflow(**resolved_identifiers_dict)
except FlyteEntityNotExistException:
logger.info("Registering workflow because it wasn't found in Flyte Admin.")
if not image_config:
raise ValueError("Need image config since we are registering")
self.register_workflow(entity, ss, version=version, options=options)

try:
Expand Down
18 changes: 10 additions & 8 deletions flytekit/tools/fast_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flytekit.constants import CopyFileDetection
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.python_auto_container import PICKLE_FILE_PATH
from flytekit.core.utils import timeit
from flytekit.exceptions.user import FlyteDataNotFoundException
from flytekit.loggers import logger
Expand Down Expand Up @@ -242,12 +243,13 @@ def download_distribution(additional_distribution: str, destination: str):
except FlyteDataNotFoundException as ex:
raise RuntimeError("task execution code was not found") from ex
tarfile_name = os.path.basename(additional_distribution)
if not tarfile_name.endswith(".tar.gz"):
if tarfile_name.endswith(".tar.gz"):
# This will overwrite the existing user flyte workflow code in the current working code dir.
result = subprocess.run(
["tar", "-xvf", os.path.join(destination, tarfile_name), "-C", destination],
stdout=subprocess.PIPE,
)
result.check_returncode()
elif tarfile_name != PICKLE_FILE_PATH:
# The distribution is not a pickled file.
raise RuntimeError("Unrecognized additional distribution format for {}".format(additional_distribution))

# This will overwrite the existing user flyte workflow code in the current working code dir.
result = subprocess.run(
["tar", "-xvf", os.path.join(destination, tarfile_name), "-C", destination],
stdout=subprocess.PIPE,
)
result.check_returncode()
Loading

0 comments on commit a3131f2

Please sign in to comment.