Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Adds ability to publish VertexAICustomTrainingJob block as a vertex…
Browse files Browse the repository at this point in the history
…-ai work pool (#238)

* Adds ability to publish VertexAICustomTrainingJob block as a vertex-ai work pool

* Use credentials block in default case

* Apply suggestions from code review

Co-authored-by: nate nowack <[email protected]>

* Updates changelog

---------

Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
desertaxle and zzstoatzz authored Dec 11, 2023
1 parent d34b0fe commit a2c74df
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 2 deletions.
15 changes: 13 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Modified default command logic in `CloudRunWorkerJobV2Configuration` to utilize the `BaseJobConfiguration._base_flow_run_command` method.

### Security

## 0.5.5

Released December 11th, 2023.

### Added

- Ability to publish `CloudRun` blocks as cloud-run work pools - [#237](https://github.com/PrefectHQ/prefect-gcp/pull/237)
- Ability to publish `VertexAICustomTrainingJob` blocks as a vertex-ai work pool - [#238](https://github.com/PrefectHQ/prefect-gcp/pull/238)

### Fixed

- Modified default command logic in `CloudRunWorkerJobV2Configuration` to utilize the `BaseJobConfiguration._base_flow_run_command` method.

## 0.5.4

Released November 29th, 2023.
Expand Down
70 changes: 70 additions & 0 deletions prefect_gcp/aiplatform.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

import datetime
import re
import shlex
import time
from typing import Dict, List, Optional, Tuple
from uuid import uuid4
Expand Down Expand Up @@ -92,6 +93,11 @@
except ModuleNotFoundError:
pass

from prefect.blocks.core import BlockNotSavedError
from prefect.workers.utilities import (
get_default_base_job_template_for_infrastructure_type,
)

from prefect_gcp.credentials import GcpCredentials

_DISALLOWED_GCP_LABEL_CHARACTERS = re.compile(r"[^-a-zA-Z0-9_]+")
Expand Down Expand Up @@ -248,6 +254,70 @@ def preview(self) -> str:
)
return str(custom_job) # outputs a json string

def get_corresponding_worker_type(self) -> str:
"""Return the corresponding worker type for this infrastructure block."""
return "vertex-ai"

async def generate_work_pool_base_job_template(self) -> dict:
"""
Generate a base job template for a `Vertex AI` work pool with the same
configuration as this block.
Returns:
- dict: a base job template for a `Vertex AI` work pool
"""
base_job_template = await get_default_base_job_template_for_infrastructure_type(
self.get_corresponding_worker_type(),
)
assert (
base_job_template is not None
), "Failed to generate default base job template for Cloud Run worker."
for key, value in self.dict(exclude_unset=True, exclude_defaults=True).items():
if key == "command":
base_job_template["variables"]["properties"]["command"][
"default"
] = shlex.join(value)
elif key in [
"type",
"block_type_slug",
"_block_document_id",
"_block_document_name",
"_is_anonymous",
]:
continue
elif key == "gcp_credentials":
if not self.gcp_credentials._block_document_id:
raise BlockNotSavedError(
"It looks like you are trying to use a block that"
" has not been saved. Please call `.save` on your block"
" before publishing it as a work pool."
)
base_job_template["variables"]["properties"]["credentials"][
"default"
] = {
"$ref": {
"block_document_id": str(
self.gcp_credentials._block_document_id
)
}
}
elif key == "maximum_run_time":
base_job_template["variables"]["properties"]["maximum_run_time_hours"][
"default"
] = round(value.total_seconds() / 3600)
elif key == "service_account":
base_job_template["variables"]["properties"]["service_account_name"][
"default"
] = value
elif key in base_job_template["variables"]["properties"]:
base_job_template["variables"]["properties"][key]["default"] = value
else:
self.logger.warning(
f"Variable {key!r} is not supported by `Vertex AI` work pools."
" Skipping."
)

return base_job_template

def _build_job_spec(self) -> "CustomJobSpec":
"""
Builds a job spec by gathering details.
Expand Down
131 changes: 131 additions & 0 deletions tests/test_aiplatform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -10,6 +11,8 @@
VertexAICustomTrainingJob,
VertexAICustomTrainingJobResult,
)
from prefect_gcp.credentials import GcpCredentials
from prefect_gcp.workers.vertex import VertexAIWorker


class TestVertexAICustomTrainingJob:
Expand Down Expand Up @@ -177,3 +180,131 @@ def test_machine_spec(
job_spec.worker_pool_specs[0].machine_spec.accelerator_type
== AcceleratorType.NVIDIA_TESLA_T4
)


@pytest.fixture
def default_base_job_template():
return deepcopy(VertexAIWorker.get_default_base_job_template())


@pytest.fixture
async def credentials_block(service_account_info):
credentials_block = GcpCredentials(
service_account_info=service_account_info, project="my-project"
)
await credentials_block.save("test-for-publish", overwrite=True)
return credentials_block


@pytest.fixture
def base_job_template_with_defaults(default_base_job_template, credentials_block):
base_job_template_with_defaults = deepcopy(default_base_job_template)
base_job_template_with_defaults["variables"]["properties"]["command"][
"default"
] = "python my_script.py"
base_job_template_with_defaults["variables"]["properties"]["env"]["default"] = {
"VAR1": "value1",
"VAR2": "value2",
}
base_job_template_with_defaults["variables"]["properties"]["labels"]["default"] = {
"label1": "value1",
"label2": "value2",
}
base_job_template_with_defaults["variables"]["properties"]["name"][
"default"
] = "prefect-job"
base_job_template_with_defaults["variables"]["properties"]["image"][
"default"
] = "docker.io/my_image:latest"
base_job_template_with_defaults["variables"]["properties"]["credentials"][
"default"
] = {"$ref": {"block_document_id": str(credentials_block._block_document_id)}}
base_job_template_with_defaults["variables"]["properties"]["region"][
"default"
] = "us-central1"
base_job_template_with_defaults["variables"]["properties"]["machine_type"][
"default"
] = "n1-standard-4"
base_job_template_with_defaults["variables"]["properties"]["accelerator_count"][
"default"
] = 1
base_job_template_with_defaults["variables"]["properties"]["accelerator_type"][
"default"
] = "NVIDIA_TESLA_T4"
base_job_template_with_defaults["variables"]["properties"]["boot_disk_type"][
"default"
] = "pd-ssd"
base_job_template_with_defaults["variables"]["properties"]["boot_disk_size_gb"][
"default"
] = 200
base_job_template_with_defaults["variables"]["properties"][
"maximum_run_time_hours"
]["default"] = 24
base_job_template_with_defaults["variables"]["properties"]["network"][
"default"
] = "my-network"
base_job_template_with_defaults["variables"]["properties"]["reserved_ip_ranges"][
"default"
] = ["172.31.0.0/16", "192.168.0.0./16"]
base_job_template_with_defaults["variables"]["properties"]["service_account_name"][
"default"
] = "my-service-account"
base_job_template_with_defaults["variables"]["properties"][
"job_watch_poll_interval"
]["default"] = 60
return base_job_template_with_defaults


@pytest.mark.parametrize(
"job_config",
[
"default",
"custom",
],
)
async def test_generate_work_pool_base_job_template(
job_config,
base_job_template_with_defaults,
credentials_block,
default_base_job_template,
):
job = VertexAICustomTrainingJob(
image="docker.io/my_image:latest",
region="us-central1",
gcp_credentials=credentials_block,
)
expected_template = default_base_job_template
default_base_job_template["variables"]["properties"]["image"][
"default"
] = "docker.io/my_image:latest"
default_base_job_template["variables"]["properties"]["region"][
"default"
] = "us-central1"
default_base_job_template["variables"]["properties"]["credentials"]["default"] = {
"$ref": {"block_document_id": str(credentials_block._block_document_id)}
}
if job_config == "custom":
expected_template = base_job_template_with_defaults
job = VertexAICustomTrainingJob(
command=["python", "my_script.py"],
env={"VAR1": "value1", "VAR2": "value2"},
labels={"label1": "value1", "label2": "value2"},
name="prefect-job",
image="docker.io/my_image:latest",
gcp_credentials=credentials_block,
region="us-central1",
machine_type="n1-standard-4",
accelerator_count=1,
accelerator_type="NVIDIA_TESLA_T4",
boot_disk_type="pd-ssd",
boot_disk_size_gb=200,
maximum_run_time=60 * 60 * 24,
network="my-network",
reserved_ip_ranges=["172.31.0.0/16", "192.168.0.0./16"],
service_account="my-service-account",
job_watch_poll_interval=60,
)

template = await job.generate_work_pool_base_job_template()

assert template == expected_template

0 comments on commit a2c74df

Please sign in to comment.