From 9e8331f259195049ef0ca76d136cec442a0b5d0f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 5 Dec 2023 19:23:04 -0500 Subject: [PATCH] Simplify CLI implementation (#2023) --- flytekit/clis/sdk_in_container/build.py | 13 ++----- flytekit/clis/sdk_in_container/run.py | 50 +++++++++++++------------ flytekit/clis/sdk_in_container/utils.py | 2 +- 3 files changed, 31 insertions(+), 34 deletions(-) diff --git a/flytekit/clis/sdk_in_container/build.py b/flytekit/clis/sdk_in_container/build.py index eadbcc2987..d11865fc8e 100644 --- a/flytekit/clis/sdk_in_container/build.py +++ b/flytekit/clis/sdk_in_container/build.py @@ -5,7 +5,7 @@ from typing_extensions import OrderedDict from flytekit.clis.sdk_in_container.run import RunCommand, RunLevelParams, WorkflowCommand -from flytekit.clis.sdk_in_container.utils import make_field +from flytekit.clis.sdk_in_container.utils import make_click_option_field from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.workflow import PythonFunctionWorkflow @@ -14,7 +14,7 @@ @dataclass class BuildParams(RunLevelParams): - fast: bool = make_field( + fast: bool = make_click_option_field( click.Option( param_decls=["--fast"], required=False, @@ -75,18 +75,13 @@ class BuildCommand(RunCommand): A click command group for building a image for flyte workflows & tasks in a file. """ - def __init__(self, *args, **kwargs): - params = BuildParams.options() - kwargs["params"] = params - super().__init__(*args, **kwargs) + _run_params = BuildParams def list_commands(self, ctx, *args, **kwargs): return super().list_commands(ctx, add_remote=False) def get_command(self, ctx, filename): - if ctx.obj is None: - ctx.obj = {} - ctx.obj = BuildParams.from_dict(ctx.params) + super().get_command(ctx, filename) return BuildWorkflowCommand(filename, name=filename, help=f"Build an image for [workflow|task] from {filename}") diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 1b2568e04b..403e5120d1 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -18,7 +18,7 @@ PyFlyteParams, domain_option, get_option_from_metadata, - make_field, + make_click_option_field, pretty_print_exception, project_option, ) @@ -58,9 +58,9 @@ class RunLevelParams(PyFlyteParams): This class is used to store the parameters that are used to run a workflow / task / launchplan. """ - project: str = make_field(project_option) - domain: str = make_field(domain_option) - destination_dir: str = make_field( + project: str = make_click_option_field(project_option) + domain: str = make_click_option_field(domain_option) + destination_dir: str = make_click_option_field( click.Option( param_decls=["--destination-dir", "destination_dir"], required=False, @@ -70,7 +70,7 @@ class RunLevelParams(PyFlyteParams): help="Directory inside the image where the tar file containing the code will be copied to", ) ) - copy_all: bool = make_field( + copy_all: bool = make_click_option_field( click.Option( param_decls=["--copy-all", "copy_all"], required=False, @@ -80,7 +80,7 @@ class RunLevelParams(PyFlyteParams): help="Copy all files in the source root directory to the destination directory", ) ) - image_config: ImageConfig = make_field( + image_config: ImageConfig = make_click_option_field( click.Option( param_decls=["-i", "--image", "image_config"], required=False, @@ -92,7 +92,7 @@ class RunLevelParams(PyFlyteParams): help="Image used to register and run.", ) ) - service_account: str = make_field( + service_account: str = make_click_option_field( click.Option( param_decls=["--service-account", "service_account"], required=False, @@ -101,7 +101,7 @@ class RunLevelParams(PyFlyteParams): help="Service account used when executing this workflow", ) ) - wait_execution: bool = make_field( + wait_execution: bool = make_click_option_field( click.Option( param_decls=["--wait-execution", "wait_execution"], required=False, @@ -111,7 +111,7 @@ class RunLevelParams(PyFlyteParams): help="Whether to wait for the execution to finish", ) ) - dump_snippet: bool = make_field( + dump_snippet: bool = make_click_option_field( click.Option( param_decls=["--dump-snippet", "dump_snippet"], required=False, @@ -121,7 +121,7 @@ class RunLevelParams(PyFlyteParams): help="Whether to dump a code snippet instructing how to load the workflow execution using flyteremote", ) ) - overwrite_cache: bool = make_field( + overwrite_cache: bool = make_click_option_field( click.Option( param_decls=["--overwrite-cache", "overwrite_cache"], required=False, @@ -131,7 +131,7 @@ class RunLevelParams(PyFlyteParams): help="Whether to overwrite the cache if it already exists", ) ) - envvars: typing.Dict[str, str] = make_field( + envvars: typing.Dict[str, str] = make_click_option_field( click.Option( param_decls=["--envvars", "--env"], required=False, @@ -142,7 +142,7 @@ class RunLevelParams(PyFlyteParams): help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", ) ) - tags: typing.List[str] = make_field( + tags: typing.List[str] = make_click_option_field( click.Option( param_decls=["--tags", "--tag"], required=False, @@ -152,7 +152,7 @@ class RunLevelParams(PyFlyteParams): help="Tags to set for the execution", ) ) - name: str = make_field( + name: str = make_click_option_field( click.Option( param_decls=["--name"], required=False, @@ -161,7 +161,7 @@ class RunLevelParams(PyFlyteParams): help="Name to assign to this execution", ) ) - labels: typing.Dict[str, str] = make_field( + labels: typing.Dict[str, str] = make_click_option_field( click.Option( param_decls=["--labels", "--label"], required=False, @@ -172,7 +172,7 @@ class RunLevelParams(PyFlyteParams): help="Labels to be attached to the execution of the format `label_key=label_value`.", ) ) - annotations: typing.Dict[str, str] = make_field( + annotations: typing.Dict[str, str] = make_click_option_field( click.Option( param_decls=["--annotations", "--annotation"], required=False, @@ -183,7 +183,7 @@ class RunLevelParams(PyFlyteParams): help="Annotations to be attached to the execution of the format `key=value`.", ) ) - raw_output_data_prefix: str = make_field( + raw_output_data_prefix: str = make_click_option_field( click.Option( param_decls=["--raw-output-data-prefix", "--raw-data-prefix"], required=False, @@ -200,7 +200,7 @@ class RunLevelParams(PyFlyteParams): ), ) ) - max_parallelism: int = make_field( + max_parallelism: int = make_click_option_field( click.Option( param_decls=["--max-parallelism"], required=False, @@ -210,7 +210,7 @@ class RunLevelParams(PyFlyteParams): " project/domain defaults are used. If 0 then it is unlimited.", ) ) - disable_notifications: bool = make_field( + disable_notifications: bool = make_click_option_field( click.Option( param_decls=["--disable-notifications"], required=False, @@ -220,7 +220,7 @@ class RunLevelParams(PyFlyteParams): help="Should notifications be disabled for this execution.", ) ) - remote: bool = make_field( + remote: bool = make_click_option_field( click.Option( param_decls=["-r", "--remote"], required=False, @@ -231,7 +231,7 @@ class RunLevelParams(PyFlyteParams): help="Whether to register and run the workflow on a Flyte deployment", ) ) - limit: int = make_field( + limit: int = make_click_option_field( click.Option( param_decls=["--limit", "limit"], required=False, @@ -242,7 +242,7 @@ class RunLevelParams(PyFlyteParams): "if `from-server` option is used", ) ) - cluster_pool: str = make_field( + cluster_pool: str = make_click_option_field( click.Option( param_decls=["--cluster-pool", "cluster_pool"], required=False, @@ -735,9 +735,11 @@ class RunCommand(click.RichGroup): A click command group for registering and executing flyte workflows & tasks in a file. """ + _run_params: typing.Type[RunLevelParams] = RunLevelParams + def __init__(self, *args, **kwargs): if "params" not in kwargs: - params = RunLevelParams.options() + params = self._run_params.options() kwargs["params"] = params super().__init__(*args, **kwargs) self._files = [] @@ -754,11 +756,11 @@ def list_commands(self, ctx, add_remote: bool = True): def get_command(self, ctx, filename): if ctx.obj is None: ctx.obj = {} - if not isinstance(ctx.obj, RunLevelParams): + if not isinstance(ctx.obj, self._run_params): params = {} params.update(ctx.params) params.update(ctx.obj) - ctx.obj = RunLevelParams.from_dict(params) + ctx.obj = self._run_params.from_dict(params) if filename == RemoteLaunchPlanGroup.COMMAND_NAME: return RemoteLaunchPlanGroup() return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}") diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py index 3f975913e0..cb5e94d1c9 100644 --- a/flytekit/clis/sdk_in_container/utils.py +++ b/flytekit/clis/sdk_in_container/utils.py @@ -130,7 +130,7 @@ def invoke(self, ctx: click.Context) -> typing.Any: raise SystemExit(e) from e -def make_field(o: click.Option) -> Field: +def make_click_option_field(o: click.Option) -> Field: if o.multiple: o.help = click.style("Multiple values allowed.", bold=True) + f"{o.help}" return field(default_factory=lambda: o.default, metadata={"click.option": o})