Skip to content

Commit

Permalink
feat: LLM - Support model evaluation when tuning chat models (`ChatMo…
Browse files Browse the repository at this point in the history
…del`, `CodeChatModel`)

PiperOrigin-RevId: 580611746
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 8, 2023
1 parent dcb6205 commit 755c3f9
Showing 2 changed files with 74 additions and 1 deletion.
53 changes: 53 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
@@ -2125,12 +2125,18 @@ def test_tune_chat_model(
):
model = language_models.ChatModel.from_pretrained("chat-bison@001")

tuning_job_location = "europe-west4"
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"

default_context = "Default context"
tuning_job = model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
default_context=default_context,
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
tensorboard=tensorboard_name,
),
accelerator_type="TPU",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
@@ -2140,6 +2146,7 @@ def test_tune_chat_model(
assert pipeline_arguments["large_model_reference"] == "chat-bison@001"
assert pipeline_arguments["default_context"] == default_context
assert pipeline_arguments["accelerator_type"] == "TPU"
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name

# Testing the tuned model
tuned_model = tuning_job.get_tuned_model()
@@ -2148,6 +2155,26 @@ def test_tune_chat_model(
== test_constants.EndpointConstants._TEST_ENDPOINT_NAME
)

unsupported_tuning_evaluation_spec_att = (
{"evaluation_data": "gs://bucket/eval.jsonl"},
{"evaluation_interval": 37},
{"enable_early_stopping": True},
{"enable_checkpoint_selection": True},
)
for unsupported_att in unsupported_tuning_evaluation_spec_att:
unsupported_tuning_evaluation_spec = (
preview_language_models.TuningEvaluationSpec(**unsupported_att)
)
with pytest.raises(AttributeError):
model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
default_context=default_context,
tuning_evaluation_spec=unsupported_tuning_evaluation_spec,
accelerator_type="TPU",
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
@@ -2228,12 +2255,18 @@ def test_tune_code_chat_model(
):
model = language_models.CodeChatModel.from_pretrained("codechat-bison@001")

tuning_job_location = "europe-west4"
tensorboard_name = f"projects/{_TEST_PROJECT}/locations/{tuning_job_location}/tensorboards/123"

# The tune_model call needs to be inside the PublisherModel mock
# since it gets a new PublisherModel when tuning completes.
model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
tuning_evaluation_spec=preview_language_models.TuningEvaluationSpec(
tensorboard=tensorboard_name,
),
accelerator_type="TPU",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
@@ -2242,6 +2275,26 @@ def test_tune_code_chat_model(
].runtime_config.parameter_values
assert pipeline_arguments["large_model_reference"] == "codechat-bison@001"
assert pipeline_arguments["accelerator_type"] == "TPU"
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name

unsupported_tuning_evaluation_spec_att = (
{"evaluation_data": "gs://bucket/eval.jsonl"},
{"evaluation_interval": 37},
{"enable_early_stopping": True},
{"enable_checkpoint_selection": True},
)
for unsupported_att in unsupported_tuning_evaluation_spec_att:
unsupported_tuning_evaluation_spec = (
preview_language_models.TuningEvaluationSpec(**unsupported_att)
)
with pytest.raises(AttributeError):
model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
tuning_evaluation_spec=unsupported_tuning_evaluation_spec,
accelerator_type="TPU",
)

@pytest.mark.usefixtures(
"get_model_with_tuned_version_label_mock",
22 changes: 21 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
@@ -496,6 +496,7 @@ def tune_model(
model_display_name: Optional[str] = None,
default_context: Optional[str] = None,
accelerator_type: Optional[_ACCELERATOR_TYPE_TYPE] = None,
tuning_evaluation_spec: Optional["TuningEvaluationSpec"] = None,
) -> "_LanguageModelTuningJob":
"""Tunes a model based on training data.
@@ -520,6 +521,7 @@ def tune_model(
model_display_name: Custom display name for the tuned model.
default_context: The context to use for all training samples by default.
accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
tuning_evaluation_spec: Specification for the model evaluation during tuning.
Returns:
A `LanguageModelTuningJob` object that represents the tuning job.
@@ -529,8 +531,25 @@ def tune_model(
ValueError: If the "tuning_job_location" value is not supported
ValueError: If the "tuned_model_location" value is not supported
RuntimeError: If the model does not support tuning
AttributeError: If any attribute in the "tuning_evaluation_spec" is not supported
"""
# Note: Chat models do not support tuning_evaluation_spec

if tuning_evaluation_spec is not None:
unsupported_chat_model_tuning_eval_spec = {
"evaluation_data": tuning_evaluation_spec.evaluation_data,
"evaluation_interval": tuning_evaluation_spec.evaluation_interval,
"enable_early_stopping": tuning_evaluation_spec.enable_early_stopping,
"enable_checkpoint_selection": tuning_evaluation_spec.enable_checkpoint_selection,
}

for att_name, att_value in unsupported_chat_model_tuning_eval_spec.items():
if not att_value is None:
raise AttributeError(
(
f"ChatModel and CodeChatModel only support tensorboard as attribute for TuningEvaluationSpec"
f"found attribute name {att_name} with value {att_value}, please leave {att_name} to None"
)
)
return super().tune_model(
training_data=training_data,
train_steps=train_steps,
@@ -540,6 +559,7 @@ def tune_model(
model_display_name=model_display_name,
default_context=default_context,
accelerator_type=accelerator_type,
tuning_evaluation_spec=tuning_evaluation_spec,
)


0 comments on commit 755c3f9

Please sign in to comment.