Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into add-validation-fl…
Browse files Browse the repository at this point in the history
…yteformat-type
  • Loading branch information
jasonlai1218 committed Nov 9, 2023
2 parents bd4232a + c7c8289 commit 63e634a
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 23 deletions.
2 changes: 1 addition & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class RunLevelParams(PyFlyteParams):
)
remote: bool = make_field(
click.Option(
param_decls=["--remote"],
param_decls=["-r", "--remote"],
required=False,
is_flag=True,
default=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
TfJob
"""

from .task import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker
from .task import PS, Chief, CleanPodPolicy, Evaluator, RestartPolicy, RunPolicy, TfJob, Worker
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ class Worker:
restart_policy: Optional[RestartPolicy] = None


@dataclass
class Evaluator:
image: Optional[str] = None
requests: Optional[Resources] = None
limits: Optional[Resources] = None
replicas: int = 0
restart_policy: Optional[RestartPolicy] = None


@dataclass
class TfJob:
"""
Expand All @@ -95,6 +104,7 @@ class TfJob:
chief: Configuration for the chief replica group.
ps: Configuration for the parameter server (PS) replica group.
worker: Configuration for the worker replica group.
evaluator: Configuration for the evaluator replica group.
run_policy: Configuration for the run policy.
num_workers: [DEPRECATED] This argument is deprecated. Use `worker.replicas` instead.
num_ps_replicas: [DEPRECATED] This argument is deprecated. Use `ps.replicas` instead.
Expand All @@ -104,11 +114,13 @@ class TfJob:
chief: Chief = field(default_factory=lambda: Chief())
ps: PS = field(default_factory=lambda: PS())
worker: Worker = field(default_factory=lambda: Worker())
evaluator: Evaluator = field(default_factory=lambda: Evaluator())
run_policy: Optional[RunPolicy] = field(default_factory=lambda: None)
# Support v0 config for backwards compatibility
num_workers: Optional[int] = None
num_ps_replicas: Optional[int] = None
num_chief_replicas: Optional[int] = None
num_evaluator_replicas: Optional[int] = None


class TensorflowFunctionTask(PythonFunctionTask[TfJob]):
Expand All @@ -130,19 +142,23 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
)
if task_config.num_chief_replicas and task_config.chief.replicas:
raise ValueError(
"Cannot specify both `num_workers` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
"Cannot specify both `num_chief_replicas` and `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
)
if task_config.num_chief_replicas is None and task_config.chief.replicas is None:
raise ValueError(
"Must specify either `num_workers` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
"Must specify either `num_chief_replicas` or `chief.replicas`. Please use `chief.replicas` as `num_chief_replicas` is depreacated."
)
if task_config.num_ps_replicas and task_config.ps.replicas:
raise ValueError(
"Cannot specify both `num_workers` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
"Cannot specify both `num_ps_replicas` and `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
)
if task_config.num_ps_replicas is None and task_config.ps.replicas is None:
raise ValueError(
"Must specify either `num_workers` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
"Must specify either `num_ps_replicas` or `ps.replicas`. Please use `ps.replicas` as `num_ps_replicas` is depreacated."
)
if task_config.num_evaluator_replicas and task_config.evaluator.replicas > 0:
raise ValueError(
"Cannot specify both `num_evaluator_replicas` and `evaluator.replicas`. Please use `evaluator.replicas` as `num_evaluator_replicas` is depreacated."
)
super().__init__(
task_type=self._TF_JOB_TASK_TYPE,
Expand All @@ -153,7 +169,7 @@ def __init__(self, task_config: TfJob, task_function: Callable, **kwargs):
)

def _convert_replica_spec(
self, replica_config: Union[Chief, PS, Worker]
self, replica_config: Union[Chief, PS, Worker, Evaluator]
) -> tensorflow_task.DistributedTensorflowTrainingReplicaSpec:
resources = convert_resources_to_resource_model(requests=replica_config.requests, limits=replica_config.limits)
return tensorflow_task.DistributedTensorflowTrainingReplicaSpec(
Expand Down Expand Up @@ -184,11 +200,16 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
if self.task_config.num_ps_replicas:
ps.replicas = self.task_config.num_ps_replicas

evaluator = self._convert_replica_spec(self.task_config.evaluator)
if self.task_config.num_evaluator_replicas:
evaluator.replicas = self.task_config.num_evaluator_replicas

run_policy = self._convert_run_policy(self.task_config.run_policy) if self.task_config.run_policy else None
training_task = tensorflow_task.DistributedTensorflowTrainingTask(
chief_replicas=chief,
worker_replicas=worker,
ps_replicas=ps,
evaluator_replicas=evaluator,
run_policy=run_policy,
)

Expand Down
26 changes: 16 additions & 10 deletions plugins/flytekit-kf-tensorflow/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.8
# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# pip-compile requirements.in
Expand Down Expand Up @@ -75,6 +75,7 @@ cryptography==39.0.2
# msal
# pyjwt
# pyopenssl
# secretstorage
dataclasses-json==0.5.7
# via flytekit
decorator==5.1.1
Expand All @@ -89,8 +90,10 @@ docker-image-py==0.1.12
# via flytekit
docstring-parser==0.15
# via flytekit
flyteidl==1.5.5
# via flytekit
flyteidl==1.10.0
# via
# flytekit
# flytekitplugins-kftensorflow
flytekit==1.6.1
# via flytekitplugins-kftensorflow
frozenlist==1.3.3
Expand Down Expand Up @@ -151,12 +154,14 @@ importlib-metadata==6.1.0
# via
# flytekit
# keyring
importlib-resources==5.12.0
# via keyring
isodate==0.6.1
# via azure-storage-blob
jaraco-classes==3.2.3
# via keyring
jeepney==0.8.0
# via
# keyring
# secretstorage
jinja2==3.1.2
# via
# cookiecutter
Expand Down Expand Up @@ -240,7 +245,9 @@ pycparser==2.21
pygments==2.15.1
# via rich
pyjwt[crypto]==2.7.0
# via msal
# via
# msal
# pyjwt
pyopenssl==23.0.0
# via flytekit
python-dateutil==2.8.2
Expand Down Expand Up @@ -299,6 +306,8 @@ rsa==4.9
# via google-auth
s3fs==2023.5.0
# via flytekit
secretstorage==3.3.3
# via keyring
six==1.16.0
# via
# azure-core
Expand All @@ -323,7 +332,6 @@ typing-extensions==4.5.0
# azure-core
# azure-storage-blob
# flytekit
# rich
# typing-inspect
typing-inspect==0.8.0
# via dataclasses-json
Expand All @@ -350,9 +358,7 @@ wrapt==1.15.0
yarl==1.9.2
# via aiohttp
zipp==3.15.0
# via
# importlib-metadata
# importlib-resources
# via importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-tensorflow/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.6.1"]
plugin_requires = ["flyteidl>=1.10.0", "flytekit>=1.6.1"]

__version__ = "0.0.0+develop"

Expand Down
41 changes: 39 additions & 2 deletions plugins/flytekit-kf-tensorflow/tests/test_tensorflow_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, RestartPolicy, RunPolicy, TfJob, Worker
from flytekitplugins.kftensorflow import PS, Chief, CleanPodPolicy, Evaluator, RestartPolicy, RunPolicy, TfJob, Worker

from flytekit import Resources, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
Expand All @@ -23,6 +23,7 @@ def test_tensorflow_task_with_default_config(serialization_settings: Serializati
worker=Worker(replicas=1),
chief=Chief(replicas=0),
ps=PS(replicas=0),
evaluator=Evaluator(replicas=0),
)

@task(
Expand Down Expand Up @@ -52,6 +53,9 @@ def my_tensorflow_task(x: int, y: str) -> int:
"psReplicas": {
"resources": {},
},
"evaluatorReplicas": {
"resources": {},
},
}
assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict

Expand All @@ -75,6 +79,13 @@ def test_tensorflow_task_with_custom_config(serialization_settings: Serializatio
replicas=2,
restart_policy=RestartPolicy.ALWAYS,
),
evaluator=Evaluator(
replicas=5,
requests=Resources(cpu="2", mem="2Gi"),
limits=Resources(cpu="4", mem="2Gi"),
image="evaluator:latest",
restart_policy=RestartPolicy.FAILURE,
),
)

@task(
Expand Down Expand Up @@ -122,7 +133,23 @@ def my_tensorflow_task(x: int, y: str) -> int:
"replicas": 2,
"restartPolicy": "RESTART_POLICY_ALWAYS",
},
"evaluatorReplicas": {
"replicas": 5,
"image": "evaluator:latest",
"resources": {
"requests": [
{"name": "CPU", "value": "2"},
{"name": "MEMORY", "value": "2Gi"},
],
"limits": [
{"name": "CPU", "value": "4"},
{"name": "MEMORY", "value": "2Gi"},
],
},
"restartPolicy": "RESTART_POLICY_ON_FAILURE",
},
}

