Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async tasks and eager revamp #2927

Merged
merged 48 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3e1d5a2
eod
wild-endeavor Nov 2, 2024
b45c4b7
notes
wild-endeavor Nov 4, 2024
999cea7
changes
wild-endeavor Nov 6, 2024
8be77c5
need to verify tests
wild-endeavor Nov 6, 2024
11f3242
quick lint pass and async
wild-endeavor Nov 6, 2024
8ceec0a
more tests
wild-endeavor Nov 6, 2024
528b063
add some assertions even though they're not correct
wild-endeavor Nov 6, 2024
2b6f698
nested eager in real execution calls the backend
wild-endeavor Nov 7, 2024
3ad46d1
comment
wild-endeavor Nov 7, 2024
38581a3
note
wild-endeavor Nov 7, 2024
469c53f
Merge remote-tracking branch 'origin/master' into async/tasks
wild-endeavor Nov 11, 2024
adbf189
comments, pre-worker queu
wild-endeavor Nov 13, 2024
b45cc87
replace queue
wild-endeavor Nov 14, 2024
52555eb
add turning back to native values
wild-endeavor Nov 14, 2024
9530595
remote
wild-endeavor Nov 14, 2024
809e6b6
remote
wild-endeavor Nov 14, 2024
88537e2
remote
wild-endeavor Nov 14, 2024
27e7ee2
return
wild-endeavor Nov 14, 2024
a0fc558
remove older comments
wild-endeavor Nov 14, 2024
4dfd857
Async/tasks cleanup (#2937)
wild-endeavor Nov 19, 2024
e0c6b7b
merge in consistent exec ids and signals
wild-endeavor Nov 21, 2024
f9b45d6
add cb to watch function, add exception handler, add try/catch around…
wild-endeavor Nov 22, 2024
d4fad13
fmt
wild-endeavor Nov 22, 2024
7c62e20
fix one test and skip rest for now
wild-endeavor Nov 22, 2024
8c80bb6
wrong marker
wild-endeavor Nov 22, 2024
bb04b64
remove old eager tests
wild-endeavor Nov 22, 2024
65df34a
Merge branch 'master' into async/tasks
wild-endeavor Nov 22, 2024
af631bc
Async/tasks agent (#2955)
wild-endeavor Nov 26, 2024
65ffbc4
Merge remote-tracking branch 'origin/master' into async/tasks
wild-endeavor Nov 26, 2024
330c3b1
update test
wild-endeavor Nov 26, 2024
45f07b0
Remove local execute (#2960)
wild-endeavor Nov 26, 2024
eea2605
Async/tasks remote (#2964)
wild-endeavor Nov 27, 2024
ea2cdf0
lint
wild-endeavor Nov 27, 2024
b214d77
merge master
wild-endeavor Dec 2, 2024
6ee16ca
bring mashumaro dep back in line with master
wild-endeavor Dec 2, 2024
be750f8
remove debug code and clean up comments
wild-endeavor Dec 2, 2024
222f65c
add gh issue links
wild-endeavor Dec 2, 2024
9931735
add one more gh issue
wild-endeavor Dec 2, 2024
95ae041
merge master
wild-endeavor Dec 2, 2024
09d473c
skip all sandbox tests for now, wait until released and add integrati…
wild-endeavor Dec 3, 2024
84cf0ff
try adding simplest integration test
wild-endeavor Dec 3, 2024
452105c
lint
wild-endeavor Dec 3, 2024
f114f53
Async/tasks torch (#2972)
wild-endeavor Dec 3, 2024
c4dc438
revert changes for base agent (#2973)
wild-endeavor Dec 4, 2024
7f66c92
de-duplicate call handler code
wild-endeavor Dec 4, 2024
802bae1
try redirecting back (#2978)
wild-endeavor Dec 5, 2024
7684ba2
merge master
wild-endeavor Dec 6, 2024
3071f56
docs changes
wild-endeavor Dec 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 0 additions & 38 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,44 +162,6 @@ jobs:
fail_ci_if_error: false
files: coverage.xml

test-hypothesis:
needs:
- detect-python-versions
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v3
with:
# This path is specific to Ubuntu
path: ~/.cache/pip
# Look to see if there is a cache hit for the corresponding requirements files
key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }}
- name: Install dependencies
run: |
pip install uv
make setup-global-uv
uv pip freeze
- name: Test with coverage
env:
FLYTEKIT_HYPOTHESIS_PROFILE: ci
run: |
make unit_test_hypothesis
- name: Codecov
uses: codecov/[email protected]
with:
fail_ci_if_error: false
files: coverage.xml

test-serialization:
needs:
- detect-python-versions
Expand Down
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ unit_test:
# Run serial tests without any parallelism
$(PYTEST) -m "serial" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/ --ignore=tests/flytekit/unit/models --ignore=tests/flytekit/unit/extend ${CODECOV_OPTS}

.PHONY: unit_test_hypothesis
unit_test_hypothesis:
$(PYTEST_AND_OPTS) -m "hypothesis" tests/flytekit/unit/experimental ${CODECOV_OPTS}

.PHONY: unit_test_extras
unit_test_extras:
PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python $(PYTEST_AND_OPTS) tests/flytekit/unit/extras tests/flytekit/unit/extend ${CODECOV_OPTS}
Expand Down
2 changes: 1 addition & 1 deletion flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference
from flytekit.core.resources import Resources
from flytekit.core.schedule import CronSchedule, FixedRate
from flytekit.core.task import Secret, reference_task, task
from flytekit.core.task import Secret, eager, reference_task, task
from flytekit.core.type_engine import BatchSize
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
Expand Down
5 changes: 0 additions & 5 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import contextlib
import datetime
import inspect
import os
import pathlib
import signal
Expand Down Expand Up @@ -177,10 +176,6 @@ def _dispatch_execute(
# Step2
# Invoke task - dispatch_execute
outputs = task_def.dispatch_execute(ctx, idl_input_literals)
if inspect.iscoroutine(outputs):
# Handle eager-mode (async) tasks
logger.info("Output is a coroutine")
outputs = _get_working_loop().run_until_complete(outputs)

# Step3a
if isinstance(outputs, VoidPromise):
Expand Down
17 changes: 15 additions & 2 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
```
"""

import os
from typing import Optional, Protocol, runtime_checkable

from click import Group
Expand Down Expand Up @@ -59,10 +60,22 @@ def get_remote(
config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None
) -> FlyteRemote:
"""Get FlyteRemote object for CLI session."""

cfg_file = get_config_file(config)

# The assumption here (if there's no config file that means we want sandbox) is too broad.
# todo: can improve this in the future, rather than just checking one env var, auto() with
# nothing configured should probably not return sandbox but can consider
if cfg_file is None:
cfg_obj = Config.for_sandbox()
logger.info("No config files found, creating remote with sandbox config")
# We really are just looking for endpoint, client_id, and client_secret. These correspond to the env vars
# FLYTE_PLATFORM_URL, FLYTE_CREDENTIALS_CLIENT_ID, FLYTE_CREDENTIALS_CLIENT_SECRET
# auto() should pick these up.
if "FLYTE_PLATFORM_URL" in os.environ:
cfg_obj = Config.auto(None)
logger.warning(f"Auto-created config object to pick up env vars {cfg_obj}")
else:
cfg_obj = Config.for_sandbox()
logger.info("No config files found, creating remote with sandbox config")
else: # pragma: no cover
cfg_obj = Config.auto(config)
logger.debug(f"Creating remote with config {cfg_obj}" + (f" with file {config}" if config else ""))
Expand Down
29 changes: 2 additions & 27 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import asyncio
import collections
import datetime
import inspect
import warnings
from abc import abstractmethod
from base64 import b64encode
Expand Down Expand Up @@ -142,6 +141,7 @@ class TaskMetadata(object):
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None
is_eager: bool = False

def __post_init__(self):
if self.timeout:
Expand Down Expand Up @@ -181,6 +181,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
is_eager=self.is_eager,
)


Expand Down Expand Up @@ -340,9 +341,6 @@ def local_execute(
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)

if inspect.iscoroutine(outputs_literal_map):
return outputs_literal_map

outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down Expand Up @@ -759,29 +757,6 @@ def dispatch_execute(
raise
raise FlyteUserRuntimeException(e) from e

if inspect.iscoroutine(native_outputs):
# If native outputs is a coroutine, then this is an eager workflow.
if exec_ctx.execution_state:
if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION:
# Just return task outputs as a coroutine if the eager workflow is being executed locally,
# outside of a workflow. This preserves the expectation that the eager workflow is an async
# function.
return native_outputs
elif exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
# If executed inside of a workflow being executed locally, then run the coroutine to get the
# actual results.
return asyncio.run(
self._async_execute(
native_inputs,
native_outputs,
ctx,
exec_ctx,
new_user_params,
)
)

return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params)

# Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is
# bubbled up to be handled at the callee layer.
native_outputs = self.post_execute(new_user_params, native_outputs)
Expand Down
11 changes: 11 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@
# Set this environment variable to true to force the task to return non-zero exit code on failure.
FLYTE_FAIL_ON_ERROR = "FLYTE_FAIL_ON_ERROR"

# Executions launched by the current eager task will be tagged with this key:current_eager_exec_name
EAGER_TAG_KEY = "eager-exec"

# Executions launched by the current eager task will be tagged with this key:root_eager_exec_name, only relevant
# for nested eager tasks. This is how you identify the root execution.
EAGER_TAG_ROOT_KEY = "eager-root-exec"

# The environment variable that will be set to the root eager task execution name. This is how you pass down the
# root eager execution.
EAGER_ROOT_ENV_NAME = "_F_EE_ROOT"

# This is a special key used to store metadata about the cache key in a literal type.
CACHE_KEY_METADATA = "cache-key-metadata"

Expand Down
42 changes: 41 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging as _logging
import os
import pathlib
import signal
import tempfile
import traceback
import typing
Expand All @@ -24,6 +25,7 @@
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from types import FrameType
from typing import Generator, List, Optional, Union

from flytekit.configuration import Config, SecretsConfig, SerializationSettings
Expand All @@ -37,8 +39,10 @@
from flytekit.models.core import identifier as _identifier

if typing.TYPE_CHECKING:
from flytekit import Deck
from flytekit.clients import friendly as friendly_client # noqa
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.core.worker_queue import Controller
from flytekit.deck.deck import Deck

# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin

Expand Down Expand Up @@ -526,6 +530,10 @@ class Mode(Enum):
# This is the mode that is used to indicate a dynamic task
DYNAMIC_TASK_EXECUTION = 4

EAGER_EXECUTION = 5

EAGER_LOCAL_EXECUTION = 6

mode: Optional[ExecutionState.Mode]
working_dir: Union[os.PathLike, str]
engine_dir: Optional[Union[os.PathLike, str]]
Expand Down Expand Up @@ -586,6 +594,7 @@ def is_local_execution(self) -> bool:
return (
self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
or self.mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION
)


Expand Down Expand Up @@ -663,6 +672,7 @@ class FlyteContext(object):
in_a_condition: bool = False
origin_stackframe: Optional[traceback.FrameSummary] = None
output_metadata_tracker: Optional[OutputMetadataTracker] = None
worker_queue: Optional[Controller] = None

@property
def user_space_params(self) -> Optional[ExecutionParameters]:
Expand All @@ -689,6 +699,7 @@ def new_builder(self) -> Builder:
execution_state=self.execution_state,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
worker_queue=self.worker_queue,
)

def enter_conditional_section(self) -> Builder:
Expand All @@ -713,6 +724,12 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder:
def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder:
return self.new_builder().with_output_metadata_tracker(t)

def with_worker_queue(self, wq: Controller) -> Builder:
return self.new_builder().with_worker_queue(wq)

def with_client(self, c: SynchronousFlyteClient) -> Builder:
return self.new_builder().with_client(c)

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -774,6 +791,7 @@ class Builder(object):
serialization_settings: Optional[SerializationSettings] = None
in_a_condition: bool = False
output_metadata_tracker: Optional[OutputMetadataTracker] = None
worker_queue: Optional[Controller] = None

def build(self) -> FlyteContext:
return FlyteContext(
Expand All @@ -785,6 +803,7 @@ def build(self) -> FlyteContext:
serialization_settings=self.serialization_settings,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
worker_queue=self.worker_queue,
)

def enter_conditional_section(self) -> FlyteContext.Builder:
Expand Down Expand Up @@ -833,6 +852,14 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext
self.output_metadata_tracker = t
return self

def with_worker_queue(self, wq: Controller) -> FlyteContext.Builder:
self.worker_queue = wq
return self

def with_client(self, c: SynchronousFlyteClient) -> FlyteContext.Builder:
self.flyte_client = c
return self

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -871,6 +898,12 @@ class FlyteContextManager(object):
FlyteContextManager.pop_context()
"""

signal_handlers: typing.List[typing.Callable[[int, FrameType], typing.Any]] = []

@staticmethod
def add_signal_handler(handler: typing.Callable[[int, FrameType], typing.Any]):
FlyteContextManager.signal_handlers.append(handler)

@staticmethod
def get_origin_stackframe(limit=2) -> traceback.FrameSummary:
ss = traceback.extract_stack(limit=limit + 1)
Expand Down Expand Up @@ -954,6 +987,13 @@ def initialize():
user_space_path = os.path.join(cfg.local_sandbox_path, "user_space")
pathlib.Path(user_space_path).mkdir(parents=True, exist_ok=True)

def main_signal_handler(signum: int, frame: FrameType):
for handler in FlyteContextManager.signal_handlers:
handler(signum, frame)
exit(1)

signal.signal(signal.SIGINT, main_signal_handler)

# Note we use the SdkWorkflowExecution object purely for formatting into the ex:project:domain:name format users
# are already acquainted with
default_context = FlyteContext(file_access=default_local_file_access_provider)
Expand Down
4 changes: 0 additions & 4 deletions flytekit/core/options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing
from dataclasses import dataclass
from typing import Callable, Optional

from flytekit.models import common as common_models
from flytekit.models import security
Expand Down Expand Up @@ -35,9 +34,6 @@ class Options(object):
notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None
file_uploader: Optional[Callable] = (
None # This is used by the translator to upload task files, like pickled code etc
)

@classmethod
def default_from(
Expand Down
Loading
Loading