Skip to content

Commit

Permalink
fix: LLM - Make tuning use the global staging bucket if specified
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574646855
  • Loading branch information
Ark-kun authored and copybara-github committed Oct 18, 2023
1 parent 19dd980 commit d9ced10
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/unit/aiplatform/test_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,53 @@ def test_tune_text_generation_model_evaluation_with_only_tensorboard(
].runtime_config.parameter_values
assert pipeline_arguments["tensorboard_resource_id"] == tensorboard_name

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
)
@pytest.mark.parametrize(
"mock_request_urlopen",
["https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"],
indirect=True,
)
def test_tune_text_generation_model_staging_bucket(
self,
mock_pipeline_service_create,
mock_pipeline_job_get,
mock_pipeline_bucket_exists,
job_spec,
mock_load_yaml_and_json,
mock_gcs_from_string,
mock_gcs_upload,
mock_request_urlopen,
mock_get_tuned_model,
):
"""Tests that tune_model respects staging_bucket."""
TEST_STAGING_BUCKET = "gs://test_staging_bucket/path/"
aiplatform.init(staging_bucket=TEST_STAGING_BUCKET)

with mock.patch.object(
target=model_garden_service_client.ModelGardenServiceClient,
attribute="get_publisher_model",
return_value=gca_publisher_model.PublisherModel(
_TEXT_BISON_PUBLISHER_MODEL_DICT
),
):
model = language_models.TextGenerationModel.from_pretrained(
"text-bison@001"
)

model.tune_model(
training_data=_TEST_TEXT_BISON_TRAINING_DF,
tuning_job_location="europe-west4",
tuned_model_location="us-central1",
)
call_kwargs = mock_pipeline_service_create.call_args[1]
pipeline_arguments = call_kwargs[
"pipeline_job"
].runtime_config.parameter_values
assert pipeline_arguments["dataset_uri"].startswith(TEST_STAGING_BUCKET)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON],
Expand Down
6 changes: 6 additions & 0 deletions vertexai/language_models/_language_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,6 +2534,12 @@ def _cancel(self):


def _get_tuned_models_dir_uri(model_id: str) -> str:
if aiplatform_initializer.global_config.staging_bucket:
return (
aiplatform_initializer.global_config.staging_bucket
+ "/tuned_language_models/"
+ model_id
)
staging_gcs_bucket = (
gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist()
)
Expand Down

0 comments on commit d9ced10

Please sign in to comment.