Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unbounded foreach #450

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions metaflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from . import metaflow_version
from . import namespace
from .current import current
from .cli_args import cli_args
from .util import resolve_identity, decompress_list, write_latest_run_id, get_latest_run_id
from .task import MetaflowTask
from .exception import CommandException, MetaflowException
Expand All @@ -29,6 +30,7 @@
from .event_logger import EventLogger
from .monitor import Monitor
from .R import use_r, metaflow_r_version
from .unbounded_foreach import UBF_CONTROL, UBF_TASK

ERASE_TO_EOL = '\033[K'
HIGHLIGHT = 'red'
Expand Down Expand Up @@ -325,14 +327,14 @@ def logs(obj, input_path, stdout=None, stderr=None, both=None):
show_default=True,
help='Index of this foreach split.')
@click.option('--tag',
'tags',
'opt_tag',
multiple=True,
default=None,
help="Annotate this run with the given tag. You can specify "
"this option multiple times to attach multiple tags in "
"the task.")
@click.option('--namespace',
'user_namespace',
'opt_namespace',
default=None,
help="Change namespace from the default (your username) to "
"the specified tag.")
Expand All @@ -356,22 +358,30 @@ def logs(obj, input_path, stdout=None, stderr=None, both=None):
help="Add a decorator to this task. You can specify this "
"option multiple times to attach multiple decorators "
"to this task.")
@click.option('--ubf-context',
default='none',
type=click.Choice(['none', UBF_CONTROL, UBF_TASK]),
help="Provides additional context if it belongs to an unbounded "
"foreach.")
@click.pass_context
@click.pass_obj
def step(obj,
ctx,
step_name,
tags=None,
opt_tag=None,
run_id=None,
task_id=None,
input_paths=None,
split_index=None,
user_namespace=None,
opt_namespace=None,
retry_count=None,
max_user_code_retries=None,
clone_only=None,
clone_run_id=None,
decospecs=None):
if user_namespace is not None:
namespace(user_namespace or None)
decospecs=None,
ubf_context=None):
if opt_namespace is not None:
namespace(opt_namespace or None)

