Skip to content

Commit

Permalink
Simplify CLI implementation (#2023)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored Dec 6, 2023
1 parent 1e30eb1 commit 9e8331f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
13 changes: 4 additions & 9 deletions flytekit/clis/sdk_in_container/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import OrderedDict

from flytekit.clis.sdk_in_container.run import RunCommand, RunLevelParams, WorkflowCommand
from flytekit.clis.sdk_in_container.utils import make_field
from flytekit.clis.sdk_in_container.utils import make_click_option_field
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.workflow import PythonFunctionWorkflow
Expand All @@ -14,7 +14,7 @@

@dataclass
class BuildParams(RunLevelParams):
fast: bool = make_field(
fast: bool = make_click_option_field(
click.Option(
param_decls=["--fast"],
required=False,
Expand Down Expand Up @@ -75,18 +75,13 @@ class BuildCommand(RunCommand):
A click command group for building a image for flyte workflows & tasks in a file.
"""

def __init__(self, *args, **kwargs):
params = BuildParams.options()
kwargs["params"] = params
super().__init__(*args, **kwargs)
_run_params = BuildParams

def list_commands(self, ctx, *args, **kwargs):
return super().list_commands(ctx, add_remote=False)

def get_command(self, ctx, filename):
if ctx.obj is None:
ctx.obj = {}
ctx.obj = BuildParams.from_dict(ctx.params)
super().get_command(ctx, filename)
return BuildWorkflowCommand(filename, name=filename, help=f"Build an image for [workflow|task] from {filename}")


Expand Down
50 changes: 26 additions & 24 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PyFlyteParams,
domain_option,
get_option_from_metadata,
make_field,
make_click_option_field,
pretty_print_exception,
project_option,
)
Expand Down Expand Up @@ -58,9 +58,9 @@ class RunLevelParams(PyFlyteParams):
This class is used to store the parameters that are used to run a workflow / task / launchplan.
"""

project: str = make_field(project_option)
domain: str = make_field(domain_option)
destination_dir: str = make_field(
project: str = make_click_option_field(project_option)
domain: str = make_click_option_field(domain_option)
destination_dir: str = make_click_option_field(
click.Option(
param_decls=["--destination-dir", "destination_dir"],
required=False,
Expand All @@ -70,7 +70,7 @@ class RunLevelParams(PyFlyteParams):
help="Directory inside the image where the tar file containing the code will be copied to",
)
)
copy_all: bool = make_field(
copy_all: bool = make_click_option_field(
click.Option(
param_decls=["--copy-all", "copy_all"],
required=False,
Expand All @@ -80,7 +80,7 @@ class RunLevelParams(PyFlyteParams):
help="Copy all files in the source root directory to the destination directory",
)
)
image_config: ImageConfig = make_field(
image_config: ImageConfig = make_click_option_field(
click.Option(
param_decls=["-i", "--image", "image_config"],
required=False,
Expand All @@ -92,7 +92,7 @@ class RunLevelParams(PyFlyteParams):
help="Image used to register and run.",
)
)
service_account: str = make_field(
service_account: str = make_click_option_field(
click.Option(
param_decls=["--service-account", "service_account"],
required=False,
Expand All @@ -101,7 +101,7 @@ class RunLevelParams(PyFlyteParams):
help="Service account used when executing this workflow",
)
)
wait_execution: bool = make_field(
wait_execution: bool = make_click_option_field(
click.Option(
param_decls=["--wait-execution", "wait_execution"],
required=False,
Expand All @@ -111,7 +111,7 @@ class RunLevelParams(PyFlyteParams):
help="Whether to wait for the execution to finish",
)
)
dump_snippet: bool = make_field(
dump_snippet: bool = make_click_option_field(
click.Option(
param_decls=["--dump-snippet", "dump_snippet"],
required=False,
Expand All @@ -121,7 +121,7 @@ class RunLevelParams(PyFlyteParams):
help="Whether to dump a code snippet instructing how to load the workflow execution using flyteremote",
)
)
overwrite_cache: bool = make_field(
overwrite_cache: bool = make_click_option_field(
click.Option(
param_decls=["--overwrite-cache", "overwrite_cache"],
required=False,
Expand All @@ -131,7 +131,7 @@ class RunLevelParams(PyFlyteParams):
help="Whether to overwrite the cache if it already exists",
)
)
envvars: typing.Dict[str, str] = make_field(
envvars: typing.Dict[str, str] = make_click_option_field(
click.Option(
param_decls=["--envvars", "--env"],
required=False,
Expand All @@ -142,7 +142,7 @@ class RunLevelParams(PyFlyteParams):
help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`",
)
)
tags: typing.List[str] = make_field(
tags: typing.List[str] = make_click_option_field(
click.Option(
param_decls=["--tags", "--tag"],
required=False,
Expand All @@ -152,7 +152,7 @@ class RunLevelParams(PyFlyteParams):
help="Tags to set for the execution",
)
)
name: str = make_field(
name: str = make_click_option_field(
click.Option(
param_decls=["--name"],
required=False,
Expand All @@ -161,7 +161,7 @@ class RunLevelParams(PyFlyteParams):
help="Name to assign to this execution",
)
)
labels: typing.Dict[str, str] = make_field(
labels: typing.Dict[str, str] = make_click_option_field(
click.Option(
param_decls=["--labels", "--label"],
required=False,
Expand All @@ -172,7 +172,7 @@ class RunLevelParams(PyFlyteParams):
help="Labels to be attached to the execution of the format `label_key=label_value`.",
)
)
annotations: typing.Dict[str, str] = make_field(
annotations: typing.Dict[str, str] = make_click_option_field(
click.Option(
param_decls=["--annotations", "--annotation"],
required=False,
Expand All @@ -183,7 +183,7 @@ class RunLevelParams(PyFlyteParams):
help="Annotations to be attached to the execution of the format `key=value`.",
)
)
raw_output_data_prefix: str = make_field(
raw_output_data_prefix: str = make_click_option_field(
click.Option(
param_decls=["--raw-output-data-prefix", "--raw-data-prefix"],
required=False,
Expand All @@ -200,7 +200,7 @@ class RunLevelParams(PyFlyteParams):
),
)
)
max_parallelism: int = make_field(
max_parallelism: int = make_click_option_field(
click.Option(
param_decls=["--max-parallelism"],
required=False,
Expand All @@ -210,7 +210,7 @@ class RunLevelParams(PyFlyteParams):
" project/domain defaults are used. If 0 then it is unlimited.",
)
)
disable_notifications: bool = make_field(
disable_notifications: bool = make_click_option_field(
click.Option(
param_decls=["--disable-notifications"],
required=False,
Expand All @@ -220,7 +220,7 @@ class RunLevelParams(PyFlyteParams):
help="Should notifications be disabled for this execution.",
)
)
remote: bool = make_field(
remote: bool = make_click_option_field(
click.Option(
param_decls=["-r", "--remote"],
required=False,
Expand All @@ -231,7 +231,7 @@ class RunLevelParams(PyFlyteParams):
help="Whether to register and run the workflow on a Flyte deployment",
)
)
limit: int = make_field(
limit: int = make_click_option_field(
click.Option(
param_decls=["--limit", "limit"],
required=False,
Expand All @@ -242,7 +242,7 @@ class RunLevelParams(PyFlyteParams):
"if `from-server` option is used",
)
)
cluster_pool: str = make_field(
cluster_pool: str = make_click_option_field(
click.Option(
param_decls=["--cluster-pool", "cluster_pool"],
required=False,
Expand Down Expand Up @@ -735,9 +735,11 @@ class RunCommand(click.RichGroup):
A click command group for registering and executing flyte workflows & tasks in a file.
"""

_run_params: typing.Type[RunLevelParams] = RunLevelParams

def __init__(self, *args, **kwargs):
if "params" not in kwargs:
params = RunLevelParams.options()
params = self._run_params.options()
kwargs["params"] = params
super().__init__(*args, **kwargs)
self._files = []
Expand All @@ -754,11 +756,11 @@ def list_commands(self, ctx, add_remote: bool = True):
def get_command(self, ctx, filename):
if ctx.obj is None:
ctx.obj = {}
if not isinstance(ctx.obj, RunLevelParams):
if not isinstance(ctx.obj, self._run_params):
params = {}
params.update(ctx.params)
params.update(ctx.obj)
ctx.obj = RunLevelParams.from_dict(params)
ctx.obj = self._run_params.from_dict(params)
if filename == RemoteLaunchPlanGroup.COMMAND_NAME:
return RemoteLaunchPlanGroup()
return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}")
Expand Down
2 changes: 1 addition & 1 deletion flytekit/clis/sdk_in_container/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def invoke(self, ctx: click.Context) -> typing.Any:
raise SystemExit(e) from e


def make_field(o: click.Option) -> Field:
def make_click_option_field(o: click.Option) -> Field:
if o.multiple:
o.help = click.style("Multiple values allowed.", bold=True) + f"{o.help}"
return field(default_factory=lambda: o.default, metadata={"click.option": o})
Expand Down

0 comments on commit 9e8331f

Please sign in to comment.