diff --git a/automl/snippets/resources/dandelion.jpg b/automl/snippets/resources/dandelion.jpg new file mode 100644 index 00000000000..326e4c1bf53 Binary files /dev/null and b/automl/snippets/resources/dandelion.jpg differ diff --git a/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationDeployModel.java b/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationDeployModel.java new file mode 100644 index 00000000000..6824c07dc71 --- /dev/null +++ b/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationDeployModel.java @@ -0,0 +1,60 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +// [START automl_vision_classification_deploy_model] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.DeployModelRequest; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +class ClassificationDeployModel { + + // Deploy a model + static void classificationDeployModel(String projectId, String modelId) + throws IOException, ExecutionException, InterruptedException { + // String projectId = "YOUR_PROJECT_ID"; + // String modelId = "YOUR_MODEL_ID"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (AutoMlClient client = AutoMlClient.create()) { + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, "us-central1", modelId); + + // Build deploy model request. + DeployModelRequest deployModelRequest = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + + // Deploy a model with the deploy model request. + OperationFuture future = + client.deployModelAsync(deployModelRequest); + + future.get(); + + // Display the deployment details of model. + System.out.println("Model deployment finished"); + } + } +} +// [END automl_vision_classification_deploy_model] diff --git a/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationDeployModelNodeCount.java b/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationDeployModelNodeCount.java new file mode 100644 index 00000000000..07b91d9be44 --- /dev/null +++ b/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationDeployModelNodeCount.java @@ -0,0 +1,61 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +// [START automl_vision_classification_deploy_model_node_count] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.DeployModelRequest; +import com.google.cloud.automl.v1beta1.ImageClassificationModelDeploymentMetadata; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +class ClassificationDeployModelNodeCount { + + // Deploy a model with a specified node count + static void classificationDeployModelNodeCount(String projectId, String modelId) + throws IOException, ExecutionException, InterruptedException { + // String projectId = "YOUR_PROJECT_ID"; + // String modelId = "YOUR_MODEL_ID"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (AutoMlClient client = AutoMlClient.create()) { + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, "us-central1", modelId); + + // Set how many nodes the model is deployed on + ImageClassificationModelDeploymentMetadata deploymentMetadata = + ImageClassificationModelDeploymentMetadata.newBuilder().setNodeCount(2).build(); + + DeployModelRequest request = + DeployModelRequest.newBuilder() + .setName(modelFullId.toString()) + .setImageClassificationModelDeploymentMetadata(deploymentMetadata) + .build(); + // Deploy the model + OperationFuture future = client.deployModelAsync(request); + future.get(); + System.out.println("Model deployment on 2 nodes finished"); + } + } +} +// [END automl_vision_classification_deploy_model_node_count] diff --git a/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationUndeployModel.java b/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationUndeployModel.java new file mode 100644 index 00000000000..73ff19bef00 --- /dev/null +++ b/automl/snippets/src/main/java/com/google/cloud/vision/ClassificationUndeployModel.java @@ -0,0 +1,60 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +// [START automl_vision_classification_undeploy_model] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; +import com.google.cloud.automl.v1beta1.UndeployModelRequest; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +class ClassificationUndeployModel { + + // Deploy a model + static void classificationUndeployModel(String projectId, String modelId) + throws IOException, ExecutionException, InterruptedException { + // String projectId = "YOUR_PROJECT_ID"; + // String modelId = "YOUR_MODEL_ID"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (AutoMlClient client = AutoMlClient.create()) { + + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, "us-central1", modelId); + + // Build deploy model request. + UndeployModelRequest undeployModelRequest = + UndeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + + // Deploy a model with the deploy model request. + OperationFuture future = + client.undeployModelAsync(undeployModelRequest); + + future.get(); + + // Display the deployment details of model. + System.out.println("Model undeploy finished"); + } + } +} +// [END automl_vision_classification_undeploy_model] diff --git a/automl/snippets/src/main/java/com/google/cloud/vision/ModelApi.java b/automl/snippets/src/main/java/com/google/cloud/vision/ModelApi.java new file mode 100644 index 00000000000..7ce19be0caa --- /dev/null +++ b/automl/snippets/src/main/java/com/google/cloud/vision/ModelApi.java @@ -0,0 +1,143 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +// Imports the Google Cloud client library +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationEvaluationMetrics; +import com.google.cloud.automl.v1beta1.ClassificationProto.ClassificationEvaluationMetrics.ConfidenceMetricsEntry; +import com.google.cloud.automl.v1beta1.ImageClassificationModelMetadata; +import com.google.cloud.automl.v1beta1.ListModelEvaluationsRequest; +import com.google.cloud.automl.v1beta1.ListModelsRequest; +import com.google.cloud.automl.v1beta1.LocationName; +import com.google.cloud.automl.v1beta1.Model; +import com.google.cloud.automl.v1beta1.ModelEvaluation; +import com.google.cloud.automl.v1beta1.ModelEvaluationName; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; +import com.google.longrunning.Operation; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.List; +import java.util.concurrent.ExecutionException; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; + +/** + * Google Cloud AutoML Vision API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.ModelApi' -Dexec.args='create_model + * [datasetId] test_model' + */ +public class ModelApi { + + // [START automl_vision_create_model] + /** + * Demonstrates using the AutoML client to create a model. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param dataSetId the Id of the dataset to which model is created. + * @param modelName the Name of the model. + * @param trainBudget the Budget for training the model. + */ + static void createModel( + String projectId, + String computeRegion, + String dataSetId, + String modelName, + String trainBudget) { + // Instantiates a client + try (AutoMlClient client = AutoMlClient.create()) { + + // A resource that represents Google Cloud Platform location. + LocationName projectLocation = LocationName.of(projectId, computeRegion); + + // Set model metadata. + ImageClassificationModelMetadata imageClassificationModelMetadata = + Long.valueOf(trainBudget) == 0 + ? ImageClassificationModelMetadata.newBuilder().build() + : ImageClassificationModelMetadata.newBuilder() + .setTrainBudget(Long.valueOf(trainBudget)) + .build(); + + // Set model name and model metadata for the image dataset. + Model myModel = + Model.newBuilder() + .setDisplayName(modelName) + .setDatasetId(dataSetId) + .setImageClassificationModelMetadata(imageClassificationModelMetadata) + .build(); + + // Create a model with the model metadata in the region. + OperationFuture response = + client.createModelAsync(projectLocation, myModel); + + System.out.println( + String.format( + "Training operation name: %s", response.getInitialFuture().get().getName())); + System.out.println("Training started..."); + } catch (IOException | ExecutionException | InterruptedException e) { + e.printStackTrace(); + } + } + // [END automl_vision_create_model] + + public static void main(String[] args) { + argsHelper(args); + } + + static void argsHelper(String[] args) { + ArgumentParser parser = + ArgumentParsers.newFor("ModelApi") + .build() + .defaultHelp(true) + .description("Model API operations."); + Subparsers subparsers = parser.addSubparsers().dest("command"); + + Subparser createModelParser = subparsers.addParser("create_model"); + createModelParser.addArgument("datasetId"); + createModelParser.addArgument("modelName"); + createModelParser.addArgument("trainBudget"); + + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String computeRegion = System.getenv("REGION_NAME"); + + if (projectId == null || computeRegion == null) { + System.out.println("Set `GOOGLE_CLOUD_PROJECT` and `REGION_NAME` as specified in the README"); + System.exit(-1); + } + + try { + Namespace ns = parser.parseArgs(args); + if (ns.get("command").equals("create_model")) { + createModel( + projectId, + computeRegion, + ns.getString("datasetId"), + ns.getString("modelName"), + ns.getString("trainBudget")); + } + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/automl/snippets/src/main/java/com/google/cloud/vision/ObjectDetectionDeployModelNodeCount.java b/automl/snippets/src/main/java/com/google/cloud/vision/ObjectDetectionDeployModelNodeCount.java new file mode 100644 index 00000000000..a26137a67ff --- /dev/null +++ b/automl/snippets/src/main/java/com/google/cloud/vision/ObjectDetectionDeployModelNodeCount.java @@ -0,0 +1,60 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +// [START automl_vision_object_detection_deploy_model_node_count] +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.DeployModelRequest; +import com.google.cloud.automl.v1beta1.ImageObjectDetectionModelDeploymentMetadata; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.OperationMetadata; +import com.google.protobuf.Empty; +import java.io.IOException; +import java.util.concurrent.ExecutionException; + +class ObjectDetectionDeployModelNodeCount { + + static void objectDetectionDeployModelNodeCount(String projectId, String modelId) + throws IOException, ExecutionException, InterruptedException { + // String projectId = "YOUR_PROJECT_ID"; + // String modelId = "YOUR_MODEL_ID"; + + // Initialize client that will be used to send requests. This client only needs to be created + // once, and can be reused for multiple requests. After completing all of your requests, call + // the "close" method on the client to safely clean up any remaining background resources. + try (AutoMlClient client = AutoMlClient.create()) { + // Get the full path of the model. + ModelName modelFullId = ModelName.of(projectId, "us-central1", modelId); + + // Set how many nodes the model is deployed on + ImageObjectDetectionModelDeploymentMetadata deploymentMetadata = + ImageObjectDetectionModelDeploymentMetadata.newBuilder().setNodeCount(2).build(); + + DeployModelRequest request = + DeployModelRequest.newBuilder() + .setName(modelFullId.toString()) + .setImageObjectDetectionModelDeploymentMetadata(deploymentMetadata) + .build(); + // Deploy the model + OperationFuture future = client.deployModelAsync(request); + future.get(); + System.out.println("Model deployment on 2 nodes finished"); + } + } +} +// [END automl_vision_object_detection_deploy_model_node_count] diff --git a/automl/snippets/src/main/java/com/google/cloud/vision/PredictionApi.java b/automl/snippets/src/main/java/com/google/cloud/vision/PredictionApi.java new file mode 100644 index 00000000000..404ee287765 --- /dev/null +++ b/automl/snippets/src/main/java/com/google/cloud/vision/PredictionApi.java @@ -0,0 +1,136 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This application demonstrates how to perform basic operations on prediction + * with the Google AutoML Vision API. + * + * For more information, the documentation at + * https://cloud.google.com/vision/automl/docs. + */ + +package com.google.cloud.vision; + +// Imports the Google Cloud client library +import com.google.cloud.automl.v1beta1.AnnotationPayload; +import com.google.cloud.automl.v1beta1.ExamplePayload; +import com.google.cloud.automl.v1beta1.Image; +import com.google.cloud.automl.v1beta1.ModelName; +import com.google.cloud.automl.v1beta1.PredictResponse; +import com.google.cloud.automl.v1beta1.PredictionServiceClient; +import com.google.protobuf.ByteString; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; + +/** + * Google Cloud AutoML Vision API sample application. Example usage: mvn package exec:java + * -Dexec.mainClass ='com.google.cloud.vision.samples.automl.PredictionApi' -Dexec.args='predict + * [modelId] [path-to-image] [scoreThreshold]' + */ +public class PredictionApi { + + // [START automl_vision_predict] + /** + * Demonstrates using the AutoML client to predict an image. + * + * @param projectId the Id of the project. + * @param computeRegion the Region name. + * @param modelId the Id of the model which will be used for text classification. + * @param filePath the Local text file path of the content to be classified. + * @param scoreThreshold the Confidence score. Only classifications with confidence score above + * scoreThreshold are displayed. + */ + static void predict( + String projectId, + String computeRegion, + String modelId, + String filePath, + String scoreThreshold) { + + // Instantiate client for prediction service. + try (PredictionServiceClient predictionClient = PredictionServiceClient.create()) { + + // Get the full path of the model. + ModelName name = ModelName.of(projectId, computeRegion, modelId); + + // Read the image and assign to payload. + ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath))); + Image image = Image.newBuilder().setImageBytes(content).build(); + ExamplePayload examplePayload = ExamplePayload.newBuilder().setImage(image).build(); + + // Additional parameters that can be provided for prediction e.g. Score Threshold + Map params = new HashMap<>(); + if (scoreThreshold != null) { + params.put("score_threshold", scoreThreshold); + } + // Perform the AutoML Prediction request + PredictResponse response = predictionClient.predict(name, examplePayload, params); + + System.out.println("Prediction results:"); + for (AnnotationPayload annotationPayload : response.getPayloadList()) { + System.out.println("Predicted class name :" + annotationPayload.getDisplayName()); + System.out.println( + "Predicted class score :" + annotationPayload.getClassification().getScore()); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + // [END automl_vision_predict] + + public static void main(String[] args) { + argsHelper(args); + } + + static void argsHelper(String[] args) { + ArgumentParser parser = + ArgumentParsers.newFor("PredictionApi") + .build() + .defaultHelp(true) + .description("Prediction API Operation"); + + parser.addArgument("modelId").required(true); + parser.addArgument("filePath").required(true); + parser.addArgument("scoreThreshold").nargs("?").type(String.class).setDefault(""); + + String projectId = System.getenv("GOOGLE_CLOUD_PROJECT"); + String computeRegion = System.getenv("REGION_NAME"); + + if (projectId == null || computeRegion == null) { + System.out.println("Set `GOOGLE_CLOUD_PROJECT` and `REGION_NAME` as specified in the README"); + System.exit(-1); + } + + try { + Namespace ns = parser.parseArgs(args); + predict( + projectId, + computeRegion, + ns.getString("modelId"), + ns.getString("filePath"), + ns.getString("scoreThreshold")); + } catch (ArgumentParserException e) { + parser.handleError(e); + } + } +} diff --git a/automl/snippets/src/test/java/com/google/cloud/vision/ClassificationDeployModelIT.java b/automl/snippets/src/test/java/com/google/cloud/vision/ClassificationDeployModelIT.java new file mode 100644 index 00000000000..d0683a01384 --- /dev/null +++ b/automl/snippets/src/test/java/com/google/cloud/vision/ClassificationDeployModelIT.java @@ -0,0 +1,92 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +@SuppressWarnings("checkstyle:AbbreviationAsWordInName") +public class ClassificationDeployModelIT { + private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); + private static final String MODEL_ID = "ICN0000000000000000000"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testClassificationDeployModelApi() { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + try { + ClassificationDeployModel.classificationDeployModel(PROJECT_ID, MODEL_ID); + String got = bout.toString(); + assertThat(got).contains("The model does not exist"); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()).contains("The model does not exist"); + } + } + + @Test + public void testClassificationUndeployModelApi() { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + try { + ClassificationUndeployModel.classificationUndeployModel(PROJECT_ID, MODEL_ID); + String got = bout.toString(); + assertThat(got).contains("The model does not exist"); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()).contains("The model does not exist"); + } + } + + @Test + public void testClassificationDeployModelNodeCountApi() { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + try { + ClassificationDeployModelNodeCount.classificationDeployModelNodeCount(PROJECT_ID, MODEL_ID); + String got = bout.toString(); + assertThat(got).contains("The model does not exist"); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()).contains("The model does not exist"); + } + } +} diff --git a/automl/snippets/src/test/java/com/google/cloud/vision/ObjectDetectionDeployModelNodeCountIT.java b/automl/snippets/src/test/java/com/google/cloud/vision/ObjectDetectionDeployModelNodeCountIT.java new file mode 100644 index 00000000000..80e0254caf2 --- /dev/null +++ b/automl/snippets/src/test/java/com/google/cloud/vision/ObjectDetectionDeployModelNodeCountIT.java @@ -0,0 +1,68 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for vision "Deploy Model Node Count" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class ObjectDetectionDeployModelNodeCountIT { + private static final String PROJECT_ID = System.getenv("GOOGLE_CLOUD_PROJECT"); + private static final String MODEL_ID = "0000000000000000000000"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testObjectDetectionDeployModelNodeCountApi() { + // As model deployment can take a long time, instead try to deploy a + // nonexistent model and confirm that the model was not found, but other + // elements of the request were valid. + try { + ObjectDetectionDeployModelNodeCount.objectDetectionDeployModelNodeCount(PROJECT_ID, MODEL_ID); + String got = bout.toString(); + assertThat(got).contains("The model does not exist"); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()).contains("The model does not exist"); + } + } +} diff --git a/automl/snippets/src/test/java/com/google/cloud/vision/PredictionApiIT.java b/automl/snippets/src/test/java/com/google/cloud/vision/PredictionApiIT.java new file mode 100644 index 00000000000..0464db3b033 --- /dev/null +++ b/automl/snippets/src/test/java/com/google/cloud/vision/PredictionApiIT.java @@ -0,0 +1,85 @@ +/* + * Copyright 2018 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.vision; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.cloud.automl.v1beta1.AutoMlClient; +import com.google.cloud.automl.v1beta1.DeployModelRequest; +import com.google.cloud.automl.v1beta1.Model; +import com.google.cloud.automl.v1beta1.ModelName; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for vision "PredictionAPI" sample. */ +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class PredictionApiIT { + private static final String COMPUTE_REGION = "us-central1"; + private static final String PROJECT_ID = "java-docs-samples-testing"; + private static final String modelId = "ICN620201829169141520"; + private static final String filePath = "./resources/dandelion.jpg"; + private static final String scoreThreshold = "0.7"; + private ByteArrayOutputStream bout; + private PrintStream out; + private PrintStream originalPrintStream; + + @Before + public void setUp() + throws IOException, ExecutionException, InterruptedException, TimeoutException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", modelId); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + Future future = client.deployModelAsync(request); + future.get(30, TimeUnit.MINUTES); + } + } + + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + originalPrintStream = System.out; + System.setOut(out); + } + + @After + public void tearDown() { + System.out.flush(); + System.setOut(originalPrintStream); + } + + @Test + public void testPredict() { + PredictionApi.predict(PROJECT_ID, COMPUTE_REGION, modelId, filePath, scoreThreshold); + String got = bout.toString(); + assertThat(got).contains("dandelion"); + } +}