assert my_tensorflow_task.get_custom(serialization_settings) == expected_custom_dict


Expand All @@ -131,6 +158,7 @@ def test_tensorflow_task_with_run_policy(serialization_settings: SerializationSe
worker=Worker(replicas=1),
ps=PS(replicas=0),
chief=Chief(replicas=0),
evaluator=Evaluator(replicas=0),
run_policy=RunPolicy(
clean_pod_policy=CleanPodPolicy.RUNNING,
backoff_limit=5,
Expand Down Expand Up @@ -166,19 +194,23 @@ def my_tensorflow_task(x: int, y: str) -> int:
"psReplicas": {
"resources": {},
},
"evaluatorReplicas": {
"resources": {},
},
"runPolicy": {
"cleanPodPolicy": "CLEANPOD_POLICY_RUNNING",
"backoffLimit": 5,
"activeDeadlineSeconds": 100,
"ttlSecondsAfterFinished": 100,
},
}

assert my_tensorflow_task.get_custom(serialization_settings) == expected_dict


def test_tensorflow_task():
@task(
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1),
task_config=TfJob(num_workers=10, num_ps_replicas=1, num_chief_replicas=1, num_evaluator_replicas=1),
cache=True,
requests=Resources(cpu="1"),
cache_version="1",
Expand Down Expand Up @@ -212,7 +244,12 @@ def my_tensorflow_task(x: int, y: str) -> int:
"replicas": 1,
"resources": {},
},
"evaluatorReplicas": {
"replicas": 1,
"resources": {},
},
}

