diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java index 61931a9fd2e..495f0f88598 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CancelBatchPredictionJobSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,9 @@ // [START aiplatform_cancel_batch_prediction_job_sample] -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJobName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.BatchPredictionJobName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; import java.io.IOException; public class CancelBatchPredictionJobSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelDataLabelingJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelDataLabelingJobSample.java index 9483c07e4dc..eb540687edf 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CancelDataLabelingJobSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CancelDataLabelingJobSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,9 @@ // [START aiplatform_cancel_data_labeling_job_sample] -import com.google.cloud.aiplatform.v1beta1.DataLabelingJobName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.DataLabelingJobName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; import java.io.IOException; public class CancelDataLabelingJobSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java b/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java index 4dd2902f328..a689ae24625 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CancelTrainingPipelineSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,9 @@ // [START aiplatform_cancel_training_pipeline_sample] -import com.google.cloud.aiplatform.v1beta1.PipelineServiceClient; -import com.google.cloud.aiplatform.v1beta1.PipelineServiceSettings; -import com.google.cloud.aiplatform.v1beta1.TrainingPipelineName; +import com.google.cloud.aiplatform.v1.PipelineServiceClient; +import com.google.cloud.aiplatform.v1.PipelineServiceSettings; +import com.google.cloud.aiplatform.v1.TrainingPipelineName; import java.io.IOException; public class CancelTrainingPipelineSample { diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java index 5ccad051aaa..105268f2e8b 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobBigquerySample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,13 @@ package aiplatform; // [START aiplatform_create_batch_prediction_job_bigquery_sample] -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; -import com.google.cloud.aiplatform.v1beta1.BigQueryDestination; -import com.google.cloud.aiplatform.v1beta1.BigQuerySource; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.ModelName; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; import com.google.gson.JsonObject; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; @@ -95,8 +95,6 @@ static void createBatchPredictionJobBigquerySample( .setModelParameters(modelParameters) .setInputConfig(inputConfig) .setOutputConfig(outputConfig) - // optional - .setGenerateExplanation(true) .build(); LocationName parent = LocationName.of(project, location); BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java index cdac97ba47e..12bab04e13b 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,19 +17,18 @@ package aiplatform; // [START aiplatform_create_batch_prediction_job_sample] -import com.google.cloud.aiplatform.v1beta1.AcceleratorType; -import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; -import com.google.cloud.aiplatform.v1beta1.GcsDestination; -import com.google.cloud.aiplatform.v1beta1.GcsSource; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.MachineSpec; -import com.google.cloud.aiplatform.v1beta1.ModelName; -import com.google.gson.JsonObject; +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.AcceleratorType; +import com.google.cloud.aiplatform.v1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ModelName; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import java.io.IOException; public class CreateBatchPredictionJobSample { @@ -74,10 +73,7 @@ static void createBatchPredictionJobSample( try (JobServiceClient client = JobServiceClient.create(settings)) { // Passing in an empty Value object for model parameters - JsonObject jsonModelParameters = new JsonObject(); - Value.Builder modelParametersBuilder = Value.newBuilder(); - JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder); - Value modelParameters = modelParametersBuilder.build(); + Value modelParameters = ValueConverter.EMPTY_VALUE; GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); BatchPredictionJob.InputConfig inputConfig = diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextClassificationSample.java new file mode 100644 index 00000000000..ba79bf14b02 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextClassificationSample.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_text_classification_sample] +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; + +public class CreateBatchPredictionJobTextClassificationSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String location = "us-central1"; + String displayName = "DISPLAY_NAME"; + String modelId = "MODEL_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobTextClassificationSample( + project, location, displayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobTextClassificationSample( + String project, + String location, + String displayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + // The AI Platform services require regional API endpoints. + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + try { + String modelName = ModelName.of(project, location, modelId).toString(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + } catch (ApiException ex) { + System.out.format("Exception: %s\n", ex.getLocalizedMessage()); + } + } + } +} + +// [END aiplatform_create_batch_prediction_job_text_classification_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSample.java new file mode 100644 index 00000000000..e753da2ed04 --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSample.java @@ -0,0 +1,95 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_text_entity_extraction_sample] +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; + +public class CreateBatchPredictionJobTextEntityExtractionSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String location = "us-central1"; + String displayName = "DISPLAY_NAME"; + String modelId = "MODEL_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobTextEntityExtractionSample( + project, location, displayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobTextEntityExtractionSample( + String project, + String location, + String displayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + // The AI Platform services require regional API endpoints. + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + try { + String modelName = ModelName.of(project, location, modelId).toString(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + System.out.format("\tname:%s\n", response.getName()); + } catch (ApiException ex) { + System.out.format("Exception: %s\n", ex.getLocalizedMessage()); + } + } + } +} + +// [END aiplatform_create_batch_prediction_job_text_entity_extraction_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSample.java new file mode 100644 index 00000000000..8191618c9fe --- /dev/null +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSample.java @@ -0,0 +1,94 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +// [START aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample] +import com.google.api.gax.rpc.ApiException; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; +import java.io.IOException; + +public class CreateBatchPredictionJobTextSentimentAnalysisSample { + + public static void main(String[] args) throws IOException { + // TODO(developer): Replace these variables before running the sample. + String project = "PROJECT"; + String location = "us-central1"; + String displayName = "DISPLAY_NAME"; + String modelId = "MODEL_ID"; + String gcsSourceUri = "GCS_SOURCE_URI"; + String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX"; + createBatchPredictionJobTextSentimentAnalysisSample( + project, location, displayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix); + } + + static void createBatchPredictionJobTextSentimentAnalysisSample( + String project, + String location, + String displayName, + String modelId, + String gcsSourceUri, + String gcsDestinationOutputUriPrefix) + throws IOException { + // The AI Platform services require regional API endpoints. + JobServiceSettings settings = + JobServiceSettings.newBuilder() + .setEndpoint("us-central1-aiplatform.googleapis.com:443") + .build(); + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (JobServiceClient client = JobServiceClient.create(settings)) { + try { + String modelName = ModelName.of(project, location, modelId).toString(); + GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); + BatchPredictionJob.InputConfig inputConfig = + BatchPredictionJob.InputConfig.newBuilder() + .setInstancesFormat("jsonl") + .setGcsSource(gcsSource) + .build(); + GcsDestination gcsDestination = + GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build(); + BatchPredictionJob.OutputConfig outputConfig = + BatchPredictionJob.OutputConfig.newBuilder() + .setPredictionsFormat("jsonl") + .setGcsDestination(gcsDestination) + .build(); + BatchPredictionJob batchPredictionJob = + BatchPredictionJob.newBuilder() + .setDisplayName(displayName) + .setModel(modelName) + .setInputConfig(inputConfig) + .setOutputConfig(outputConfig) + .build(); + LocationName parent = LocationName.of(project, location); + BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob); + System.out.format("response: %s\n", response); + } catch (ApiException ex) { + System.out.format("Exception: %s\n", ex.getLocalizedMessage()); + } + } + } +} + +// [END aiplatform_create_batch_prediction_job_text_sentiment_analysis_sample] diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java index b255b625ccd..0d0f68e5418 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoActionRecognitionSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,16 +17,15 @@ package aiplatform; // [START aiplatform_create_batch_prediction_job_video_action_recognition_sample] -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; -import com.google.cloud.aiplatform.v1beta1.GcsDestination; -import com.google.cloud.aiplatform.v1beta1.GcsSource; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.ModelName; -import com.google.gson.JsonObject; +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.ModelName; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import java.io.IOException; public class CreateBatchPredictionJobVideoActionRecognitionSample { @@ -59,11 +58,7 @@ static void createBatchPredictionJobVideoActionRecognitionSample( // once, and can be reused for multiple requests. After completing all of your requests, call // the "close" method on the client to safely clean up any remaining background resources. try (JobServiceClient client = JobServiceClient.create(settings)) { - JsonObject jsonModelParameters = new JsonObject(); - jsonModelParameters.addProperty("confidenceThreshold", 0.5); - Value.Builder modelParametersBuilder = Value.newBuilder(); - JsonFormat.parser().merge(jsonModelParameters.toString(), modelParametersBuilder); - Value modelParameters = modelParametersBuilder.build(); + Value modelParameters = ValueConverter.EMPTY_VALUE; GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build(); BatchPredictionJob.InputConfig inputConfig = BatchPredictionJob.InputConfig.newBuilder() diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java index a89f2bfe3d5..905ab46b7c5 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoClassificationSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,26 +18,27 @@ // [START aiplatform_create_batch_prediction_job_video_classification_sample] -import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.InputConfig; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputConfig; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputInfo; -import com.google.cloud.aiplatform.v1beta1.BigQueryDestination; -import com.google.cloud.aiplatform.v1beta1.BigQuerySource; -import com.google.cloud.aiplatform.v1beta1.CompletionStats; -import com.google.cloud.aiplatform.v1beta1.GcsDestination; -import com.google.cloud.aiplatform.v1beta1.GcsSource; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.MachineSpec; -import com.google.cloud.aiplatform.v1beta1.ManualBatchTuningParameters; -import com.google.cloud.aiplatform.v1beta1.ModelName; -import com.google.cloud.aiplatform.v1beta1.ResourcesConsumed; +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.CompletionStats; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ManualBatchTuningParameters; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ResourcesConsumed; +import com.google.cloud.aiplatform.v1.schema.predict.params.VideoClassificationPredictionParams; import com.google.protobuf.Any; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; import java.util.List; @@ -75,11 +76,16 @@ static void createBatchPredictionJobVideoClassification( String location = "us-central1"; LocationName locationName = LocationName.of(project, location); - String jsonString = - "{\"confidenceThreshold\": 0.5,\"maxPredictions\": 10000,\"segmentClassification\":" - + " True,\"shotClassification\": True,\"oneSecIntervalClassification\": True}"; - Value.Builder modelParameters = Value.newBuilder(); - JsonFormat.parser().merge(jsonString, modelParameters); + VideoClassificationPredictionParams modelParamsObj = + VideoClassificationPredictionParams.newBuilder() + .setConfidenceThreshold(((float) 0.5)) + .setMaxPredictions(10000) + .setSegmentClassification(true) + .setShotClassification(true) + .setOneSecIntervalClassification(true) + .build(); + + Value modelParameters = ValueConverter.toValue(modelParamsObj); ModelName modelName = ModelName.of(project, location, modelId); GcsSource.Builder gcsSource = GcsSource.newBuilder(); @@ -112,8 +118,6 @@ static void createBatchPredictionJobVideoClassification( System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel()); System.out.format( "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters()); - System.out.format( - "\tGenerate Explanation: %s\n", batchPredictionJobResponse.getGenerateExplanation()); System.out.format("\tState: %s\n", batchPredictionJobResponse.getState()); System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java index da0550b2607..860bc8da82a 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,26 +18,27 @@ // [START aiplatform_create_batch_prediction_job_video_object_tracking_sample] -import com.google.cloud.aiplatform.v1beta1.BatchDedicatedResources; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.InputConfig; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputConfig; -import com.google.cloud.aiplatform.v1beta1.BatchPredictionJob.OutputInfo; -import com.google.cloud.aiplatform.v1beta1.BigQueryDestination; -import com.google.cloud.aiplatform.v1beta1.BigQuerySource; -import com.google.cloud.aiplatform.v1beta1.CompletionStats; -import com.google.cloud.aiplatform.v1beta1.GcsDestination; -import com.google.cloud.aiplatform.v1beta1.GcsSource; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.MachineSpec; -import com.google.cloud.aiplatform.v1beta1.ManualBatchTuningParameters; -import com.google.cloud.aiplatform.v1beta1.ModelName; -import com.google.cloud.aiplatform.v1beta1.ResourcesConsumed; +import com.google.cloud.aiplatform.util.ValueConverter; +import com.google.cloud.aiplatform.v1.BatchDedicatedResources; +import com.google.cloud.aiplatform.v1.BatchPredictionJob; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig; +import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputInfo; +import com.google.cloud.aiplatform.v1.BigQueryDestination; +import com.google.cloud.aiplatform.v1.BigQuerySource; +import com.google.cloud.aiplatform.v1.CompletionStats; +import com.google.cloud.aiplatform.v1.GcsDestination; +import com.google.cloud.aiplatform.v1.GcsSource; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.MachineSpec; +import com.google.cloud.aiplatform.v1.ManualBatchTuningParameters; +import com.google.cloud.aiplatform.v1.ModelName; +import com.google.cloud.aiplatform.v1.ResourcesConsumed; +import com.google.cloud.aiplatform.v1.schema.predict.params.VideoObjectTrackingPredictionParams; import com.google.protobuf.Any; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; import java.util.List; @@ -77,9 +78,12 @@ static void batchPredictionJobVideoObjectTracking( LocationName locationName = LocationName.of(project, location); ModelName modelName = ModelName.of(project, location, modelId); - String jsonString = "{\"confidenceThreshold\": 0.0}"; - Value.Builder modelParameters = Value.newBuilder(); - JsonFormat.parser().merge(jsonString, modelParameters); + VideoObjectTrackingPredictionParams modelParamsObj = + VideoObjectTrackingPredictionParams.newBuilder() + .setConfidenceThreshold(((float) 0.5)) + .build(); + + Value modelParameters = ValueConverter.toValue(modelParamsObj); GcsSource.Builder gcsSource = GcsSource.newBuilder(); gcsSource.addUris(gcsSourceUri); @@ -111,8 +115,6 @@ static void batchPredictionJobVideoObjectTracking( System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel()); System.out.format( "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters()); - System.out.format( - "\tGenerate Explanation: %s\n", batchPredictionJobResponse.getGenerateExplanation()); System.out.format("\tState: %s\n", batchPredictionJobResponse.getState()); System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java index d9f069e408f..1a0076fbc4b 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobActiveLearningSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,12 @@ package aiplatform; // [START aiplatform_create_data_labeling_job_active_learning_sample] -import com.google.cloud.aiplatform.v1beta1.ActiveLearningConfig; -import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; -import com.google.cloud.aiplatform.v1beta1.DatasetName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1.ActiveLearningConfig; +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.protobuf.Value; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java index 5ea70a42fef..8d9dced5ec7 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobImageSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,11 +18,11 @@ // [START aiplatform_create_data_labeling_job_image_sample] -import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; -import com.google.cloud.aiplatform.v1beta1.DatasetName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.type.Money; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSample.java index 2858a7e80d9..a677169d7bc 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSample.java @@ -18,11 +18,11 @@ // [START aiplatform_create_data_labeling_job_sample] -import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; -import com.google.cloud.aiplatform.v1beta1.DatasetName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.type.Money; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java index 04a3c421634..528e4b2d0f5 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobSpecialistPoolSample.java @@ -17,12 +17,12 @@ package aiplatform; // [START aiplatform_create_data_labeling_job_specialist_pool_sample] -import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; -import com.google.cloud.aiplatform.v1beta1.DatasetName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; -import com.google.cloud.aiplatform.v1beta1.SpecialistPoolName; +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; +import com.google.cloud.aiplatform.v1.SpecialistPoolName; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.google.protobuf.Value; @@ -78,8 +78,8 @@ static void createDataLabelingJobSpecialistPoolSample( Value inputs = inputsBuilder.build(); String datasetName = DatasetName.of(project, location, dataset).toString(); - String specialistPoolName = SpecialistPoolName.of(project, location, specialistPool) - .toString(); + String specialistPoolName = + SpecialistPoolName.of(project, location, specialistPool).toString(); DataLabelingJob dataLabelingJob = DataLabelingJob.newBuilder() diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java index ae0e451ba52..cabf2399735 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateDataLabelingJobVideoSample.java @@ -18,11 +18,11 @@ // [START aiplatform_create_data_labeling_job_video_sample] -import com.google.cloud.aiplatform.v1beta1.DataLabelingJob; -import com.google.cloud.aiplatform.v1beta1.DatasetName; -import com.google.cloud.aiplatform.v1beta1.JobServiceClient; -import com.google.cloud.aiplatform.v1beta1.JobServiceSettings; -import com.google.cloud.aiplatform.v1beta1.LocationName; +import com.google.cloud.aiplatform.v1.DataLabelingJob; +import com.google.cloud.aiplatform.v1.DatasetName; +import com.google.cloud.aiplatform.v1.JobServiceClient; +import com.google.cloud.aiplatform.v1.JobServiceSettings; +import com.google.cloud.aiplatform.v1.LocationName; import com.google.protobuf.Value; import com.google.protobuf.util.JsonFormat; import com.google.type.Money; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java index 739d15cf8ee..ea624de5b9d 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineCustomTrainingManagedDatasetSample.java @@ -113,7 +113,7 @@ static void createTrainingPipelineCustomTrainingManagedDatasetSample( .build(); GcsDestination gcsDestination = GcsDestination.newBuilder().setOutputUriPrefix(baseOutputUriPrefix).build(); - + // input_data_config InputDataConfig inputDataConfig = InputDataConfig.newBuilder() diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java index 9136cec90c4..78181e448ba 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageObjectDetectionSample.java @@ -39,11 +39,8 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassification; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs.ModelType; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java index d37bbf6eb6b..0ee0392dbea 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularClassificationSample.java @@ -41,8 +41,6 @@ import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; import java.util.ArrayList; @@ -55,15 +53,11 @@ public static void main(String[] args) throws IOException { String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; String datasetId = "YOUR_DATASET_ID"; String targetColumn = "TARGET_COLUMN"; - createTrainingPipelineTableClassification( - project, modelDisplayName, datasetId, targetColumn); + createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn); } static void createTrainingPipelineTableClassification( - String project, - String modelDisplayName, - String datasetId, - String targetColumn) + String project, String modelDisplayName, String datasetId, String targetColumn) throws IOException { PipelineServiceSettings pipelineServiceSettings = PipelineServiceSettings.newBuilder() @@ -81,18 +75,22 @@ static void createTrainingPipelineTableClassification( "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml"; // Set the columns used for training and their data types - Transformation transformation1 = Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build()) - .build(); - Transformation transformation2 = Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build()) - .build(); - Transformation transformation3 = Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build()) - .build(); - Transformation transformation4 = Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build()) - .build(); + Transformation transformation1 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build()) + .build(); + Transformation transformation2 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build()) + .build(); + Transformation transformation3 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build()) + .build(); + Transformation transformation4 = + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build()) + .build(); ArrayList transformationArrayList = new ArrayList<>(); transformationArrayList.add(transformation1); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java index ce5fff48089..f9f6ade398d 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTabularRegressionSample.java @@ -38,13 +38,10 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTables; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation; import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; import java.util.ArrayList; @@ -57,15 +54,11 @@ public static void main(String[] args) throws IOException { String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME"; String datasetId = "YOUR_DATASET_ID"; String targetColumn = "TARGET_COLUMN"; - createTrainingPipelineTableRegression( - project, modelDisplayName, datasetId, targetColumn); + createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn); } static void createTrainingPipelineTableRegression( - String project, - String modelDisplayName, - String datasetId, - String targetColumn) + String project, String modelDisplayName, String datasetId, String targetColumn) throws IOException { PipelineServiceSettings pipelineServiceSettings = PipelineServiceSettings.newBuilder() @@ -84,79 +77,106 @@ static void createTrainingPipelineTableRegression( // Set the columns used for training and their data types ArrayList tranformations = new ArrayList<>(); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setTimestamp(TimestampTransformation.newBuilder() - .setColumnName("TIMESTAMP_1unique_NULLABLE") - .setInvalidValuesAllowed(true)) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setTimestamp(TimestampTransformation.newBuilder() - .setColumnName("DATETIME_1unique_NULLABLE") - .setInvalidValuesAllowed(true)) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE")) - .build()); - tranformations.add(Transformation.newBuilder() - .setAuto(AutoTransformation.newBuilder() - .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE")) - .build()); - - AutoMlTablesInputs trainingTaskInputs = AutoMlTablesInputs.newBuilder() - .addAllTransformations(tranformations) - .setTargetColumn(targetColumn) - .setPredictionType("regression") - .setTrainBudgetMilliNodeHours(8000) - .setDisableEarlyStopping(false) - // supported regression optimisation objectives: minimize-rmse, - // minimize-mae, minimize-rmsle - .setOptimizationObjective("minimize-rmse") - .build(); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setTimestamp( + TimestampTransformation.newBuilder() + .setColumnName("TIMESTAMP_1unique_NULLABLE") + .setInvalidValuesAllowed(true)) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setTimestamp( + TimestampTransformation.newBuilder() + .setColumnName("DATETIME_1unique_NULLABLE") + .setInvalidValuesAllowed(true)) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE")) + .build()); + tranformations.add( + Transformation.newBuilder() + .setAuto( + AutoTransformation.newBuilder() + .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE")) + .build()); + + AutoMlTablesInputs trainingTaskInputs = + AutoMlTablesInputs.newBuilder() + .addAllTransformations(tranformations) + .setTargetColumn(targetColumn) + .setPredictionType("regression") + .setTrainBudgetMilliNodeHours(8000) + .setDisableEarlyStopping(false) + // supported regression optimisation objectives: minimize-rmse, + // minimize-mae, minimize-rmsle + .setOptimizationObjective("minimize-rmse") + .build(); FractionSplit fractionSplit = FractionSplit.newBuilder() diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java index 194f2a71dfb..dadd642c26a 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextClassificationSample.java @@ -77,9 +77,7 @@ static void createTrainingPipelineTextClassificationSample( LocationName locationName = LocationName.of(project, location); AutoMlTextClassificationInputs trainingTaskInputs = - AutoMlTextClassificationInputs.newBuilder() - .setMultiLabel(false) - .build(); + AutoMlTextClassificationInputs.newBuilder().setMultiLabel(false).build(); InputDataConfig trainingInputDataConfig = InputDataConfig.newBuilder().setDatasetId(datasetId).build(); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java index 577dc865a95..c62606c9886 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineTextEntityExtractionSample.java @@ -39,8 +39,6 @@ import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution; import com.google.cloud.aiplatform.v1beta1.TimestampSplit; import com.google.cloud.aiplatform.v1beta1.TrainingPipeline; -import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import com.google.rpc.Status; import java.io.IOException; diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java index bd7320813b8..9b3d83e7738 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoActionRecognitionSample.java @@ -41,10 +41,7 @@ public static void main(String[] args) throws IOException { } static void createTrainingPipelineVideoActionRecognitionSample( - String project, - String displayName, - String datasetId, - String modelDisplayName) + String project, String displayName, String datasetId, String modelDisplayName) throws IOException { PipelineServiceSettings settings = PipelineServiceSettings.newBuilder() @@ -57,9 +54,7 @@ static void createTrainingPipelineVideoActionRecognitionSample( // the "close" method on the client to safely clean up any remaining background resources. try (PipelineServiceClient client = PipelineServiceClient.create(settings)) { AutoMlVideoActionRecognitionInputs trainingTaskInputs = - AutoMlVideoActionRecognitionInputs.newBuilder() - .setModelType(ModelType.CLOUD) - .build(); + AutoMlVideoActionRecognitionInputs.newBuilder().setModelType(ModelType.CLOUD).build(); InputDataConfig inputDataConfig = InputDataConfig.newBuilder().setDatasetId(datasetId).build(); diff --git a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java index 5cc64fe719b..03cf2a522c4 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/CreateTrainingPipelineVideoObjectTrackingSample.java @@ -69,9 +69,7 @@ static void createTrainingPipelineVideoObjectTracking( LocationName locationName = LocationName.of(project, location); AutoMlVideoObjectTrackingInputs trainingTaskInputs = - AutoMlVideoObjectTrackingInputs.newBuilder() - .setModelType(ModelType.CLOUD) - .build(); + AutoMlVideoObjectTrackingInputs.newBuilder().setModelType(ModelType.CLOUD).build(); InputDataConfig inputDataConfig = InputDataConfig.newBuilder().setDatasetId(datasetId).build(); diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java index 15519675f34..7a6961b725b 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java @@ -64,22 +64,20 @@ static void predictImageClassification(String project, String fileName, String e String content = new String(contents, StandardCharsets.UTF_8); ImageClassificationPredictionInstance predictionInstance = - ImageClassificationPredictionInstance.newBuilder() - .setContent(content) - .build(); + ImageClassificationPredictionInstance.newBuilder().setContent(content).build(); List instances = new ArrayList<>(); instances.add(ValueConverter.toValue(predictionInstance)); ImageClassificationPredictionParams predictionParams = ImageClassificationPredictionParams.newBuilder() - .setConfidenceThreshold((float) 0.5) - .setMaxPredictions(5) - .build(); + .setConfidenceThreshold((float) 0.5) + .setMaxPredictions(5) + .build(); PredictResponse predictResponse = - predictionServiceClient.predict(endpointName, instances, - ValueConverter.toValue(predictionParams)); + predictionServiceClient.predict( + endpointName, instances, ValueConverter.toValue(predictionParams)); System.out.println("Predict Image Classification Response"); System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java index da8d4b8d737..7d5a8f1ccca 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictImageObjectDetectionSample.java @@ -28,7 +28,6 @@ import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageObjectDetectionPredictionParams; import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ImageObjectDetectionPredictionResult; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; @@ -71,9 +70,7 @@ static void predictImageObjectDetection(String project, String fileName, String .build(); ImageObjectDetectionPredictionInstance instance = - ImageObjectDetectionPredictionInstance.newBuilder() - .setContent(content) - .build(); + ImageObjectDetectionPredictionInstance.newBuilder().setContent(content).build(); List instances = new ArrayList<>(); instances.add(ValueConverter.toValue(instance)); @@ -90,8 +87,8 @@ static void predictImageObjectDetection(String project, String fileName, String ImageObjectDetectionPredictionResult.newBuilder(); ImageObjectDetectionPredictionResult result = - (ImageObjectDetectionPredictionResult) ValueConverter - .fromValue(resultBuilder, prediction); + (ImageObjectDetectionPredictionResult) + ValueConverter.fromValue(resultBuilder, prediction); for (int i = 0; i < result.getIdsCount(); i++) { System.out.printf("\tDisplay name: %s\n", result.getDisplayNames(i)); diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java index f5c42327c6e..c6fc5a8cb07 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularClassificationSample.java @@ -70,8 +70,8 @@ static void predictTabularClassification(String instance, String project, String TabularClassificationPredictionResult.Builder resultBuilder = TabularClassificationPredictionResult.newBuilder(); TabularClassificationPredictionResult result = - (TabularClassificationPredictionResult) ValueConverter - .fromValue(resultBuilder, prediction); + (TabularClassificationPredictionResult) + ValueConverter.fromValue(resultBuilder, prediction); for (int i = 0; i < result.getClassesCount(); i++) { System.out.printf("\tClass: %s", result.getClasses(i)); diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java index fd5ec9e68d1..bf728f84735 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTabularRegressionSample.java @@ -71,8 +71,7 @@ static void predictTabularRegression(String instance, String project, String end TabularRegressionPredictionResult.newBuilder(); TabularRegressionPredictionResult result = - (TabularRegressionPredictionResult) ValueConverter - .fromValue(resultBuilder, prediction); + (TabularRegressionPredictionResult) ValueConverter.fromValue(resultBuilder, prediction); System.out.printf("\tUpper bound: %f\n", result.getUpperBound()); System.out.printf("\tLower bound: %f\n", result.getLowerBound()); diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java index e8989906c5b..d3384dc72bd 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java @@ -55,10 +55,8 @@ static void predictTextClassificationSingleLabel( String location = "us-central1"; EndpointName endpointName = EndpointName.of(project, location, endpointId); - TextClassificationPredictionInstance predictionInstance = TextClassificationPredictionInstance - .newBuilder() - .setContent(content) - .build(); + TextClassificationPredictionInstance predictionInstance = + TextClassificationPredictionInstance.newBuilder().setContent(content).build(); List instances = new ArrayList<>(); instances.add(ValueConverter.toValue(predictionInstance)); diff --git a/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java b/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java index 6995e34b2da..47f4b9eabdc 100644 --- a/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java +++ b/aiplatform/snippets/src/main/java/aiplatform/PredictTextEntityExtractionSample.java @@ -26,7 +26,6 @@ import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.TextExtractionPredictionInstance; import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.TextExtractionPredictionResult; import com.google.protobuf.Value; -import com.google.protobuf.util.JsonFormat; import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -60,9 +59,7 @@ static void predictTextEntityExtraction(String project, String content, String e EndpointName endpointName = EndpointName.of(project, location, endpointId); TextExtractionPredictionInstance instance = - TextExtractionPredictionInstance.newBuilder() - .setContent(content) - .build(); + TextExtractionPredictionInstance.newBuilder().setContent(content).build(); List instances = new ArrayList<>(); instances.add(ValueConverter.toValue(instance)); diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextClassificationSampleTest.java new file mode 100644 index 00000000000..5c8e9b0ba18 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextClassificationSampleTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobTextClassificationSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String MODEL_ID = System.getenv("TEXT_CLASS_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String got; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_CLASS_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + + String batchPredictionJobId = + got.split("name:")[1].split("batchPredictionJobs/")[1].split("\"\n")[0]; + + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobTextClassificationSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "temp_java_create_batch_prediction_TCN_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobTextClassificationSample + .createBatchPredictionJobTextClassificationSample( + PROJECT, + LOCATION, + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_OUTPUT_URI); + + // Assert + got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSampleTest.java new file mode 100644 index 00000000000..3cc135491b5 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextEntityExtractionSampleTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobTextEntityExtractionSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String MODEL_ID = System.getenv("TEXT_ENTITY_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/batch_predict_TEN/ten_inputs.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String got; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_ENTITY_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + String batchPredictionJobId = + got.split("name:")[1].split("batchPredictionJobs/")[1].split("\"\n")[0]; + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobTextEntityExtractionSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "temp_java_create_batch_prediction_TEN_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobTextEntityExtractionSample + .createBatchPredictionJobTextEntityExtractionSample( + PROJECT, + LOCATION, + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_OUTPUT_URI); + + // Assert + got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSampleTest.java new file mode 100644 index 00000000000..8d70db48796 --- /dev/null +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobTextSentimentAnalysisSampleTest.java @@ -0,0 +1,111 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package aiplatform; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +public class CreateBatchPredictionJobTextSentimentAnalysisSampleTest { + private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); + private static final String LOCATION = "us-central1"; + private static final String MODEL_ID = System.getenv("TEXT_SENTI_MODEL_ID"); + private static final String GCS_SOURCE_URI = + "gs://ucaip-samples-test-output/inputs/batch_predict_TSN/tsn_inputs.jsonl"; + private static final String GCS_OUTPUT_URI = "gs://ucaip-samples-test-output/"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + private String got; + + private static void requireEnvVar(String varName) { + String errorMessage = + String.format("Environment variable '%s' is required to perform these tests.", varName); + assertNotNull(errorMessage, System.getenv(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("TEXT_SENTI_MODEL_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() + throws InterruptedException, ExecutionException, IOException, TimeoutException { + String batchPredictionJobId = + got.split("name:")[1].split("batchPredictionJobs/")[1].split("\"\n")[0]; + CancelBatchPredictionJobSample.cancelBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String cancelResponse = bout.toString(); + assertThat(cancelResponse).contains("Cancelled the Batch Prediction Job"); + TimeUnit.MINUTES.sleep(2); + + // Delete the Batch Prediction Job + DeleteBatchPredictionJobSample.deleteBatchPredictionJobSample(PROJECT, batchPredictionJobId); + + // Assert + String deleteResponse = bout.toString(); + assertThat(deleteResponse).contains("Deleted Batch"); + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testCreateBatchPredictionJobTextSentimentAnalysisSample() throws IOException { + // Act + String batchPredictionDisplayName = + String.format( + "temp_java_create_batch_prediction_TSN_%s", + UUID.randomUUID().toString().replaceAll("-", "_").substring(0, 26)); + + CreateBatchPredictionJobTextSentimentAnalysisSample + .createBatchPredictionJobTextSentimentAnalysisSample( + PROJECT, + LOCATION, + batchPredictionDisplayName, + MODEL_ID, + GCS_SOURCE_URI, + GCS_OUTPUT_URI); + + // Assert + got = bout.toString(); + assertThat(got).contains(batchPredictionDisplayName); + assertThat(got).contains("response:"); + } +} diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java index fc6c217e85b..d2a0fef2292 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoClassificationSampleTest.java @@ -34,7 +34,7 @@ public class CreateBatchPredictionJobVideoClassificationSampleTest { private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); - private static final String MODEL_ID = "8596984660557299712"; + private static final String MODEL_ID = System.getenv("VIDEO_CLASS_MODEL_ID"); private static final String GCS_SOURCE_URI = "gs://ucaip-samples-test-output/inputs/vcn_40_batch_prediction_input.jsonl"; private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/"; @@ -53,6 +53,7 @@ private static void requireEnvVar(String varName) { public static void checkRequirements() { requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); requireEnvVar("UCAIP_PROJECT_ID"); + requireEnvVar("VIDEO_CLASS_MODEL_ID"); } @Before diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java index b6890db18a4..c2f2855e402 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateBatchPredictionJobVideoObjectTrackingSampleTest.java @@ -34,7 +34,7 @@ public class CreateBatchPredictionJobVideoObjectTrackingSampleTest { private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID"); - private static final String MODEL_ID = "8609932509485989888"; + private static final String MODEL_ID = System.getenv("VIDEO_OBJECT_DETECT_MODEL_ID"); private static final String GCS_SOURCE_URI = "gs://ucaip-samples-test-output/inputs/vot_batch_prediction_input.jsonl"; private static final String GCS_DESTINATION_OUTPUT_URI_PREFIX = "gs://ucaip-samples-test-output/"; diff --git a/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java b/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java index 3dd56d5612a..2c6ee822278 100644 --- a/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java +++ b/aiplatform/snippets/src/test/java/aiplatform/CreateDataLabelingJobVideoSampleTest.java @@ -87,7 +87,7 @@ public void tearDown() } @Test - @Ignore + @Ignore("Avoid creating actual data labeling job for humans") public void testCreateDataLabelingJobVideoSample() throws IOException { // Act String dataLabelingDisplayName =