func = None
try:
Expand All @@ -387,7 +397,16 @@ def step(obj,
if decospecs:
decorators._attach_decorators_to_step(func, decospecs)

obj.metadata.add_sticky_tags(tags=tags)
step_kwargs = ctx.params
# Remove argument `step_name` from `step_kwargs`.
step_kwargs.pop('step_name', None)
# Remove `opt_*` prefix from (some) option keys.
step_kwargs = dict([(k[4:], v) if k.startswith('opt_') else (k, v)
for k, v in step_kwargs.items()])
cli_args._set_step_kwargs(step_kwargs)

obj.metadata.add_sticky_tags(tags=opt_tag)

paths = decompress_list(input_paths) if input_paths else []

task = MetaflowTask(obj.flow,
Expand All @@ -396,7 +415,8 @@ def step(obj,
obj.environment,
obj.echo,
obj.event_logger,
obj.monitor)
obj.monitor,
ubf_context)
if clone_only:
task.clone_only(step_name,
run_id,
Expand Down Expand Up @@ -735,6 +755,7 @@ def start(ctx,
branch=True)
cov.start()

cli_args._set_top_kwargs(ctx.params)
ctx.obj.echo = echo
ctx.obj.echo_always = echo_always
ctx.obj.graph = FlowGraph(ctx.obj.flow.__class__)
Expand Down
48 changes: 48 additions & 0 deletions metaflow/cli_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This module provides a global singleton `cli_args` which stores the `top` and
# `step` level options for the metaflow CLI.
# This allows decorators to have access to the CLI options instead of relying
# (solely) on the click context.
# TODO(crk): Fold `dict_to_cli_options` as a private method of this `CLIArgs`
# once all other callers of `step` [titus, meson etc.] are unified.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment not relevant to OSS or at least needs to be modded to remove irrelevant references.

from .util import dict_to_cli_options

class CLIArgs(object):
def __init__(self):
self._top_kwargs = {}
self._step_kwargs = {}

def _set_step_kwargs(self, kwargs):
self._step_kwargs = kwargs

def _set_top_kwargs(self, kwargs):
self._top_kwargs = kwargs

@property
def top_kwargs(self):
return self._top_kwargs

@property
def step_kwargs(self):
return self._step_kwargs

def step_command(self,
executable,
script,
step_name,
top_kwargs=None,
step_kwargs=None):
cmd = [executable, '-u', script]
if top_kwargs is None:
top_kwargs = self._top_kwargs
if step_kwargs is None:
step_kwargs = self._step_kwargs

top_args_list = [arg for arg in dict_to_cli_options(top_kwargs)]
cmd.extend(top_args_list)
cmd.extend(['step', step_name])
step_args_list = [arg for arg in dict_to_cli_options(step_kwargs)]
cmd.extend(step_args_list)

return cmd

cli_args = CLIArgs()
24 changes: 23 additions & 1 deletion metaflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metaflow.metaflow_config import DEFAULT_METADATA
from metaflow.plugins import ENVIRONMENTS, METADATA_PROVIDERS

from metaflow.unbounded_foreach import CONTROL_TASK_TAG
from metaflow.util import cached_property, resolve_identity

from .filecache import FileCache
Expand Down Expand Up @@ -1141,7 +1142,8 @@ def task(self):
A task in the step
"""
for t in self:
return t
if CONTROL_TASK_TAG not in t.tags:
return t

def tasks(self, *tags):
"""
Expand All @@ -1164,6 +1166,26 @@ def tasks(self, *tags):
return self._filtered_children(*tags)

@property
def control_task(self):
children = super(Step, self).__iter__()
for t in children:
if CONTROL_TASK_TAG in t.tags:
return t

def control_tasks(self, *tags):
children = super(Step, self).__iter__()
filter_tags = [CONTROL_TASK_TAG]
filter_tags.extend(tags)
for child in children:
if all(tag in child.tags for tag in filter_tags):
yield child

def __iter__(self):
children = super(Step, self).__iter__()
for t in children:
if CONTROL_TASK_TAG not in t.tags:
yield t
@property
def finished_at(self):
"""
Returns the datetime object of when the step finished (successfully or not).
Expand Down
15 changes: 10 additions & 5 deletions metaflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def step_task_retry_count(self):
Returns a tuple of (user_code_retries, error_retries). Error retries
are attempts to run the process after the user code has failed all
its retries.
Return None, None to disable all retries by the (native) runtime.
"""
return 0, 0

Expand All @@ -235,19 +236,21 @@ def runtime_task_created(self,
task_id,
split_index,
input_paths,
is_cloned):
is_cloned,
ubf_context):
"""
Called when the runtime has created a task related to this step.
"""
pass

def runtime_finished(self, exception):
"""
Called when the runtime created task finishes or encounters an interrupt/exception.
Called when the runtime finishes the flow or encounters an
interrupt/exception.
"""
pass

def runtime_step_cli(self, cli_args, retry_count, max_user_code_retries):
def runtime_step_cli(self, cli_args, retry_count, max_user_code_retries, ubf_context):
"""
Access the command line for a step execution in the runtime context.
"""
Expand All @@ -262,7 +265,8 @@ def task_pre_step(self,
flow,
graph,
retry_count,
max_user_code_retries):
max_user_code_retries,
ubf_context):
"""
Run before the step function in the task context.
"""
Expand All @@ -273,7 +277,8 @@ def task_decorate(self,
flow,
graph,
retry_count,
max_user_code_retries):
max_user_code_retries,
ubf_context):
return step_func

