Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automl: separate batch predict test, verify model is deployed before prediction #1931

Merged
merged 10 commits into from
Jan 7, 2020
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019 Google LLC
* Copyright 2020 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,10 @@
import static junit.framework.TestCase.assertNotNull;

import com.google.api.gax.paging.Page;
import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.DeployModelRequest;
import com.google.cloud.automl.v1.Model;
import com.google.cloud.automl.v1.ModelName;
import com.google.cloud.storage.Blob;
import com.google.cloud.storage.Storage;
import com.google.cloud.storage.StorageOptions;
Expand All @@ -31,70 +35,54 @@
import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

// Tests for automl natural language entity extraction "Predict" sample.
@RunWith(JUnit4.class)
@SuppressWarnings("checkstyle:abbreviationaswordinname")
public class LanguageEntityExtractionPredictIT {
public class BatchPredictTest {
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
private static final String BUCKET_ID = System.getenv("GOOGLE_CLOUD_PROJECT") + "-lcm";
private static final String modelId = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
private static final String BUCKET_ID = PROJECT_ID + "-lcm";
private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
private ByteArrayOutputStream bout;
private PrintStream out;

private static void requireEnvVar(String varName) {
assertNotNull(
System.getenv(varName),
"Environment variable '%s' is required to perform these tests.".format(varName)
);
System.getenv(varName),
"Environment variable '%s' is required to perform these tests.".format(varName));
}

@BeforeClass
public static void checkRequirements() {
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
requireEnvVar("GOOGLE_CLOUD_PROJECT");
requireEnvVar("AUTOML_PROJECT_ID");
requireEnvVar("ENTITY_EXTRACTION_MODEL_ID");
}

@Before
public void setUp() {
public void setUp() throws IOException, ExecutionException, InterruptedException {
// Verify that the model is deployed for prediction
try (AutoMlClient client = AutoMlClient.create()) {
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
Model model = client.getModel(modelFullId);
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
// Deploy the model if not deployed
DeployModelRequest request =
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
client.deployModelAsync(request).get();
}
}

bout = new ByteArrayOutputStream();
out = new PrintStream(bout);
System.setOut(out);
}

@After
public void tearDown() {
System.setOut(null);
}

@Test
public void testPredict() throws IOException {
String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome.";
// Act
LanguageEntityExtractionPredict.predict(PROJECT_ID, modelId, text);

// Assert
String got = bout.toString();
assertThat(got).contains("Text Extract Entity Type:");
}

@Ignore
public void testBatchPredict() throws IOException, ExecutionException, InterruptedException {
String inputUri = String.format("gs://%s/entity_extraction/input.jsonl", BUCKET_ID);
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
// Act
BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);

// Assert
String got = bout.toString();
assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket");

// Delete the created files from GCS
Storage storage = StorageOptions.getDefaultInstance().getService();
Page<Blob> blobs =
storage.list(
Expand All @@ -114,5 +102,19 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt
}
}
}

System.setOut(null);
}

