Skip to content

Commit

Permalink
Add tags to execution (#1723)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: Kevin Su <[email protected]>

* Add tests

Signed-off-by: Kevin Su <[email protected]>

* Use JsonParamType instead

Signed-off-by: Kevin Su <[email protected]>

* update

Signed-off-by: Kevin Su <[email protected]>

* update idl

Signed-off-by: Kevin Su <[email protected]>

* update idl

Signed-off-by: Kevin Su <[email protected]>

* update idl

Signed-off-by: Kevin Su <[email protected]>

* update idl

Signed-off-by: Kevin Su <[email protected]>

* bump grpcio-status version

Signed-off-by: Kevin Su <[email protected]>

* bump grpcioversion

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
pingsutw and eapolinario authored Aug 10, 2023
1 parent 6291301 commit edfa767
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 4 deletions.
2 changes: 1 addition & 1 deletion doc-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ flask==2.3.2
# via mlflow
flatbuffers==23.5.26
# via tensorflow
flyteidl==1.5.13
flyteidl==1.5.14
# via flytekit
fonttools==4.41.1
# via matplotlib
Expand Down
8 changes: 8 additions & 0 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
type=JsonParamType(),
help="Environment variables to set in the container",
),
click.Option(
param_decls=["--tag", "tag"],
required=False,
multiple=True,
type=str,
help="Tags to set for the execution",
),
]


Expand Down Expand Up @@ -708,6 +715,7 @@ def _run(*args, **kwargs):
type_hints=entity.python_interface.inputs,
overwrite_cache=run_level_params.get("overwrite_cache"),
envs=run_level_params.get("envs"),
tags=run_level_params.get("tag"),
)

console_url = remote.generate_console_url(execution)
Expand Down
9 changes: 9 additions & 0 deletions flytekit/models/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(
security_context: Optional[security.SecurityContext] = None,
overwrite_cache: Optional[bool] = None,
envs: Optional[_common_models.Envs] = None,
tags: Optional[typing.List[str]] = None,
):
"""
:param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute
Expand All @@ -194,6 +195,7 @@ def __init__(
:param security_context: Optional security context to use for this execution.
:param overwrite_cache: Optional flag to overwrite the cache for this execution.
:param envs: flytekit.models.common.Envs environment variables to set for this execution.
:param tags: Optional list of tags to apply to the execution.
"""
self._launch_plan = launch_plan
self._metadata = metadata
Expand All @@ -207,6 +209,7 @@ def __init__(
self._security_context = security_context
self._overwrite_cache = overwrite_cache
self._envs = envs
self._tags = tags

@property
def launch_plan(self):
Expand Down Expand Up @@ -281,6 +284,10 @@ def overwrite_cache(self) -> Optional[bool]:
def envs(self) -> Optional[_common_models.Envs]:
return self._envs

@property
def tags(self) -> Optional[typing.List[str]]:
return self._tags

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.execution_pb2.ExecutionSpec
Expand All @@ -300,6 +307,7 @@ def to_flyte_idl(self):
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache,
envs=self.envs.to_flyte_idl() if self.envs else None,
tags=self.tags,
)

@classmethod
Expand All @@ -325,6 +333,7 @@ def from_flyte_idl(cls, p):
else None,
overwrite_cache=p.overwrite_cache,
envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None,
tags=p.tags,
)