def task_post_step(self,
Expand Down
51 changes: 38 additions & 13 deletions metaflow/flowspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .exception import MetaflowException, MetaflowInternalError, \
MissingInMergeArtifactsException, UnhandledInMergeArtifactsException
from .graph import FlowGraph
from .unbounded_foreach import UnboundedForeachInput

# For Python 3 compatibility
try:
Expand Down Expand Up @@ -355,6 +356,25 @@ def merge_artifacts(self, inputs, exclude=[], include=[]):
for var, (inp, _) in to_merge.items():
setattr(self, var, getattr(inp, var))

def _validate_ubf_step(self, step_name):
join_list = self._graph[step_name].out_funcs
if len(join_list) != 1:
msg = "UnboundedForeach is supported over a single node, "\
"not an arbitrary DAG. Specify a single `join` node"\
" instead of multiple:{join_list}."\
.format(join_list=join_list)
raise InvalidNextException(msg)
join_step = join_list[0]
join_node = self._graph[join_step]
join_type = join_node.type

if join_type != 'join':
msg = "UnboundedForeach found for:{node} -> {join}."\
" The join type isn't valid."\
.format(node=step_name,
join=join_step)
raise InvalidNextException(msg)

def next(self, *dsts, **kwargs):
"""
Indicates the next step to execute at the end of this step
Expand Down Expand Up @@ -443,19 +463,24 @@ def next(self, *dsts, **kwargs):
.format(step=step, var=foreach)
raise InvalidNextException(msg)

try:
self._foreach_num_splits = sum(1 for _ in foreach_iter)
except TypeError:
msg = "Foreach variable *self.{var}* in step *{step}* "\
"is not iterable. Check your variable."\
.format(step=step, var=foreach)
raise InvalidNextException(msg)

if self._foreach_num_splits == 0:
msg = "Foreach iterator over *{var}* in step *{step}* "\
"produced zero splits. Check your variable."\
.format(step=step, var=foreach)
raise InvalidNextException(msg)
if issubclass(type(foreach_iter), UnboundedForeachInput):
self._unbounded_foreach = True
self._foreach_num_splits = None
self._validate_ubf_step(funcs[0])
else:
try:
self._foreach_num_splits = sum(1 for _ in foreach_iter)
except TypeError:
msg = "Foreach variable *self.{var}* in step *{step}* "\
"is not iterable. Check your variable."\
.format(step=step, var=foreach)
raise InvalidNextException(msg)

if self._foreach_num_splits == 0:
msg = "Foreach iterator over *{var}* in step *{step}* "\
"produced zero splits. Check your variable."\
.format(step=step, var=foreach)
raise InvalidNextException(msg)

self._foreach_var = foreach

Expand Down
5 changes: 4 additions & 1 deletion metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def _merge_lists(base, overrides, attr):
from .aws.batch.batch_decorator import BatchDecorator, ResourcesDecorator
from .aws.step_functions.step_functions_decorator import StepFunctionsInternalDecorator
from .conda.conda_step_decorator import CondaStepDecorator
from .test_unbounded_foreach.test_unbounded_foreach_decorator\
import TestUnboundedForeachDecorator

STEP_DECORATORS = _merge_lists([CatchDecorator,
TimeoutDecorator,
Expand All @@ -61,7 +63,8 @@ def _merge_lists(base, overrides, attr):
RetryDecorator,
BatchDecorator,
StepFunctionsInternalDecorator,
CondaStepDecorator], ext_plugins.STEP_DECORATORS, 'name')
CondaStepDecorator,
TestUnboundedForeachDecorator], ext_plugins.STEP_DECORATORS, 'name')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should expose this decorator right now.


# Add Conda environment
from .conda.conda_environment import CondaEnvironment
Expand Down
6 changes: 6 additions & 0 deletions metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metaflow.datastore import FlowDataStore
from metaflow.datastore.local_backend import LocalBackend
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from metaflow import util
from metaflow import R
from metaflow.exception import (
Expand Down Expand Up @@ -154,6 +155,11 @@ def kill(ctx, run_id, user, my_runs):
@click.option("--shared_memory", help="Shared Memory requirement for AWS Batch.")
@click.option("--max_swap", help="Max Swap requirement for AWS Batch.")
@click.option("--swappiness", help="Swappiness requirement for AWS Batch.")
@click.option('--ubf-context',
default=None,
type=click.Choice([None, UBF_CONTROL, UBF_TASK]),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be a discrepancy between the 'none' option in cli.py and None here. Is this intended?

help="Provides additional context if it belongs to an unbounded "
"foreach.")
@click.pass_context
def step(
ctx,
Expand Down
10 changes: 6 additions & 4 deletions metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ def runtime_init(self, flow, graph, package, run_id):
self.run_id = run_id

def runtime_task_created(
self, task_datastore, task_id, split_index, input_paths, is_cloned):
self, task_datastore, task_id, split_index, input_paths, is_cloned,
ubf_context):
if not is_cloned:
self._save_package_once(self.flow_datastore, self.package)

def runtime_step_cli(self, cli_args, retry_count, max_user_code_retries):
def runtime_step_cli(self, cli_args, retry_count, max_user_code_retries,
ubf_context):
if retry_count <= max_user_code_retries:
# after all attempts to run the user code have failed, we don't need
# Batch anymore. We can execute possible fallback code locally.
Expand All @@ -187,7 +189,7 @@ def runtime_step_cli(self, cli_args, retry_count, max_user_code_retries):

def task_pre_step(
self, step_name, task_datastore, metadata, run_id, task_id, flow, graph, retry_count,
max_retries):
max_retries, ubf_context):
if metadata.TYPE == 'local':
self.task_ds = task_datastore
else:
Expand All @@ -196,7 +198,7 @@ def task_pre_step(
meta['aws-batch-job-id'] = os.environ['AWS_BATCH_JOB_ID']
meta['aws-batch-job-attempt'] = os.environ['AWS_BATCH_JOB_ATTEMPT']
meta['aws-batch-ce-name'] = os.environ['AWS_BATCH_CE_NAME']
meta['aws-batch-jq-name'] = os.environ['AWS_BATCH_JQ_NAME']
meta['aws-batch-jq-name'] = os.environ['AWS_BATCH_JQ_NAME']
entries = [MetaDatum(field=k, value=v, type=k, tags=[]) for k, v in meta.items()]
# Register book-keeping metadata for debugging.
metadata.register_metadata(run_id, step_name, task_id, entries)
Expand Down
Loading