Skip to content

Commit

Permalink
feat: LLM - Added support for learning_rate in tuning
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542784145
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 23, 2023
1 parent 750e161 commit c6cdd10
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/system/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_tuning(self, shared_state):
train_steps=1,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
learning_rate=2.0,
)
# According to the Pipelines design, external resources created by a pipeline
# must not be modified or deleted. Otherwise caching will break next pipeline runs.
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,8 +661,13 @@ def test_tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
learning_rate=0.1,
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["learning_rate"] == 0.1
assert (
call_kwargs["pipeline_job"].encryption_spec.kms_key_name
== _TEST_ENCRYPTION_KEY_NAME
Expand Down
13 changes: 12 additions & 1 deletion vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def tune_model(
training_data: Union[str, "pandas.core.frame.DataFrame"],
*,
train_steps: int = 1000,
learning_rate: Optional[float] = None,
tuning_job_location: Optional[str] = None,
tuned_model_location: Optional[str] = None,
model_display_name: Optional[str] = None,
Expand All @@ -151,6 +152,7 @@ def tune_model(
training_data: A Pandas DataFrame of a URI pointing to data in JSON lines format.
The dataset must have the "input_text" and "output_text" columns.
train_steps: Number of training steps to perform.
learning_rate: Learning rate for the tuning
tuning_job_location: GCP location where the tuning job should be run. Only "europe-west4" is supported for now.
tuned_model_location: GCP location where the tuned model should be deployed. Only "us-central1" is supported for now.
model_display_name: Custom display name for the tuned model.
Expand Down Expand Up @@ -184,6 +186,7 @@ def tune_model(
model_id=model_info.tuning_model_id,
tuning_pipeline_uri=model_info.tuning_pipeline_uri,
model_display_name=model_display_name,
learning_rate=learning_rate,
)

job = _LanguageModelTuningJob(
Expand Down Expand Up @@ -1041,6 +1044,7 @@ def _launch_tuning_job(
tuning_pipeline_uri: str,
train_steps: Optional[int] = None,
model_display_name: Optional[str] = None,
learning_rate: Optional[float] = None,
) -> aiplatform.PipelineJob:
output_dir_uri = _generate_tuned_model_dir_uri(model_id=model_id)
if isinstance(training_data, str):
Expand All @@ -1062,6 +1066,7 @@ def _launch_tuning_job(
train_steps=train_steps,
tuning_pipeline_uri=tuning_pipeline_uri,
model_display_name=model_display_name,
learning_rate=learning_rate,
)
return job

Expand All @@ -1071,11 +1076,15 @@ def _launch_tuning_job_on_jsonl_data(
dataset_name_or_uri: str,
tuning_pipeline_uri: str,
train_steps: Optional[int] = None,
learning_rate: Optional[float] = None,
model_display_name: Optional[str] = None,
) -> aiplatform.PipelineJob:
if not model_display_name:
# Creating a human-readable model display name
name = f"{model_id} tuned for {train_steps} steps on "
name = f"{model_id} tuned for {train_steps} steps"
if learning_rate:
name += f" with learning rate {learning_rate}"
name += " on "
# Truncating the start of the dataset URI to keep total length <= 128.
max_display_name_length = 128
if len(dataset_name_or_uri + name) <= max_display_name_length:
Expand All @@ -1095,6 +1104,8 @@ def _launch_tuning_job_on_jsonl_data(
"large_model_reference": model_id,
"model_display_name": model_display_name,
}
if learning_rate:
pipeline_arguments["learning_rate"] = learning_rate

if dataset_name_or_uri.startswith("projects/"):
pipeline_arguments["dataset_name"] = dataset_name_or_uri
Expand Down

0 comments on commit c6cdd10

Please sign in to comment.