assert my_tensorflow_task.get_custom(settings) == expected_dict
assert my_tensorflow_task.resources.limits == Resources()
assert my_tensorflow_task.resources.requests == Resources(cpu="1")
Expand Down
6 changes: 5 additions & 1 deletion plugins/flytekit-vaex/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "vaex-core>=4.13.0,<4.14"]
plugin_requires = [
"flytekit>=1.3.0b2,<2.0.0",
"vaex-core>=4.13.0,<4.14; python_version < '3.10'",
"vaex-core>=4.16.0; python_version >= '3.10'",
]

__version__ = "0.0.0+develop"

Expand Down
13 changes: 11 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,20 @@ def remote():
return flyte_remote


def test_pyflyte_run_wf(remote):
@pytest.mark.parametrize(
"remote_flag",
[
"-r",
"--remote",
],
)
def test_pyflyte_run_wf(remote, remote_flag):
with mock.patch("flytekit.clis.sdk_in_container.helpers.get_remote"):
runner = CliRunner()
module_path = WORKFLOW_FILE
result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False)
result = runner.invoke(
pyflyte.main, ["run", remote_flag, module_path, "my_wf", "--help"], catch_exceptions=False
)

assert result.exit_code == 0

Expand Down

0 comments on commit 63e634a

Please sign in to comment.