Skip to content

Commit

Permalink
chore: [vertexai] Integration test for REST Transport
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622942749
  • Loading branch information
Zhenyi Qi authored and copybara-github committed Apr 17, 2024
1 parent ae22f1c commit dae538c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.generativeai.ChatSession;
import com.google.cloud.vertexai.generativeai.ContentMaker;
import com.google.cloud.vertexai.generativeai.FunctionDeclarationMaker;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.PartMaker;
import com.google.cloud.vertexai.generativeai.ResponseHandler;
import com.google.cloud.vertexai.generativeai.ResponseStream;
import com.google.common.collect.ImmutableList;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Collections;
import java.util.logging.Logger;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -60,6 +65,24 @@ public void tearDown() throws IOException {
vertexAi.close();
}

private static void assertSizeAndAlternatingRolesInHistory(
String methodName,
ImmutableList<Content> history,
int expectedSize,
ImmutableList<Content> expectedUserContent) {
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("%s: The whole history is:\n%s", methodName, history));
assertThat(history.size()).isEqualTo(expectedSize);
for (int i = 1; i < expectedSize; i += 2) {
if (!expectedUserContent.isEmpty()) {
assertThat(history.get(i - 1)).isEqualTo(expectedUserContent.get((i - 1) / 2));
}
assertThat(history.get(i - 1).getRole()).isEqualTo("user");
assertThat(history.get(i).getRole()).isEqualTo("model");
}
}

@Test
public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
// Arrange
Expand All @@ -77,17 +100,14 @@ public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
assertThat(resp.getCandidatesList()).isNotEmpty();
}
GenerateContentResponse response = chat.sendMessage(secondMessage);
List<Content> history = chat.getHistory();
ImmutableList<Content> history = chat.getHistory();

// Assert
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("The whole history is:\n%s", history));
assertThat(history.size()).isEqualTo(4);
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
assertThat(history.get(1).getRole()).isEqualTo("model");
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
assertThat(history.get(3).getRole()).isEqualTo("model");
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
4,
ImmutableList.of(expectedFirstContent, expectedThirdContent));
}

@Test
Expand All @@ -98,8 +118,8 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I
Content expectedFirstContent = ContentMaker.fromString(firstMessage);
Content expectedThirdContent = ContentMaker.fromString(secondMessage);
GenerationConfig config = GenerationConfig.newBuilder().setTemperature(0.7F).build();
List<SafetySetting> safetySettings =
Arrays.asList(
ImmutableList<SafetySetting> safetySettings =
ImmutableList.of(
SafetySetting.newBuilder()
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH)
Expand All @@ -122,14 +142,79 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I
.sendMessage(secondMessage);

// Assert
List<Content> history = chat.getHistory();
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("The whole history is:\n%s", history));
assertThat(history.size()).isEqualTo(4);
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
assertThat(history.get(1).getRole()).isEqualTo("model");
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
assertThat(history.get(3).getRole()).isEqualTo("model");
ImmutableList<Content> history = chat.getHistory();
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
4,
ImmutableList.of(expectedFirstContent, expectedThirdContent));
}

@Test
public void sendMessageWithFunctionCalling_functionCallInResponse() throws IOException {
// Arrange
String firstMessage = "hello!";
String secondMessage = "What is the weather in Boston?";
// Making an Json object representing a function declaration
// The following code makes a function declaration
// {
// "name": "getCurrentWeather",
// "description": "Get the current weather in a given location",
// "parameters": {
// "type": "OBJECT",
// "properties": {
// "location": {
// "type": "STRING",
// "description": "location"
// }
// }
// }
// }
JsonObject locationJsonObject = new JsonObject();
locationJsonObject.addProperty("type", "STRING");
locationJsonObject.addProperty("description", "location");

JsonObject propertiesJsonObject = new JsonObject();
propertiesJsonObject.add("location", locationJsonObject);

JsonObject parametersJsonObject = new JsonObject();
parametersJsonObject.addProperty("type", "OBJECT");
parametersJsonObject.add("properties", propertiesJsonObject);

JsonObject jsonObject = new JsonObject();
jsonObject.addProperty("name", "getCurrentWeather");
jsonObject.addProperty("description", "Get the current weather in a given location");
jsonObject.add("parameters", parametersJsonObject);
Tool tool =
Tool.newBuilder()
.addFunctionDeclarations(FunctionDeclarationMaker.fromJsonObject(jsonObject))
.build();
ImmutableList<Tool> tools = ImmutableList.of(tool);

Content functionResponse =
ContentMaker.fromMultiModalData(
PartMaker.fromFunctionResponse(
"getCurrentWeather", Collections.singletonMap("currentWeather", "snowing")));

// Act
chat = model.startChat();
GenerateContentResponse firstResponse = chat.sendMessage(firstMessage);
GenerateContentResponse secondResponse = chat.withTools(tools).sendMessage(secondMessage);
GenerateContentResponse thirdResponse = chat.sendMessage(functionResponse);
ImmutableList<Content> history = chat.getHistory();

// Assert
assertThat(firstResponse.getCandidatesList()).hasSize(1);
assertThat(secondResponse.getCandidatesList()).hasSize(1);
assertThat(ResponseHandler.getFunctionCalls(secondResponse).size()).isEqualTo(1);
assertThat(thirdResponse.getCandidatesList()).hasSize(1);
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
6,
ImmutableList.of(
ContentMaker.fromString(firstMessage),
ContentMaker.fromString(secondMessage),
functionResponse));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static com.google.common.truth.Truth.assertThat;