Expand Down
23 changes: 23 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,7 @@ def _execute(
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.
Expand All @@ -970,6 +971,7 @@ def _execute(
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
execution_name = execution_name or "f" + uuid.uuid4().hex[:19]
Expand Down Expand Up @@ -1035,6 +1037,7 @@ def _execute(
max_parallelism=options.max_parallelism,
security_context=options.security_context,
envs=common_models.Envs(envs) if envs else None,
tags=tags,
),
literal_inputs,
)
Expand Down Expand Up @@ -1092,6 +1095,7 @@ def execute(
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""
Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity.
Expand Down Expand Up @@ -1129,6 +1133,7 @@ def execute(
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:param envs: Environment variables to be set for the execution.
:param tags: Tags to be set for the execution.
.. note:
Expand All @@ -1149,6 +1154,7 @@ def execute(
type_hints=type_hints,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)
if isinstance(entity, FlyteWorkflow):
return self.execute_remote_wf(
Expand All @@ -1162,6 +1168,7 @@ def execute(
type_hints=type_hints,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)
if isinstance(entity, PythonTask):
return self.execute_local_task(
Expand All @@ -1176,6 +1183,7 @@ def execute(
wait=wait,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)
if isinstance(entity, WorkflowBase):
return self.execute_local_workflow(
Expand All @@ -1191,6 +1199,7 @@ def execute(
wait=wait,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)
if isinstance(entity, LaunchPlan):
return self.execute_local_launch_plan(
Expand All @@ -1204,6 +1213,7 @@ def execute(
wait=wait,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)
raise NotImplementedError(f"entity type {type(entity)} not recognized for execution")

Expand All @@ -1222,6 +1232,7 @@ def execute_remote_task_lp(
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteTask, or FlyteLaunchplan.
Expand All @@ -1238,6 +1249,7 @@ def execute_remote_task_lp(
type_hints=type_hints,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)

def execute_remote_wf(
Expand All @@ -1252,6 +1264,7 @@ def execute_remote_wf(
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""Execute a FlyteWorkflow.
Expand All @@ -1269,6 +1282,7 @@ def execute_remote_wf(
type_hints=type_hints,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)

# Flytekit Entities
Expand All @@ -1287,6 +1301,7 @@ def execute_local_task(
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""
Execute a @task-decorated function or TaskTemplate task.
Expand All @@ -1302,6 +1317,7 @@ def execute_local_task(
:param wait: If True, will wait for the execution to complete before returning.
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to set for the execution.
:param tags: Tags to set for the execution.
:return: FlyteWorkflowExecution object.
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1330,6 +1346,7 @@ def execute_local_task(
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)

def execute_local_workflow(
Expand All @@ -1346,6 +1363,7 @@ def execute_local_workflow(
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""
Execute an @workflow decorated function.
Expand All @@ -1361,6 +1379,7 @@ def execute_local_workflow(
:param wait:
:param overwrite_cache:
:param envs:
:param tags:
:return:
"""
resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version)
Expand Down Expand Up @@ -1407,6 +1426,7 @@ def execute_local_workflow(
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)

def execute_local_launch_plan(
Expand All @@ -1421,6 +1441,7 @@ def execute_local_launch_plan(
wait: bool = False,
overwrite_cache: typing.Optional[bool] = None,
envs: typing.Optional[typing.Dict[str, str]] = None,
tags: typing.Optional[typing.List[str]] = None,
) -> FlyteWorkflowExecution:
"""
Expand All @@ -1434,6 +1455,7 @@ def execute_local_launch_plan(
:param wait: If True, will wait for the execution to complete before returning.
:param overwrite_cache: If True, will overwrite the cache.
:param envs: Environment variables to be passed into the execution.
:param tags: Tags to be passed into the execution.
:return: FlyteWorkflowExecution object
"""
try:
Expand Down Expand Up @@ -1461,6 +1483,7 @@ def execute_local_launch_plan(
type_hints=entity.python_interface.inputs,
overwrite_cache=overwrite_cache,
envs=envs,
tags=tags,
)

###################################
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
install_requires=[
"googleapis-common-protos>=1.57",
"flyteidl>=1.5.12",
"flyteidl>=1.5.14",
"wheel>=0.30.0,<1.0.0",
"pandas>=1.0.0,<2.0.0",
"pyarrow>=4.0.0,<11.0.0",
Expand Down
10 changes: 9 additions & 1 deletion tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,18 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote

remote = FlyteRemote(Config.auto(), PROJECT, "development")
execution = remote.execute(
t1, inputs={"a": 10}, version=f"v{VERSION}", wait=True, overwrite_cache=True, envs={"foo": "bar"}
t1,
inputs={"a": 10},
version=f"v{VERSION}",
wait=True,
overwrite_cache=True,
envs={"foo": "bar"},
tags=["flyte"],
)
assert execution.outputs["t1_int_output"] == 12
assert execution.outputs["c"] == "world"
assert execution.spec.envs == {"foo": "bar"}
assert execution.spec.tags == ["flyte"]


def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env):
Expand Down
15 changes: 14 additions & 1 deletion tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,20 @@ def test_union_type2(input):
env = '{"foo": "bar"}'
result = runner.invoke(
pyflyte.main,
["run", "--overwrite-cache", "--envs", env, os.path.join(DIR_NAME, "workflow.py"), "test_union2", "--a", input],
[
"run",
"--overwrite-cache",
"--envs",
env,
"--tag",
"flyte",
"--tag",
"hello",
os.path.join(DIR_NAME, "workflow.py"),
"test_union2",
"--a",
input,
],
catch_exceptions=False,
)
print(result.stdout)
Expand Down

0 comments on commit edfa767

Please sign in to comment.