Skip to content

Commit

Permalink
feat: Add disable_retries option to custom jobs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 557870565
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Aug 17, 2023
1 parent 4e76a6e commit db518b0
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 4 deletions.
34 changes: 32 additions & 2 deletions google/cloud/aiplatform/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,7 @@ def run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
) -> None:
"""Run this configured CustomJob.
Expand Down Expand Up @@ -1686,6 +1687,10 @@ def run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
disable_retries (bool):
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
"""
network = network or initializer.global_config.network

Expand All @@ -1700,6 +1705,7 @@ def run(
tensorboard=tensorboard,
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
)

@base.optional_sync()
Expand All @@ -1715,6 +1721,7 @@ def _run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
Expand Down Expand Up @@ -1770,6 +1777,10 @@ def _run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
disable_retries (bool):
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
"""
self.submit(
service_account=service_account,
Expand All @@ -1781,6 +1792,7 @@ def _run(
experiment_run=experiment_run,
tensorboard=tensorboard,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
)

self._block_until_complete()
Expand All @@ -1797,6 +1809,7 @@ def submit(
experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None,
tensorboard: Optional[str] = None,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
) -> None:
"""Submit the configured CustomJob.
Expand Down Expand Up @@ -1849,6 +1862,10 @@ def submit(
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
disable_retries (bool):
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
Raises:
ValueError:
Expand All @@ -1869,11 +1886,12 @@ def submit(
if network:
self._gca_resource.job_spec.network = network

if timeout or restart_job_on_worker_restart:
if timeout or restart_job_on_worker_restart or disable_retries:
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
disable_retries=disable_retries,
)

if enable_web_access:
Expand Down Expand Up @@ -2287,6 +2305,7 @@ def run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
) -> None:
"""Run this configured CustomJob.
Expand Down Expand Up @@ -2331,6 +2350,10 @@ def run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
disable_retries (bool):
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
"""
network = network or initializer.global_config.network

Expand All @@ -2343,6 +2366,7 @@ def run(
tensorboard=tensorboard,
sync=sync,
create_request_timeout=create_request_timeout,
disable_retries=disable_retries,
)

@base.optional_sync()
Expand All @@ -2356,6 +2380,7 @@ def _run(
tensorboard: Optional[str] = None,
sync: bool = True,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
) -> None:
"""Helper method to ensure network synchronization and to run the configured CustomJob.
Expand Down Expand Up @@ -2398,19 +2423,24 @@ def _run(
will unblock and it will be executed in a concurrent Future.
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
disable_retries (bool):
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
"""
if service_account:
self._gca_resource.trial_job_spec.service_account = service_account

if network:
self._gca_resource.trial_job_spec.network = network

if timeout or restart_job_on_worker_restart:
if timeout or restart_job_on_worker_restart or disable_retries:
duration = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.trial_job_spec.scheduling = (
gca_custom_job_compat.Scheduling(
timeout=duration,
restart_job_on_worker_restart=restart_job_on_worker_restart,
disable_retries=disable_retries,
)
)

Expand Down
8 changes: 7 additions & 1 deletion google/cloud/aiplatform/preview/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def submit(
experiment_run: Optional[Union["aiplatform.ExperimentRun", str]] = None,
tensorboard: Optional[str] = None,
create_request_timeout: Optional[float] = None,
disable_retries: bool = False,
) -> None:
"""Submit the configured CustomJob.
Expand Down Expand Up @@ -290,6 +291,10 @@ def submit(
https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training
create_request_timeout (float):
Optional. The timeout for the create request in seconds.
disable_retries (bool):
Indicates if the job should retry for internal errors after the
job starts running. If True, overrides
`restart_job_on_worker_restart` to False.
Raises:
ValueError:
Expand All @@ -310,11 +315,12 @@ def submit(
if network:
self._gca_resource.job_spec.network = network

if timeout or restart_job_on_worker_restart:
if timeout or restart_job_on_worker_restart or disable_retries:
timeout = duration_pb2.Duration(seconds=timeout) if timeout else None
self._gca_resource.job_spec.scheduling = gca_custom_job_compat.Scheduling(
timeout=timeout,
restart_job_on_worker_restart=restart_job_on_worker_restart,
disable_retries=disable_retries,
)

if enable_web_access:
Expand Down
Loading

0 comments on commit db518b0

Please sign in to comment.