import com.google.api.core.ApiFuture;
import com.google.cloud.vertexai.Transport;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensResponse;
Expand All @@ -41,7 +42,9 @@
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

Expand Down Expand Up @@ -70,6 +73,8 @@ public class ITGenerativeModelIntegrationTest {
private GenerativeModel multiModalModel;
private GenerativeModel latestGemini;

@Rule public TestName name = new TestName();

@Before
public void setUp() throws IOException {
vertexAi = new VertexAI(PROJECT_ID, LOCATION);
Expand Down Expand Up @@ -113,17 +118,30 @@ private static void assertNonEmptyAndLogTextContentOfResponseStream(
}
}

@Test
public void generateContent_restTransport_nonEmptyCandidateList() throws IOException {
try (VertexAI vertexAiViaRest =
new VertexAI.Builder()
.setProjectId(PROJECT_ID)
.setLocation(LOCATION)
.setTransport(Transport.REST)
.build()) {
GenerativeModel textModelWithRest = new GenerativeModel(MODEL_NAME_TEXT, vertexAiViaRest);
GenerateContentResponse response = textModelWithRest.generateContent(TEXT);

assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response);
}
}

@Test
public void generateContent_withPlainText_nonEmptyCandidateList() throws IOException {
GenerateContentResponse response = textModel.generateContent(TEXT);

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogResponse(methodName, TEXT, response);
assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response);
}

@Test
public void generateContent_withCompleteConfig_nonEmptyCandidateList() throws IOException {
logger.info(String.format("Generating response for question: %s", TEXT));
Integer maxOutputTokens = 50;
GenerationConfig generationConfig =
GenerationConfig.newBuilder()
Expand All @@ -147,8 +165,7 @@ public void generateContent_withCompleteConfig_nonEmptyCandidateList() throws IO
String contentText = ResponseHandler.getText(response);
int numWords = contentText.split("\\s+").length;

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogResponse(methodName, TEXT, response);
assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response);
// We avoid calling the countTokens service and just assert that the number of words should be
// less than the maxOutputTokens since each word on average results in more than one tokens.
assertThat(numWords).isAtMost(maxOutputTokens);
Expand All @@ -159,8 +176,7 @@ public void generateContentAsync_withPlainText_nonEmptyCandidateList() throws Ex
ApiFuture<GenerateContentResponse> responseFuture = textModel.generateContentAsync(TEXT);
GenerateContentResponse response = responseFuture.get();

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogResponse(methodName, TEXT, response);
assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response);
}

@Test
Expand All @@ -182,20 +198,18 @@ public void generateContent_withContentList_nonEmptyCandidate() throws IOExcepti
// can it. Same for `fromMultiModalList`
ContentMaker.fromString(followupPrompt)));

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogResponse(methodName, followupPrompt, response);
assertNonEmptyAndLogResponse(name.getMethodName(), followupPrompt, response);
}

@Test
public void generateContentStream_withPlainText_nonEmptyCandidateList() throws IOException {
ResponseStream<GenerateContentResponse> stream = textModel.generateContentStream(TEXT);

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogTextContentOfResponseStream(methodName, TEXT, stream);
assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), TEXT, stream);
}

// TODO(b/333866041): Re-enable byteImage test
@Ignore("The test is not compatible with GraalVM native image test on GitHub.")
@Ignore("TODO(b/333866041):The test is not compatible with GraalVM native image test on GitHub.")
@Test
public void generateContentStream_withByteImage_nonEmptyCandidateList() throws Exception {
ResponseStream<GenerateContentResponse> stream =
Expand All @@ -204,8 +218,7 @@ public void generateContentStream_withByteImage_nonEmptyCandidateList() throws E
IMAGE_INQUIRY,
PartMaker.fromMimeTypeAndData("image/jpeg", imageToBytes(new URL(IMAGE_URL)))));

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogTextContentOfResponseStream(methodName, IMAGE_INQUIRY, stream);
assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), IMAGE_INQUIRY, stream);
}

@Test
Expand All @@ -215,8 +228,7 @@ public void generateContentStream_withGcsVideo_nonEmptyCandidateList() throws Ex

ResponseStream<GenerateContentResponse> stream = multiModalModel.generateContentStream(content);

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogTextContentOfResponseStream(methodName, VIDEO_INQUIRY, stream);
assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), VIDEO_INQUIRY, stream);
}

@Test
Expand All @@ -226,8 +238,7 @@ public void generateContentStream_withGcsImage_nonEmptyCandidateList() throws Ex

ResponseStream<GenerateContentResponse> stream = multiModalModel.generateContentStream(content);

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogTextContentOfResponseStream(methodName, IMAGE_INQUIRY, stream);
assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), IMAGE_INQUIRY, stream);
}

@Test
Expand Down

0 comments on commit dae538c

Please sign in to comment.