Skip to content

Commit

Permalink
feat: LLM - Support tuning in the "us-central1" location
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 547655421
  • Loading branch information
Ark-kun authored and copybara-github committed Jul 13, 2023
1 parent c903e7d commit 4aa7745
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def tune_model(
The dataset must have the "input_text" and "output_text" columns.
train_steps: Number of training batches to tune on (batch size is 8 samples).
learning_rate: Learning rate for the tuning
tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now.
tuning_job_location: GCP location where the tuning job should be run.
Only "europe-west4" and "us-central1" locations are supported for now.
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
model_display_name: Custom display name for the tuned model.
Expand All @@ -166,9 +167,10 @@ def tune_model(
ValueError: If the "tuned_model_location" value is not supported
RuntimeError: If the model does not support tuning
"""
if tuning_job_location != _TUNING_LOCATION:
if tuning_job_location not in _TUNING_LOCATIONS:
raise ValueError(
f'Tuning is only supported in the following locations: tuning_job_location="{_TUNING_LOCATION}"'
"Please specify the tuning job location (`tuning_job_location`)."
f"Tuning is supported in the following locations: {_TUNING_LOCATIONS}"
)
if tuned_model_location != _TUNED_MODEL_LOCATION:
raise ValueError(
Expand All @@ -187,6 +189,7 @@ def tune_model(
tuning_pipeline_uri=model_info.tuning_pipeline_uri,
model_display_name=model_display_name,
learning_rate=learning_rate,
tuning_job_location=tuning_job_location,
)

job = _LanguageModelTuningJob(
Expand Down Expand Up @@ -965,7 +968,7 @@ def predict(

###### Model tuning
# Currently, tuning can only work in this location
_TUNING_LOCATION = "europe-west4"
_TUNING_LOCATIONS = ("europe-west4", "us-central1")
# Currently, deployment can only work in this location
_TUNED_MODEL_LOCATION = "us-central1"

Expand Down Expand Up @@ -1051,6 +1054,7 @@ def _launch_tuning_job(
train_steps: Optional[int] = None,
model_display_name: Optional[str] = None,
learning_rate: Optional[float] = None,
tuning_job_location: str = _TUNING_LOCATIONS[0],
) -> aiplatform.PipelineJob:
output_dir_uri = _generate_tuned_model_dir_uri(model_id=model_id)
if isinstance(training_data, str):
Expand All @@ -1073,6 +1077,7 @@ def _launch_tuning_job(
tuning_pipeline_uri=tuning_pipeline_uri,
model_display_name=model_display_name,
learning_rate=learning_rate,
tuning_job_location=tuning_job_location,
)
return job

Expand All @@ -1084,6 +1089,7 @@ def _launch_tuning_job_on_jsonl_data(
train_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
model_display_name: Optional[str] = None,
tuning_job_location: str = _TUNING_LOCATIONS[0],
) -> aiplatform.PipelineJob:
if not model_display_name:
# Creating a human-readable model display name
Expand Down Expand Up @@ -1126,7 +1132,7 @@ def _launch_tuning_job_on_jsonl_data(
display_name=None,
parameter_values=pipeline_arguments,
# TODO(b/275444101): Remove the explicit location once model can be deployed in all regions
location=_TUNING_LOCATION,
location=tuning_job_location,
)
job.submit()
return job
Expand Down

0 comments on commit 4aa7745

Please sign in to comment.