diff --git a/vertexai/language_models/_language_models.py b/vertexai/language_models/_language_models.py index 6e4bb76032..a0a7fe60d8 100644 --- a/vertexai/language_models/_language_models.py +++ b/vertexai/language_models/_language_models.py @@ -42,7 +42,7 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str: """Gets the base model ID for the model ID labels used the tuned models. Args: - tuning_model_id: The model ID used in tuning + tuning_model_id: The model ID used in tuning. E.g. `text-bison-001` Returns: The publisher model ID @@ -50,13 +50,9 @@ def _get_model_id_from_tuning_model_id(tuning_model_id: str) -> str: Raises: ValueError: If tuning model ID is unsupported """ - if tuning_model_id.startswith("text-bison-"): - return tuning_model_id.replace( - "text-bison-", "publishers/google/models/text-bison@" - ) - if "/" not in tuning_model_id: - return "publishers/google/models/" + tuning_model_id - return tuning_model_id + model_name, _, version = tuning_model_id.rpartition("-") + # "publishers/google/models/text-bison@001" + return f"publishers/google/models/{model_name}@{version}" class _LanguageModel(_model_garden_models._ModelGardenModel): @@ -203,6 +199,7 @@ def tune_model( tuned_model = job.result() # The UXR study attendees preferred to tune model in place self._endpoint = tuned_model._endpoint + self._endpoint_name = tuned_model._endpoint_name @dataclasses.dataclass