Skip to content

Commit

Permalink
feat: LLM - Made tuning asynchronous when tuning becomes GA
Browse files Browse the repository at this point in the history
Previously, `tune_model` waited for the tuning is complete, then modified the model in-place.
This behavior will change in the future GA (non-preview) classes:

In the future, `tune_model` will become asynchronous: It will start tuning job and return a job object immediately without waiting. This will allow the user to do other work while the model is being tuned. This will also allow the user to perform multiple tuning jobs in parallel.

Future breaking change: The model will no longer be updated in-place, so the user will need to get the tuned model from the job object.

To make the transition easier and avoid breaking changes, the `.tune_model(...)` method will start returning the job object even in preview classes (although it will still wait for the job completion and update the model in-place too). This makes it possible to start writing future-proof code immediately.

Usage:

```
tuning_job = model.tune_model(...)  # Returns tuning job. In preview: Waits for the tuning job to finish
tuned_model = tuning_job.get_tuned_model()  # Returns tuned model after waiting for the tuning job to finish.
```
PiperOrigin-RevId: 558554561
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 20, 2023
1 parent e6d1e95 commit 226ab8b
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 26 deletions.
14 changes: 13 additions & 1 deletion tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_tuning(self, shared_state):
df=training_data, upload_gcs_path=dataset_uri
)

model.tune_model(
tuning_job = model.tune_model(
training_data=training_data,
train_steps=1,
tuning_job_location="europe-west4",
Expand All @@ -211,6 +211,18 @@ def test_tuning(self, shared_state):
)
# Deleting the Endpoint is a little less bad since the LLM SDK will recreate it, but it's not advised for the same reason.

# Testing the new model returned by the `tuning_job.get_tuned_model` method
tuned_model1 = tuning_job.get_tuned_model()
response1 = tuned_model1.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
temperature=0,
top_p=1,
top_k=5,
)
assert response1.text

