Skip to content

Commit

Permalink
feat: LLM - Added tuning support for chat-bison models
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555782339
  • Loading branch information
Ark-kun authored and copybara-github committed Aug 11, 2023
1 parent 06c9d18 commit 3a97c52
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
48 changes: 46 additions & 2 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
language_models as preview_language_models,
)
from vertexai import language_models
from vertexai.language_models import _language_models
from google.cloud.aiplatform_v1 import Execution as GapicExecution
from google.cloud.aiplatform.compat.types import (
encryption_spec as gca_encryption_spec,
Expand Down Expand Up @@ -471,7 +472,7 @@ def get_endpoint_mock():
@pytest.fixture
def mock_get_tuned_model(get_endpoint_mock):
with mock.patch.object(
preview_language_models.TextGenerationModel, "get_tuned_model"
_language_models._TunableModelMixin, "get_tuned_model"
) as mock_text_generation_model:
mock_text_generation_model._model_id = (
test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME
Expand Down Expand Up @@ -634,7 +635,7 @@ def test_text_generation_ga(self):
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_tune_model(
def test_tune_text_generation_model(
self,
mock_pipeline_service_create,
mock_pipeline_job_get,
Expand Down Expand Up @@ -680,6 +681,49 @@ def test_tune_model(
== _TEST_ENCRYPTION_KEY_NAME
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
)
@pytest.mark.parametrize(
"mock_request_urlopen",
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_tune_chat_model(
self,
mock_pipeline_service_create,
mock_pipeline_job_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
mock_gcs_from_string,
mock_gcs_upload,
mock_request_urlopen,
mock_get_tuned_model,
):
"""Tests tuning a chat model."""
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)
with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_CHAT_BISON_PUBLISHER_MODEL_DICT
),
):
model = preview_language_models.ChatModel.from_pretrained("chat-bison@001")

model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"

@pytest.mark.usefixtures(
"get_model_with_tuned_version_label_mock",
"get_endpoint_with_models_mock",
Expand Down
1 change: 1 addition & 0 deletions vertexai/_model_garden/_model_garden_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = {
"text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0",
"code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0",
"chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0",
}

_SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset(
Expand Down
2 changes: 1 addition & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ class ChatModel(_ChatModelBase):
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml"


class _PreviewChatModel(ChatModel):
class _PreviewChatModel(ChatModel, _TunableModelMixin):
_LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE


Expand Down

0 comments on commit 3a97c52

Please sign in to comment.