diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java similarity index 68% rename from automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java rename to automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java index 701143c0e86..4bc36436056 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Google LLC + * Copyright 2020 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,10 @@ import static junit.framework.TestCase.assertNotNull; import com.google.api.gax.paging.Page; +import com.google.cloud.automl.v1.AutoMlClient; +import com.google.cloud.automl.v1.DeployModelRequest; +import com.google.cloud.automl.v1.Model; +import com.google.cloud.automl.v1.ModelName; import com.google.cloud.storage.Blob; import com.google.cloud.storage.Storage; import com.google.cloud.storage.StorageOptions; @@ -31,38 +35,46 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -// Tests for automl natural language entity extraction "Predict" sample. @RunWith(JUnit4.class) @SuppressWarnings("checkstyle:abbreviationaswordinname") -public class LanguageEntityExtractionPredictIT { +public class BatchPredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String BUCKET_ID = System.getenv("GOOGLE_CLOUD_PROJECT") + "-lcm"; - private static final String modelId = System.getenv("ENTITY_EXTRACTION_MODEL_ID"); + private static final String BUCKET_ID = PROJECT_ID + "-lcm"; + private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID"); private ByteArrayOutputStream bout; private PrintStream out; private static void requireEnvVar(String varName) { assertNotNull( - System.getenv(varName), - "Environment variable '%s' is required to perform these tests.".format(varName) - ); + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); } @BeforeClass public static void checkRequirements() { requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); - requireEnvVar("GOOGLE_CLOUD_PROJECT"); requireEnvVar("AUTOML_PROJECT_ID"); requireEnvVar("ENTITY_EXTRACTION_MODEL_ID"); } @Before - public void setUp() { + public void setUp() throws IOException, ExecutionException, InterruptedException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + client.deployModelAsync(request).get(); + } + } + bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); @@ -70,31 +82,7 @@ public void setUp() { @After public void tearDown() { - System.setOut(null); - } - - @Test - public void testPredict() throws IOException { - String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome."; - // Act - LanguageEntityExtractionPredict.predict(PROJECT_ID, modelId, text); - - // Assert - String got = bout.toString(); - assertThat(got).contains("Text Extract Entity Type:"); - } - - @Ignore - public void testBatchPredict() throws IOException, ExecutionException, InterruptedException { - String inputUri = String.format("gs://%s/entity_extraction/input.jsonl", BUCKET_ID); - String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); - // Act - BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); - - // Assert - String got = bout.toString(); - assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket"); - + // Delete the created files from GCS Storage storage = StorageOptions.getDefaultInstance().getService(); Page blobs = storage.list( @@ -114,5 +102,19 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt } } } + + System.setOut(null); + } + + @Test + public void testBatchPredict() throws IOException, ExecutionException, InterruptedException { + String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID); + String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); + // Act + BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri); + + // Assert + String got = bout.toString(); + assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket"); } } diff --git a/automl/cloud-client/src/test/java/com/example/automl/ExportDatasetTest.java b/automl/cloud-client/src/test/java/com/example/automl/ExportDatasetTest.java index 9f4b221e720..0e5966db85d 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/ExportDatasetTest.java +++ b/automl/cloud-client/src/test/java/com/example/automl/ExportDatasetTest.java @@ -41,7 +41,7 @@ public class ExportDatasetTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String DATASET_ID = System.getenv("ENTITY_EXTRACTION_DATASET_ID"); + private static final String DATASET_ID = "TEN0000000000000000000"; private static final String BUCKET_ID = PROJECT_ID + "-lcm"; private static final String BUCKET = "gs://" + BUCKET_ID; private ByteArrayOutputStream bout; @@ -69,34 +69,21 @@ public void setUp() { @After public void tearDown() { - // Delete the created files from GCS - Storage storage = StorageOptions.getDefaultInstance().getService(); - Page blobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix("TEST_EXPORT_OUTPUT/")); - - for (Blob blob : blobs.iterateAll()) { - Page fileBlobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix(blob.getName())); - for (Blob fileBlob : fileBlobs.iterateAll()) { - if (!fileBlob.isDirectory()) { - fileBlob.delete(); - } - } - } - System.setOut(null); } @Test public void testExportDataset() throws IOException, ExecutionException, InterruptedException { - ExportDataset.exportDataset(PROJECT_ID, DATASET_ID, BUCKET + "/TEST_EXPORT_OUTPUT/"); - String got = bout.toString(); - assertThat(got).contains("Dataset exported."); + // As exporting a dataset can take a long time and only one operation can be run on a dataset + // at once. Try to export a nonexistent dataset and confirm that the dataset was not found, but + // other elements of the request were valid. + try { + ExportDataset.exportDataset(PROJECT_ID, DATASET_ID, BUCKET + "/TEST_EXPORT_OUTPUT/"); + String got = bout.toString(); + assertThat(got).contains("The Dataset doesn't exist or is inaccessible for use with AutoMl."); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()) + .contains("The Dataset doesn't exist or is inaccessible for use with AutoMl."); + } } } diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictTest.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictTest.java new file mode 100644 index 00000000000..46c7938284c --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.automl.v1.AutoMlClient; +import com.google.cloud.automl.v1.DeployModelRequest; +import com.google.cloud.automl.v1.Model; +import com.google.cloud.automl.v1.ModelName; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class LanguageEntityExtractionPredictTest { + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("GOOGLE_CLOUD_PROJECT"); + requireEnvVar("AUTOML_PROJECT_ID"); + requireEnvVar("ENTITY_EXTRACTION_MODEL_ID"); + } + + @Before + public void setUp() throws IOException, ExecutionException, InterruptedException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + client.deployModelAsync(request).get(); + } + } + + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testPredict() throws IOException { + String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome."; + LanguageEntityExtractionPredict.predict(PROJECT_ID, MODEL_ID, text); + String got = bout.toString(); + assertThat(got).contains("Text Extract Entity Type:"); + } +} diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictTest.java similarity index 61% rename from automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictIT.java rename to automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictTest.java index 13627760f92..72600a81e78 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisPredictTest.java @@ -19,9 +19,15 @@ import static com.google.common.truth.Truth.assertThat; import static junit.framework.TestCase.assertNotNull; +import com.google.cloud.automl.v1.AutoMlClient; +import com.google.cloud.automl.v1.DeployModelRequest; +import com.google.cloud.automl.v1.Model; +import com.google.cloud.automl.v1.ModelName; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; +import java.util.concurrent.ExecutionException; import org.junit.After; import org.junit.Before; @@ -30,20 +36,18 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -// Tests for automl natural language sentiment analysis "Predict" sample. @RunWith(JUnit4.class) @SuppressWarnings("checkstyle:abbreviationaswordinname") -public class LanguageSentimentAnalysisPredictIT { +public class LanguageSentimentAnalysisPredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String modelId = System.getenv("SENTIMENT_ANALYSIS_MODEL_ID"); + private static final String MODEL_ID = System.getenv("SENTIMENT_ANALYSIS_MODEL_ID"); private ByteArrayOutputStream bout; private PrintStream out; private static void requireEnvVar(String varName) { assertNotNull( - System.getenv(varName), - "Environment variable '%s' is required to perform these tests.".format(varName) - ); + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); } @BeforeClass @@ -54,7 +58,19 @@ public static void checkRequirements() { } @Before - public void setUp() { + public void setUp() throws IOException, ExecutionException, InterruptedException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + client.deployModelAsync(request).get(); + } + } + bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); @@ -68,10 +84,7 @@ public void tearDown() { @Test public void testPredict() throws IOException { String text = "Hopefully this Claritin kicks in soon"; - // Act - LanguageSentimentAnalysisPredict.predict(PROJECT_ID, modelId, text); - - // Assert + LanguageSentimentAnalysisPredict.predict(PROJECT_ID, MODEL_ID, text); String got = bout.toString(); assertThat(got).contains("Predicted sentiment score:"); } diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationPredictTest.java similarity index 60% rename from automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationPredictIT.java rename to automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationPredictTest.java index 59dac164afc..75f4f10a270 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationPredictTest.java @@ -19,9 +19,15 @@ import static com.google.common.truth.Truth.assertThat; import static junit.framework.TestCase.assertNotNull; +import com.google.cloud.automl.v1.AutoMlClient; +import com.google.cloud.automl.v1.DeployModelRequest; +import com.google.cloud.automl.v1.Model; +import com.google.cloud.automl.v1.ModelName; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; +import java.util.concurrent.ExecutionException; import org.junit.After; import org.junit.Before; @@ -30,20 +36,18 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -// Tests for automl natural language text classification "Predict" sample. @RunWith(JUnit4.class) @SuppressWarnings("checkstyle:abbreviationaswordinname") -public class LanguageTextClassificationPredictIT { +public class LanguageTextClassificationPredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String modelId = System.getenv("TEXT_CLASSIFICATION_MODEL_ID"); + private static final String MODEL_ID = System.getenv("TEXT_CLASSIFICATION_MODEL_ID"); private ByteArrayOutputStream bout; private PrintStream out; private static void requireEnvVar(String varName) { assertNotNull( - System.getenv(varName), - "Environment variable '%s' is required to perform these tests.".format(varName) - ); + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); } @BeforeClass @@ -54,7 +58,19 @@ public static void checkRequirements() { } @Before - public void setUp() { + public void setUp() throws IOException, ExecutionException, InterruptedException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + client.deployModelAsync(request).get(); + } + } + bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); @@ -68,10 +84,7 @@ public void tearDown() { @Test public void testPredict() throws IOException { String text = "Fruit and nut flavour"; - // Act - LanguageTextClassificationPredict.predict(PROJECT_ID, modelId, text); - - // Assert + LanguageTextClassificationPredict.predict(PROJECT_ID, MODEL_ID, text); String got = bout.toString(); assertThat(got).contains("Predicted class name:"); } diff --git a/automl/cloud-client/src/test/java/com/example/automl/TranslatePredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/TranslatePredictTest.java similarity index 98% rename from automl/cloud-client/src/test/java/com/example/automl/TranslatePredictIT.java rename to automl/cloud-client/src/test/java/com/example/automl/TranslatePredictTest.java index 9f903230537..4bb3039aa27 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/TranslatePredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/TranslatePredictTest.java @@ -34,7 +34,7 @@ // Tests for translation "Predict" sample. @RunWith(JUnit4.class) @SuppressWarnings("checkstyle:abbreviationaswordinname") -public class TranslatePredictIT { +public class TranslatePredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); private static final String modelId = System.getenv("TRANSLATION_MODEL_ID"); private static final String filePath = "./resources/input.txt"; diff --git a/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictTest.java similarity index 50% rename from automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java rename to automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictTest.java index c9d7ed4bc93..fe8ea336806 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictTest.java @@ -19,10 +19,11 @@ import static com.google.common.truth.Truth.assertThat; import static junit.framework.TestCase.assertNotNull; -import com.google.api.gax.paging.Page; -import com.google.cloud.storage.Blob; -import com.google.cloud.storage.Storage; -import com.google.cloud.storage.StorageOptions; +import com.google.cloud.automl.v1.AutoMlClient; +import com.google.cloud.automl.v1.DeployModelRequest; +import com.google.cloud.automl.v1.Model; +import com.google.cloud.automl.v1.ModelName; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; @@ -31,38 +32,45 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -// Tests for automl vision image classification "Predict" sample. @RunWith(JUnit4.class) @SuppressWarnings("checkstyle:abbreviationaswordinname") -public class VisionClassificationPredictIT { +public class VisionClassificationPredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String BUCKET_ID = System.getenv("GOOGLE_CLOUD_PROJECT") + "-vcm"; - private static final String modelId = System.getenv("VISION_CLASSIFICATION_MODEL_ID"); + private static final String MODEL_ID = System.getenv("VISION_CLASSIFICATION_MODEL_ID"); private ByteArrayOutputStream bout; private PrintStream out; private static void requireEnvVar(String varName) { assertNotNull( - System.getenv(varName), - "Environment variable '%s' is required to perform these tests.".format(varName) - ); + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); } @BeforeClass public static void checkRequirements() { requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); - requireEnvVar("GOOGLE_CLOUD_PROJECT"); requireEnvVar("AUTOML_PROJECT_ID"); requireEnvVar("VISION_CLASSIFICATION_MODEL_ID"); } @Before - public void setUp() { + public void setUp() throws IOException, ExecutionException, InterruptedException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + client.deployModelAsync(request).get(); + } + } + bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); @@ -76,43 +84,8 @@ public void tearDown() { @Test public void testPredict() throws IOException { String filePath = "resources/test.png"; - // Act - VisionClassificationPredict.predict(PROJECT_ID, modelId, filePath); - - // Assert + VisionClassificationPredict.predict(PROJECT_ID, MODEL_ID, filePath); String got = bout.toString(); assertThat(got).contains("Predicted class name:"); } - - @Ignore - public void testBatchPredict() throws IOException, ExecutionException, InterruptedException { - String inputUri = String.format("gs://%s/batch_predict_test.csv", BUCKET_ID); - String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); - // Act - BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); - - // Assert - String got = bout.toString(); - assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket"); - - Storage storage = StorageOptions.getDefaultInstance().getService(); - Page blobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix("TEST_BATCH_PREDICT/")); - - for (Blob blob : blobs.iterateAll()) { - Page fileBlobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix(blob.getName())); - for (Blob fileBlob : fileBlobs.iterateAll()) { - if (!fileBlob.isDirectory()) { - fileBlob.delete(); - } - } - } - } } diff --git a/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java b/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictTest.java similarity index 50% rename from automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java rename to automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictTest.java index 347779037e8..593a867bf53 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java +++ b/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictTest.java @@ -19,10 +19,11 @@ import static com.google.common.truth.Truth.assertThat; import static junit.framework.TestCase.assertNotNull; -import com.google.api.gax.paging.Page; -import com.google.cloud.storage.Blob; -import com.google.cloud.storage.Storage; -import com.google.cloud.storage.StorageOptions; +import com.google.cloud.automl.v1.AutoMlClient; +import com.google.cloud.automl.v1.DeployModelRequest; +import com.google.cloud.automl.v1.Model; +import com.google.cloud.automl.v1.ModelName; + import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PrintStream; @@ -31,38 +32,45 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -// Tests for automl vision object detection "Predict" sample. @RunWith(JUnit4.class) @SuppressWarnings("checkstyle:abbreviationaswordinname") -public class VisionObjectDetectionPredictIT { +public class VisionObjectDetectionPredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String BUCKET_ID = System.getenv("GOOGLE_CLOUD_PROJECT") + "-vcm"; - private static final String modelId = System.getenv("OBJECT_DETECTION_MODEL_ID"); + private static final String MODEL_ID = System.getenv("OBJECT_DETECTION_MODEL_ID"); private ByteArrayOutputStream bout; private PrintStream out; private static void requireEnvVar(String varName) { assertNotNull( - System.getenv(varName), - "Environment variable '%s' is required to perform these tests.".format(varName) - ); + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); } @BeforeClass public static void checkRequirements() { requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); - requireEnvVar("GOOGLE_CLOUD_PROJECT"); requireEnvVar("AUTOML_PROJECT_ID"); requireEnvVar("OBJECT_DETECTION_MODEL_ID"); } @Before - public void setUp() { + public void setUp() throws IOException, ExecutionException, InterruptedException { + // Verify that the model is deployed for prediction + try (AutoMlClient client = AutoMlClient.create()) { + ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); + Model model = client.getModel(modelFullId); + if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { + // Deploy the model if not deployed + DeployModelRequest request = + DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); + client.deployModelAsync(request).get(); + } + } + bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); @@ -76,45 +84,9 @@ public void tearDown() { @Test public void testPredict() throws IOException { String filePath = "resources/salad.jpg"; - // Act - VisionObjectDetectionPredict.predict(PROJECT_ID, modelId, filePath); - - // Assert + VisionObjectDetectionPredict.predict(PROJECT_ID, MODEL_ID, filePath); String got = bout.toString(); assertThat(got).contains("X:"); assertThat(got).contains("Y:"); } - - @Ignore - public void testBatchPredict() throws IOException, ExecutionException, InterruptedException { - String inputUri = - String.format("gs://%s/vision_object_detection_batch_predict_test.csv", BUCKET_ID); - String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); - // Act - BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri); - - // Assert - String got = bout.toString(); - assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket"); - - Storage storage = StorageOptions.getDefaultInstance().getService(); - Page blobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix("TEST_BATCH_PREDICT/")); - - for (Blob blob : blobs.iterateAll()) { - Page fileBlobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix(blob.getName())); - for (Blob fileBlob : fileBlobs.iterateAll()) { - if (!fileBlob.isDirectory()) { - fileBlob.delete(); - } - } - } - } }