diff --git a/automl/cloud-client/src/main/java/com/example/automl/VisionClassificationCreateModel.java b/automl/cloud-client/src/main/java/com/example/automl/VisionClassificationCreateModel.java index 020975da611..3b18dc422ff 100644 --- a/automl/cloud-client/src/main/java/com/example/automl/VisionClassificationCreateModel.java +++ b/automl/cloud-client/src/main/java/com/example/automl/VisionClassificationCreateModel.java @@ -49,10 +49,7 @@ static void createModel(String projectId, String datasetId, String displayName) LocationName projectLocation = LocationName.of(projectId, "us-central1"); // Set model metadata. ImageClassificationModelMetadata metadata = - ImageClassificationModelMetadata.newBuilder() - .setTrainBudgetMilliNodeHours( - 8) // The train budget of creating this model, expressed in hours. - .build(); + ImageClassificationModelMetadata.newBuilder().setTrainBudgetMilliNodeHours(24000).build(); Model model = Model.newBuilder() .setDisplayName(displayName) diff --git a/automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java b/automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java index 4bc36436056..d46b669ac1d 100644 --- a/automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java +++ b/automl/cloud-client/src/test/java/com/example/automl/BatchPredictTest.java @@ -44,7 +44,7 @@ public class BatchPredictTest { private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); private static final String BUCKET_ID = PROJECT_ID + "-lcm"; - private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID"); + private static final String MODEL_ID = "TEN0000000000000000000"; private ByteArrayOutputStream bout; private PrintStream out; @@ -62,19 +62,7 @@ public static void checkRequirements() { } @Before - public void setUp() throws IOException, ExecutionException, InterruptedException { - // Verify that the model is deployed for prediction - try (AutoMlClient client = AutoMlClient.create()) { - ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID); - Model model = client.getModel(modelFullId); - if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) { - // Deploy the model if not deployed - DeployModelRequest request = - DeployModelRequest.newBuilder().setName(modelFullId.toString()).build(); - client.deployModelAsync(request).get(); - } - } - + public void setUp() { bout = new ByteArrayOutputStream(); out = new PrintStream(bout); System.setOut(out); @@ -82,39 +70,23 @@ public void setUp() throws IOException, ExecutionException, InterruptedException @After public void tearDown() { - // Delete the created files from GCS - Storage storage = StorageOptions.getDefaultInstance().getService(); - Page blobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix("TEST_BATCH_PREDICT/")); - - for (Blob blob : blobs.iterateAll()) { - Page fileBlobs = - storage.list( - BUCKET_ID, - Storage.BlobListOption.currentDirectory(), - Storage.BlobListOption.prefix(blob.getName())); - for (Blob fileBlob : fileBlobs.iterateAll()) { - if (!fileBlob.isDirectory()) { - fileBlob.delete(); - } - } - } - System.setOut(null); } @Test - public void testBatchPredict() throws IOException, ExecutionException, InterruptedException { - String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID); - String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); - // Act - BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri); - - // Assert - String got = bout.toString(); - assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket"); + public void testBatchPredict() { + // As batch prediction can take a long time. Try to batch predict on a model and confirm that + // the model was not found, but other elements of the request were valid. + try { + String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID); + String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID); + BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri); + String got = bout.toString(); + assertThat(got) + .contains("The model is either not found or not supported for prediction yet."); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()) + .contains("The model is either not found or not supported for prediction yet."); + } } } diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionCreateModelTest.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionCreateModelTest.java new file mode 100644 index 00000000000..7165916f66a --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionCreateModelTest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class LanguageEntityExtractionCreateModelTest { + + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String DATASET_ID = "TEN0000000000000000000"; + private ByteArrayOutputStream bout; + private PrintStream out; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("AUTOML_PROJECT_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() { + System.setOut(null); + } + + @Test + public void testLanguageEntityExtractionCreateModel() { + // As entity extraction does not let you cancel model creation, instead try to create a model + // from a nonexistent dataset, but other elements of the request were valid. + try { + // Create a random dataset name with a length of 32 characters (max allowed by AutoML) + // To prevent name collisions when running tests in multiple java versions at once. + // AutoML doesn't allow "-", but accepts "_" + String modelName = + String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); + LanguageEntityExtractionCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName); + String got = bout.toString(); + assertThat(got).contains("Dataset does not exist"); + } catch (IOException | ExecutionException | InterruptedException e) { + assertThat(e.getMessage()).contains("Dataset does not exist"); + } + } +} diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisCreateModelTest.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisCreateModelTest.java new file mode 100644 index 00000000000..13d83d04da1 --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageSentimentAnalysisCreateModelTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.automl.v1.AutoMlClient; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class LanguageSentimentAnalysisCreateModelTest { + + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("SENTIMENT_ANALYSIS_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private String operationId; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("AUTOML_PROJECT_ID"); + requireEnvVar("SENTIMENT_ANALYSIS_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() throws IOException { + // Cancel the operation + try (AutoMlClient client = AutoMlClient.create()) { + client.getOperationsClient().cancelOperation(operationId); + } + + System.setOut(null); + } + + @Test + public void testLanguageSentimentAnalysisCreateModel() + throws IOException, ExecutionException, InterruptedException { + // Create a random dataset name with a length of 32 characters (max allowed by AutoML) + // To prevent name collisions when running tests in multiple java versions at once. + // AutoML doesn't allow "-", but accepts "_" + String modelName = + String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); + LanguageSentimentAnalysisCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName); + + String got = bout.toString(); + assertThat(got).contains("Training started"); + + operationId = got.split("Training operation name: ")[1].split("\n")[0]; + } +} diff --git a/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationCreateModelTest.java b/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationCreateModelTest.java new file mode 100644 index 00000000000..4e53eed6bb5 --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/LanguageTextClassificationCreateModelTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.automl.v1.AutoMlClient; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class LanguageTextClassificationCreateModelTest { + + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TEXT_CLASSIFICATION_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private String operationId; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("AUTOML_PROJECT_ID"); + requireEnvVar("TEXT_CLASSIFICATION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() throws IOException { + // Cancel the operation + try (AutoMlClient client = AutoMlClient.create()) { + client.getOperationsClient().cancelOperation(operationId); + } + + System.setOut(null); + } + + @Test + public void testLanguageTextClassificationCreateModel() + throws IOException, ExecutionException, InterruptedException { + // Create a random dataset name with a length of 32 characters (max allowed by AutoML) + // To prevent name collisions when running tests in multiple java versions at once. + // AutoML doesn't allow "-", but accepts "_" + String modelName = + String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); + LanguageTextClassificationCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName); + + String got = bout.toString(); + assertThat(got).contains("Training started"); + + operationId = got.split("Training operation name: ")[1].split("\n")[0]; + } +} diff --git a/automl/cloud-client/src/test/java/com/example/automl/TranslateCreateModelTest.java b/automl/cloud-client/src/test/java/com/example/automl/TranslateCreateModelTest.java new file mode 100644 index 00000000000..3a54765dde0 --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/TranslateCreateModelTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.automl.v1.AutoMlClient; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +// Tests for Automl translation models. +@RunWith(JUnit4.class) +public class TranslateCreateModelTest { + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("TRANSLATION_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private String operationId; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("AUTOML_PROJECT_ID"); + requireEnvVar("TRANSLATION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() throws IOException { + // Cancel the operation + try (AutoMlClient client = AutoMlClient.create()) { + client.getOperationsClient().cancelOperation(operationId); + } + + System.setOut(null); + } + + @Test + public void testTranslateCreateModel() + throws IOException, ExecutionException, InterruptedException { + // Create a random dataset name with a length of 32 characters (max allowed by AutoML) + // To prevent name collisions when running tests in multiple java versions at once. + // AutoML doesn't allow "-", but accepts "_" + String modelName = + String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); + TranslateCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName); + + String got = bout.toString(); + assertThat(got).contains("Training started"); + + operationId = got.split("Training operation name: ")[1].split("\n")[0]; + } +} diff --git a/automl/cloud-client/src/test/java/com/example/automl/TranslateModelManagementIT.java b/automl/cloud-client/src/test/java/com/example/automl/TranslateModelManagementIT.java deleted file mode 100644 index 27ab477ae27..00000000000 --- a/automl/cloud-client/src/test/java/com/example/automl/TranslateModelManagementIT.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.example.automl; - -import static com.google.common.truth.Truth.assertThat; -import static junit.framework.TestCase.assertNotNull; - -import com.google.cloud.automl.v1.AutoMlClient; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.PrintStream; -import java.util.concurrent.ExecutionException; - -import org.junit.After; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -// Tests for Automl translation models. -@RunWith(JUnit4.class) -public class TranslateModelManagementIT { - private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); - private static final String DATASET_ID = System.getenv("TRANSLATION_DATASET_ID"); - private static final String MODEL_NAME = "translation_test_create_model"; - private ByteArrayOutputStream bout; - private PrintStream out; - private String modelId; - private String modelEvaluationId; - - private static void requireEnvVar(String varName) { - assertNotNull( - System.getenv(varName), - "Environment variable '%s' is required to perform these tests.".format(varName) - ); - } - - @BeforeClass - public static void checkRequirements() { - requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); - requireEnvVar("AUTOML_PROJECT_ID"); - requireEnvVar("TRANSLATION_DATASET_ID"); - } - - @Before - public void setUp() { - bout = new ByteArrayOutputStream(); - out = new PrintStream(bout); - System.setOut(out); - } - - @After - public void tearDown() { - System.setOut(null); - } - - @Test - public void testModelApi() throws IOException { - // LIST MODELS - ListModels.listModels(PROJECT_ID); - String got = bout.toString(); - modelId = got.split("Model id: ")[1].split("\n")[0]; - assertThat(got).contains("Model id:"); - - // GET MODEL - bout = new ByteArrayOutputStream(); - out = new PrintStream(bout); - System.setOut(out); - GetModel.getModel(PROJECT_ID, modelId); - got = bout.toString(); - assertThat(got).contains("Model id: " + modelId); - - // LIST MODEL EVALUATIONS - bout = new ByteArrayOutputStream(); - out = new PrintStream(bout); - System.setOut(out); - ListModelEvaluations.listModelEvaluations(PROJECT_ID, modelId); - got = bout.toString(); - modelEvaluationId = got.split(modelId + "/modelEvaluations/")[1].split("\n")[0]; - assertThat(got).contains("Model Evaluation Name:"); - - // Act - bout = new ByteArrayOutputStream(); - out = new PrintStream(bout); - System.setOut(out); - GetModelEvaluation.getModelEvaluation(PROJECT_ID, modelId, modelEvaluationId); - got = bout.toString(); - assertThat(got).contains("Model Evaluation Name:"); - } - - @Test - public void testOperationStatus() throws IOException { - // Act - ListOperationStatus.listOperationStatus(PROJECT_ID); - - // Assert - String got = bout.toString(); - String operationId = got.split("\n")[1].split(":")[1].trim(); - assertThat(got).contains("Operation details:"); - - // Act - bout.reset(); - GetOperationStatus.getOperationStatus(operationId); - - // Assert - got = bout.toString(); - assertThat(got).contains("Operation details:"); - } - - @Test - public void testCreateModel() throws IOException, ExecutionException, InterruptedException { - TranslateCreateModel.createModel(PROJECT_ID, DATASET_ID, MODEL_NAME); - - String got = bout.toString(); - assertThat(got).contains("Training started"); - - String operationId = got.split("Training operation name: ")[1].split("\n")[0]; - - try (AutoMlClient client = AutoMlClient.create()) { - client.getOperationsClient().cancelOperation(operationId); - } - } -} diff --git a/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationCreateModelTest.java b/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationCreateModelTest.java new file mode 100644 index 00000000000..1ed242d2dca --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/VisionClassificationCreateModelTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.automl.v1.AutoMlClient; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class VisionClassificationCreateModelTest { + + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("VISION_CLASSIFICATION_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private String operationId; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("AUTOML_PROJECT_ID"); + requireEnvVar("VISION_CLASSIFICATION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() throws IOException { + // Cancel the operation + try (AutoMlClient client = AutoMlClient.create()) { + client.getOperationsClient().cancelOperation(operationId); + } + + System.setOut(null); + } + + @Test + public void testVisionClassificationCreateModel() + throws IOException, ExecutionException, InterruptedException { + // Create a random dataset name with a length of 32 characters (max allowed by AutoML) + // To prevent name collisions when running tests in multiple java versions at once. + // AutoML doesn't allow "-", but accepts "_" + String modelName = + String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); + VisionClassificationCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName); + + String got = bout.toString(); + assertThat(got).contains("Training started"); + + operationId = got.split("Training operation name: ")[1].split("\n")[0]; + } +} diff --git a/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionCreateModelTest.java b/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionCreateModelTest.java new file mode 100644 index 00000000000..7473b1a46c8 --- /dev/null +++ b/automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionCreateModelTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.example.automl; + +import static com.google.common.truth.Truth.assertThat; +import static junit.framework.TestCase.assertNotNull; + +import com.google.cloud.automl.v1.AutoMlClient; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.UUID; +import java.util.concurrent.ExecutionException; + +import org.junit.After; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +@SuppressWarnings("checkstyle:abbreviationaswordinname") +public class VisionObjectDetectionCreateModelTest { + + private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID"); + private static final String DATASET_ID = System.getenv("OBJECT_DETECTION_DATASET_ID"); + private ByteArrayOutputStream bout; + private PrintStream out; + private String operationId; + + private static void requireEnvVar(String varName) { + assertNotNull( + System.getenv(varName), + "Environment variable '%s' is required to perform these tests.".format(varName)); + } + + @BeforeClass + public static void checkRequirements() { + requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS"); + requireEnvVar("AUTOML_PROJECT_ID"); + requireEnvVar("OBJECT_DETECTION_DATASET_ID"); + } + + @Before + public void setUp() { + bout = new ByteArrayOutputStream(); + out = new PrintStream(bout); + System.setOut(out); + } + + @After + public void tearDown() throws IOException { + // Cancel the operation + try (AutoMlClient client = AutoMlClient.create()) { + client.getOperationsClient().cancelOperation(operationId); + } + + System.setOut(null); + } + + @Test + public void testVisionObjectDetectionCreateModel() + throws IOException, ExecutionException, InterruptedException { + // Create a random dataset name with a length of 32 characters (max allowed by AutoML) + // To prevent name collisions when running tests in multiple java versions at once. + // AutoML doesn't allow "-", but accepts "_" + String modelName = + String.format("test_%s", UUID.randomUUID().toString().replace("-", "_").substring(0, 26)); + VisionObjectDetectionCreateModel.createModel(PROJECT_ID, DATASET_ID, modelName); + + String got = bout.toString(); + assertThat(got).contains("Training started"); + + operationId = got.split("Training operation name: ")[1].split("\n")[0]; + } +}