@Test
public void testBatchPredict() throws IOException, ExecutionException, InterruptedException {
String inputUri = String.format("gs://%s/entity-extraction/input.jsonl", BUCKET_ID);
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
// Act
BatchPredict.batchPredict(PROJECT_ID, MODEL_ID, inputUri, outputUri);

// Assert
String got = bout.toString();
assertThat(got).contains("Batch Prediction results saved to specified Cloud Storage bucket");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
public class ExportDatasetTest {

private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
private static final String DATASET_ID = System.getenv("ENTITY_EXTRACTION_DATASET_ID");
private static final String DATASET_ID = "TEN0000000000000000000";
private static final String BUCKET_ID = PROJECT_ID + "-lcm";
private static final String BUCKET = "gs://" + BUCKET_ID;
private ByteArrayOutputStream bout;
Expand Down Expand Up @@ -69,34 +69,21 @@ public void setUp() {

@After
public void tearDown() {
// Delete the created files from GCS
Storage storage = StorageOptions.getDefaultInstance().getService();
Page<Blob> blobs =
storage.list(
BUCKET_ID,
Storage.BlobListOption.currentDirectory(),
Storage.BlobListOption.prefix("TEST_EXPORT_OUTPUT/"));

for (Blob blob : blobs.iterateAll()) {
Page<Blob> fileBlobs =
storage.list(
BUCKET_ID,
Storage.BlobListOption.currentDirectory(),
Storage.BlobListOption.prefix(blob.getName()));
for (Blob fileBlob : fileBlobs.iterateAll()) {
if (!fileBlob.isDirectory()) {
fileBlob.delete();
}
}
}

System.setOut(null);
}

@Test
public void testExportDataset() throws IOException, ExecutionException, InterruptedException {
ExportDataset.exportDataset(PROJECT_ID, DATASET_ID, BUCKET + "/TEST_EXPORT_OUTPUT/");
String got = bout.toString();
assertThat(got).contains("Dataset exported.");
// As exporting a dataset can take a long time and only one operation can be run on a dataset
// at once. Try to export a nonexistent dataset and confirm that the dataset was not found, but
// other elements of the request were valid.
try {
ExportDataset.exportDataset(PROJECT_ID, DATASET_ID, BUCKET + "/TEST_EXPORT_OUTPUT/");
String got = bout.toString();
assertThat(got).contains("The Dataset doesn't exist or is inaccessible for use with AutoMl.");
} catch (IOException | ExecutionException | InterruptedException e) {
assertThat(e.getMessage())
.contains("The Dataset doesn't exist or is inaccessible for use with AutoMl.");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright 2019 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.example.automl;

import static com.google.common.truth.Truth.assertThat;
import static junit.framework.TestCase.assertNotNull;

import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.DeployModelRequest;
import com.google.cloud.automl.v1.Model;
import com.google.cloud.automl.v1.ModelName;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.concurrent.ExecutionException;

import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
@SuppressWarnings("checkstyle:abbreviationaswordinname")
public class LanguageEntityExtractionPredictTest {
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
private static final String MODEL_ID = System.getenv("ENTITY_EXTRACTION_MODEL_ID");
private ByteArrayOutputStream bout;
private PrintStream out;

private static void requireEnvVar(String varName) {
assertNotNull(
System.getenv(varName),
"Environment variable '%s' is required to perform these tests.".format(varName));
}

@BeforeClass
public static void checkRequirements() {
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
requireEnvVar("GOOGLE_CLOUD_PROJECT");
requireEnvVar("AUTOML_PROJECT_ID");
requireEnvVar("ENTITY_EXTRACTION_MODEL_ID");
}

@Before
public void setUp() throws IOException, ExecutionException, InterruptedException {
// Verify that the model is deployed for prediction
try (AutoMlClient client = AutoMlClient.create()) {
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
Model model = client.getModel(modelFullId);
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
// Deploy the model if not deployed
DeployModelRequest request =
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
client.deployModelAsync(request).get();
}
}

bout = new ByteArrayOutputStream();
out = new PrintStream(bout);
System.setOut(out);
}

@After
public void tearDown() {
System.setOut(null);
}

@Test
public void testPredict() throws IOException {
String text = "Constitutional mutations in the WT1 gene in patients with Denys-Drash syndrome.";
LanguageEntityExtractionPredict.predict(PROJECT_ID, MODEL_ID, text);
String got = bout.toString();
assertThat(got).contains("Text Extract Entity Type:");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
import static com.google.common.truth.Truth.assertThat;
import static junit.framework.TestCase.assertNotNull;

import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.DeployModelRequest;
import com.google.cloud.automl.v1.Model;
import com.google.cloud.automl.v1.ModelName;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.concurrent.ExecutionException;

import org.junit.After;
import org.junit.Before;
Expand All @@ -30,20 +36,18 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

// Tests for automl natural language sentiment analysis "Predict" sample.
@RunWith(JUnit4.class)
@SuppressWarnings("checkstyle:abbreviationaswordinname")
public class LanguageSentimentAnalysisPredictIT {
public class LanguageSentimentAnalysisPredictTest {
private static final String PROJECT_ID = System.getenv("AUTOML_PROJECT_ID");
private static final String modelId = System.getenv("SENTIMENT_ANALYSIS_MODEL_ID");
private static final String MODEL_ID = System.getenv("SENTIMENT_ANALYSIS_MODEL_ID");
private ByteArrayOutputStream bout;
private PrintStream out;

private static void requireEnvVar(String varName) {
assertNotNull(
System.getenv(varName),
"Environment variable '%s' is required to perform these tests.".format(varName)
);
System.getenv(varName),
"Environment variable '%s' is required to perform these tests.".format(varName));
}

@BeforeClass
Expand All @@ -54,7 +58,19 @@ public static void checkRequirements() {
}

@Before
public void setUp() {
public void setUp() throws IOException, ExecutionException, InterruptedException {
// Verify that the model is deployed for prediction
try (AutoMlClient client = AutoMlClient.create()) {
ModelName modelFullId = ModelName.of(PROJECT_ID, "us-central1", MODEL_ID);
Model model = client.getModel(modelFullId);
if (model.getDeploymentState() == Model.DeploymentState.UNDEPLOYED) {
// Deploy the model if not deployed
DeployModelRequest request =
DeployModelRequest.newBuilder().setName(modelFullId.toString()).build();
client.deployModelAsync(request).get();
}
}

bout = new ByteArrayOutputStream();
out = new PrintStream(bout);
System.setOut(out);
Expand All @@ -68,10 +84,7 @@ public void tearDown() {
@Test
public void testPredict() throws IOException {
String text = "Hopefully this Claritin kicks in soon";
// Act
LanguageSentimentAnalysisPredict.predict(PROJECT_ID, modelId, text);

// Assert
LanguageSentimentAnalysisPredict.predict(PROJECT_ID, MODEL_ID, text);
String got = bout.toString();
assertThat(got).contains("Predicted sentiment score:");
}
Expand Down
Loading