Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
parkedwards committed Sep 22, 2023
1 parent c11c3f6 commit 35996c4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
23 changes: 12 additions & 11 deletions prefect_gcp/workers/vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ class VertexAIWorkerVariables(BaseVariables):
"The type of accelerator to attach to the machine. "
"See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"
),
default="NVIDIA_TESLA_K80",
example="NVIDIA_TESLA_K80",
)
accelerator_count: int = Field(
accelerator_count: Optional[int] = Field(
title="Accelerator Count",
description=(
"The number of accelerators to attach to the machine. "
"See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec"
),
default=0,
example=1,
)
boot_disk_type: str = Field(
title="Boot Disk Type",
Expand Down Expand Up @@ -190,8 +190,6 @@ def _get_base_job_spec() -> Dict[str, Any]:
},
"machine_spec": {
"machine_type": "n1-standard-4",
"accelerator_type": "NVIDIA_TESLA_K80",
"accelerator_count": "1",
},
"disk_spec": {
"boot_disk_type": "pd-ssd",
Expand Down Expand Up @@ -412,7 +410,7 @@ async def run(
)

if task_status:
task_status.started(job_name)
task_status.started(job_run.name)

final_job_run = await self._watch_job_run(
job_name=job_name,
Expand All @@ -435,7 +433,10 @@ async def run(
)

error_msg = final_job_run.error.message
if error_msg:

# Vertex will include an error message upon valid
# flow cancellations, so we'll avoid raising an error in that case
if error_msg and "CANCELED" not in error_msg:
raise RuntimeError(error_msg)

status_code = 0 if final_job_run.state == JobState.JOB_STATE_SUCCEEDED else 1
Expand Down Expand Up @@ -619,22 +620,22 @@ async def kill_infrastructure(
await run_sync_in_worker_thread(
self._stop_job,
client=job_service_client,
job_name=infrastructure_pid,
vertex_job_name=infrastructure_pid,
)

def _stop_job(self, client: "JobServiceClient", job_name: str):
def _stop_job(self, client: "JobServiceClient", vertex_job_name: str):
"""
Calls the `cancel_custom_job` method on the Vertex AI Job Service Client.
"""
cancel_custom_job_request = CancelCustomJobRequest(name=job_name)
cancel_custom_job_request = CancelCustomJobRequest(name=vertex_job_name)
try:
client.cancel_custom_job(
request=cancel_custom_job_request,
)
except Exception as exc:
if "does not exist" in str(exc):
raise InfrastructureNotFound(
f"Cannot stop Vertex AI job; the job name {job_name!r} "
f"Cannot stop Vertex AI job; the job name {vertex_job_name!r} "
"could not be found."
) from exc
raise
19 changes: 15 additions & 4 deletions tests/test_vertex_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock

import anyio
Expand Down Expand Up @@ -101,7 +102,6 @@ async def test_validate_incomplete_worker_pool_spec(self, gcp_credentials):
"Job is missing required attributes at the following paths: "
"/worker_pool_specs/0/container_spec/image_uri, "
"/worker_pool_specs/0/disk_spec, "
"/worker_pool_specs/0/machine_spec/accelerator_count, "
"/worker_pool_specs/0/machine_spec/machine_type"
),
"type": "value_error",
Expand Down Expand Up @@ -210,21 +210,32 @@ async def test_cancelled_worker_run(self, flow_run, job_config):
)

async def test_kill_infrastructure(self, flow_run, job_config):
mock = job_config.credentials.job_service_client.create_custom_job
# the CancelCustomJobRequest class seems to reject a MagicMock value
# so here, we'll use a SimpleNamespace as the mocked return values
mock.return_value = SimpleNamespace(
name="foobar", state=JobState.JOB_STATE_PENDING
)

async with VertexAIWorker("test-pool") as worker:
with anyio.fail_after(10):
async with anyio.create_task_group() as tg:
identifier = await tg.start(worker.run, flow_run, job_config)
await worker.kill_infrastructure(identifier, job_config)
result = await tg.start(worker.run, flow_run, job_config)
await worker.kill_infrastructure(result, job_config)

mock = job_config.credentials.job_service_client.cancel_custom_job
assert mock.call_count == 1
assert mock.call_args.kwargs == {
"request": CancelCustomJobRequest(name=identifier)
"request": CancelCustomJobRequest(name="foobar")
}

async def test_kill_infrastructure_no_grace_seconds(
self, flow_run, job_config, caplog
):
mock = job_config.credentials.job_service_client.create_custom_job
mock.return_value = SimpleNamespace(
name="bazzbar", state=JobState.JOB_STATE_PENDING
)
async with VertexAIWorker("test-pool") as worker:

input_grace_period = 32
Expand Down

0 comments on commit 35996c4

Please sign in to comment.