# Testing the model updated in-place (Deprecated. Preview only)
response = model.predict(
"What is the best recipe for banana bread? Recipe:",
max_output_tokens=128,
Expand Down
24 changes: 19 additions & 5 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,13 +1039,13 @@ def mock_get_tuned_model(get_endpoint_mock):
with mock.patch.object(
_language_models._TunableModelMixin, "get_tuned_model"
) as mock_text_generation_model:
mock_text_generation_model._model_id = (
mock_text_generation_model.return_value._model_id = (
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
)
mock_text_generation_model._endpoint_name = (
mock_text_generation_model.return_value._endpoint_name = (
test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)
mock_text_generation_model._endpoint = get_endpoint_mock
mock_text_generation_model.return_value._endpoint = get_endpoint_mock
yield mock_text_generation_model


Expand Down Expand Up @@ -1344,7 +1344,7 @@ def test_tune_text_generation_model(
enable_early_stopping = True
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"

model.tune_model(
tuning_job = model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location=tuning_job_location,
tuned_model_location="us-central1",
Expand Down Expand Up @@ -1375,6 +1375,13 @@ def test_tune_text_generation_model(
== _TEST_ENCRYPTION_KEY_NAME
)

# Testing the tuned model
tuned_model = tuning_job.get_tuned_model()
assert (
tuned_model._endpoint_name
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
Expand Down Expand Up @@ -1408,7 +1415,7 @@ def test_tune_chat_model(
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")

default_context = "Default context"
model.tune_model(
tuning_job = model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
Expand All @@ -1421,6 +1428,13 @@ def test_tune_chat_model(
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
assert pipeline_arguments["default_context"] == default_context

# Testing the tuned model
tuned_model = tuning_job.get_tuned_model()
assert (
tuned_model._endpoint_name
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
Expand Down
188 changes: 168 additions & 20 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer as aiplatform_initializer
from google.cloud.aiplatform import utils as aiplatform_utils
from google.cloud.aiplatform.compat import types as aiplatform_types
from google.cloud.aiplatform.utils import gcs_utils
from vertexai._model_garden import _model_garden_models
from vertexai.language_models import (
Expand Down Expand Up @@ -148,18 +149,24 @@ def tune_model(
self,
training_data: Union[str, "pandas.core.frame.DataFrame"],
*,
train_steps: int = 1000,
train_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
learning_rate_multiplier: Optional[float] = None,
tuning_job_location: Optional[str] = None,
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
default_context: Optional[str] = None,
):
) -> "_LanguageModelTuningJob":
"""Tunes a model based on training data.
This method launches a model tuning job that can take some time.
This method launches and returns an asynchronous model tuning job.
Usage:
```
tuning_job = model.tune_model(...)
... do some other work
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
```
Args:
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
Expand Down Expand Up @@ -303,16 +310,68 @@ def _tune_model(
base_model=self,
job=pipeline_job,
)
self._job = job
tuned_model = job.result()
# The UXR study attendees preferred to tune model in place
self._endpoint = tuned_model._endpoint
self._endpoint_name = tuned_model._endpoint_name
return job


class _TunableTextModelMixin(_TunableModelMixin):
"""Text model that can be tuned."""

def tune_model(
self,
training_data: Union[str, "pandas.core.frame.DataFrame"],
*,
train_steps: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
tuning_job_location: Optional[str] = None,
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
) -> "_LanguageModelTuningJob":
"""Tunes a model based on training data.
This method launches and returns an asynchronous model tuning job.
Usage:
```
tuning_job = model.tune_model(...)
... do some other work
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
Args:
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
The dataset schema is model-specific.
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
train_steps: Number of training batches to tune on (batch size is 8 samples).
learning_rate_multiplier: Learning rate multiplier to use in tuning.
tuning_job_location: GCP location where the tuning job should be run.
Only "europe-west4" and "us-central1" locations are supported for now.
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
model_display_name: Custom display name for the tuned model.
tuning_evaluation_spec: Specification for the model evaluation during tuning.
Returns:
A `LanguageModelTuningJob` object that represents the tuning job.
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
Raises:
ValueError: If the "tuning_job_location" value is not supported
ValueError: If the "tuned_model_location" value is not supported
RuntimeError: If the model does not support tuning
"""
# Note: Chat models do not support default_context
return super().tune_model(
training_data=training_data,
train_steps=train_steps,
learning_rate_multiplier=learning_rate_multiplier,
tuning_job_location=tuning_job_location,
tuned_model_location=tuned_model_location,
model_display_name=model_display_name,
tuning_evaluation_spec=tuning_evaluation_spec,
)


class _PreviewTunableTextModelMixin(_TunableModelMixin):
"""Text model that can be tuned."""

def tune_model(
self,
training_data: Union[str, "pandas.core.frame.DataFrame"],
Expand All @@ -324,10 +383,20 @@ def tune_model(
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
):
) -> "_LanguageModelTuningJob":
"""Tunes a model based on training data.
This method launches a model tuning job that can take some time.
This method launches a model tuning job, waits for completion,
updates the model in-place. This method returns job object for forward
compatibility.
In the future (GA), this method will become asynchronous and will stop
updating the model in-place.
Usage:
```
tuning_job = model.tune_model(...) # Blocks until tuning is complete
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
```
Args:
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
Expand All @@ -353,7 +422,7 @@ def tune_model(
RuntimeError: If the model does not support tuning
"""
# Note: Chat models do not support default_context
return super().tune_model(
job = super().tune_model(
training_data=training_data,
train_steps=train_steps,
learning_rate=learning_rate,
Expand All @@ -363,11 +432,74 @@ def tune_model(
model_display_name=model_display_name,
tuning_evaluation_spec=tuning_evaluation_spec,
)
tuned_model = job.get_tuned_model()
self._endpoint = tuned_model._endpoint
self._endpoint_name = tuned_model._endpoint_name
return job


class _TunableChatModelMixin(_TunableModelMixin):
"""Chat model that can be tuned."""

def tune_model(
self,
training_data: Union[str, "pandas.core.frame.DataFrame"],
*,
train_steps: Optional[int] = None,
learning_rate_multiplier: Optional[float] = None,
tuning_job_location: Optional[str] = None,
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
default_context: Optional[str] = None,
) -> "_LanguageModelTuningJob":
"""Tunes a model based on training data.
This method launches and returns an asynchronous model tuning job.
Usage:
```
tuning_job = model.tune_model(...)
... do some other work
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
```
Args:
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
The dataset schema is model-specific.
See https://cloud.google.com/vertex-ai/docs/generative-ai/models/tune-models#dataset_format
train_steps: Number of training batches to tune on (batch size is 8 samples).
learning_rate: Deprecated. Use learning_rate_multiplier instead.
Learning rate to use in tuning.
learning_rate_multiplier: Learning rate multiplier to use in tuning.
tuning_job_location: GCP location where the tuning job should be run.
Only "europe-west4" and "us-central1" locations are supported for now.
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
model_display_name: Custom display name for the tuned model.
default_context: The context to use for all training samples by default.
Returns:
A `LanguageModelTuningJob` object that represents the tuning job.
Calling `job.result()` blocks until the tuning is complete and returns a `LanguageModel` object.
Raises:
ValueError: If the "tuning_job_location" value is not supported
ValueError: If the "tuned_model_location" value is not supported
RuntimeError: If the model does not support tuning
"""
# Note: Chat models do not support tuning_evaluation_spec
return super().tune_model(
training_data=training_data,
train_steps=train_steps,
learning_rate_multiplier=learning_rate_multiplier,
tuning_job_location=tuning_job_location,
tuned_model_location=tuned_model_location,
model_display_name=model_display_name,
default_context=default_context,
)


class _PreviewTunableChatModelMixin(_TunableModelMixin):
"""Chat model that can be tuned."""

def tune_model(
self,
training_data: Union[str, "pandas.core.frame.DataFrame"],
Expand All @@ -379,10 +511,20 @@ def tune_model(
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
default_context: Optional[str] = None,
):
) -> "_LanguageModelTuningJob":
"""Tunes a model based on training data.
This method launches a model tuning job that can take some time.
This method launches a model tuning job, waits for completion,
updates the model in-place. This method returns job object for forward
compatibility.
In the future (GA), this method will become asynchronous and will stop
updating the model in-place.
Usage:
```
tuning_job = model.tune_model(...) # Blocks until tuning is complete
tuned_model = tuning_job.get_tuned_model() # Blocks until tuning is complete
```
Args:
training_data: A Pandas DataFrame or a URI pointing to data in JSON lines format.
Expand All @@ -408,7 +550,7 @@ def tune_model(
RuntimeError: If the model does not support tuning
"""
# Note: Chat models do not support tuning_evaluation_spec
return super().tune_model(
job = super().tune_model(
training_data=training_data,
train_steps=train_steps,
learning_rate=learning_rate,
Expand All @@ -418,6 +560,10 @@ def tune_model(
model_display_name=model_display_name,
default_context=default_context,
)
tuned_model = job.get_tuned_model()
self._endpoint = tuned_model._endpoint
self._endpoint_name = tuned_model._endpoint_name
return job


@dataclasses.dataclass
Expand Down Expand Up @@ -746,7 +892,7 @@ class TextGenerationModel(_TextGenerationModel, _ModelWithBatchPredict):

class _PreviewTextGenerationModel(
_TextGenerationModel,
_TunableTextModelMixin,
_PreviewTunableTextModelMixin,
_PreviewModelWithBatchPredict,
_evaluatable_language_models._EvaluatableLanguageModel,
):
Expand Down Expand Up @@ -1076,7 +1222,7 @@ class ChatModel(_ChatModelBase):
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"


class _PreviewChatModel(ChatModel, _TunableChatModelMixin):
class _PreviewChatModel(ChatModel, _PreviewTunableChatModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


Expand Down Expand Up @@ -1650,11 +1796,12 @@ def __init__(
base_model: _LanguageModel,
job: aiplatform.PipelineJob,
):
"""Internal constructor. Do not call directly."""
self._base_model = base_model
self._job = job
self._model: Optional[_LanguageModel] = None

def result(self) -> "_LanguageModel":
def get_tuned_model(self) -> "_LanguageModel":
"""Blocks until the tuning is complete and returns a `LanguageModel` object."""
if self._model:
return self._model
Expand All @@ -1681,11 +1828,12 @@ def result(self) -> "_LanguageModel":
return self._model

@property
def status(self):
"""Job status"""
def _status(self) -> Optional[aiplatform_types.pipeline_state.PipelineState]:
"""Job status."""
return self._job.state

def cancel(self):
def _cancel(self):
"""Cancels the job."""
self._job.cancel()


Expand Down

0 comments on commit 226ab8b

Please sign in to comment.