diff --git a/tests/unit/aiplatform/test_language_models.py b/tests/unit/aiplatform/test_language_models.py index e246f2cd19..fe04517fed 100644 --- a/tests/unit/aiplatform/test_language_models.py +++ b/tests/unit/aiplatform/test_language_models.py @@ -57,6 +57,7 @@ language_models as preview_language_models, ) from vertexai import language_models +from vertexai.language_models import _language_models from google.cloud.aiplatform_v1 import Execution as GapicExecution from google.cloud.aiplatform.compat.types import ( encryption_spec as gca_encryption_spec, @@ -471,7 +472,7 @@ def get_endpoint_mock(): @pytest.fixture def mock_get_tuned_model(get_endpoint_mock): with mock.patch.object( - preview_language_models.TextGenerationModel, "get_tuned_model" + _language_models._TunableModelMixin, "get_tuned_model" ) as mock_text_generation_model: mock_text_generation_model._model_id = ( test_constants.ModelConstants._TEST_MODEL_RESOURCE_NAME @@ -634,7 +635,7 @@ def test_text_generation_ga(self): ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], indirect=True, ) - def test_tune_model( + def test_tune_text_generation_model( self, mock_pipeline_service_create, mock_pipeline_job_get, @@ -680,6 +681,49 @@ def test_tune_model( == _TEST_ENCRYPTION_KEY_NAME ) + @pytest.mark.parametrize( + "job_spec", + [_TEST_PIPELINE_SPEC_JSON], + ) + @pytest.mark.parametrize( + "mock_request_urlopen", + ["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"], + indirect=True, + ) + def test_tune_chat_model( + self, + mock_pipeline_service_create, + mock_pipeline_job_get, + mock_pipeline_bucket_exists, + job_spec, + mock_load_yaml_and_json, + mock_gcs_from_string, + mock_gcs_upload, + mock_request_urlopen, + mock_get_tuned_model, + ): + """Tests tuning a chat model.""" + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with mock.patch.object( + target=model_garden_service_client.ModelGardenServiceClient, + attribute="get_publisher_model", + return_value=gca_publisher_model.PublisherModel( + _CHAT_BISON_PUBLISHER_MODEL_DICT + ), + ): + model = preview_language_models.ChatModel.from_pretrained("chat-bison@001") + + model.tune_model( + training_data=_TEST_TEXT_BISON_TRAINING_DF, + tuning_job_location="europe-west4", + tuned_model_location="us-central1", + ) + call_kwargs = mock_pipeline_service_create.call_args[1] + pipeline_arguments = call_kwargs[ + "pipeline_job" + ].runtime_config.parameter_values + assert pipeline_arguments["large_model_reference"] == "chat-bison@001" + @pytest.mark.usefixtures( "get_model_with_tuned_version_label_mock", "get_endpoint_with_models_mock", diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index 30f71398e6..0fa77edb54 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -33,6 +33,7 @@ _SHORT_MODEL_ID_TO_TUNING_PIPELINE_MAP = { "text-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v2.0.0", "code-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-model/v3.0.0", + "chat-bison": "https://us-kfp.pkg.dev/ml-pipeline/large-language-model-pipelines/tune-large-chat-model/v3.0.0", } _SDK_PRIVATE_PREVIEW_LAUNCH_STAGE = frozenset( diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 2795d94bd8..8aa67ec92f 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -692,7 +692,7 @@ class ChatModel(_ChatModelBase): _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/chat_generation_1.0.0.yaml" -class _PreviewChatModel(ChatModel): +class _PreviewChatModel(ChatModel, _TunableModelMixin): _LAUNCH_STAGE = _model_garden_models._SDK_PUBLIC_PREVIEW_LAUNCH_STAGE