diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java index f08444eda777..5dbf0ed507ec 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITChatSessionIntegrationTest.java @@ -23,13 +23,18 @@ import com.google.cloud.vertexai.api.GenerationConfig; import com.google.cloud.vertexai.api.HarmCategory; import com.google.cloud.vertexai.api.SafetySetting; +import com.google.cloud.vertexai.api.Tool; import com.google.cloud.vertexai.generativeai.ChatSession; import com.google.cloud.vertexai.generativeai.ContentMaker; +import com.google.cloud.vertexai.generativeai.FunctionDeclarationMaker; import com.google.cloud.vertexai.generativeai.GenerativeModel; +import com.google.cloud.vertexai.generativeai.PartMaker; +import com.google.cloud.vertexai.generativeai.ResponseHandler; import com.google.cloud.vertexai.generativeai.ResponseStream; +import com.google.common.collect.ImmutableList; +import com.google.gson.JsonObject; import java.io.IOException; -import java.util.Arrays; -import java.util.List; +import java.util.Collections; import java.util.logging.Logger; import org.junit.After; import org.junit.Before; @@ -60,6 +65,24 @@ public void tearDown() throws IOException { vertexAi.close(); } + private static void assertSizeAndAlternatingRolesInHistory( + String methodName, + ImmutableList history, + int expectedSize, + ImmutableList expectedUserContent) { + // GenAI output is flaky so we always print out the response. + // For the same reason, we don't do assertions much. + logger.info(String.format("%s: The whole history is:\n%s", methodName, history)); + assertThat(history.size()).isEqualTo(expectedSize); + for (int i = 1; i < expectedSize; i += 2) { + if (!expectedUserContent.isEmpty()) { + assertThat(history.get(i - 1)).isEqualTo(expectedUserContent.get((i - 1) / 2)); + } + assertThat(history.get(i - 1).getRole()).isEqualTo("user"); + assertThat(history.get(i).getRole()).isEqualTo("model"); + } + } + @Test public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException { // Arrange @@ -77,17 +100,14 @@ public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException { assertThat(resp.getCandidatesList()).isNotEmpty(); } GenerateContentResponse response = chat.sendMessage(secondMessage); - List history = chat.getHistory(); + ImmutableList history = chat.getHistory(); // Assert - // GenAI output is flaky so we always print out the response. - // For the same reason, we don't do assertions much. - logger.info(String.format("The whole history is:\n%s", history)); - assertThat(history.size()).isEqualTo(4); - assertThat(history.get(0)).isEqualTo(expectedFirstContent); - assertThat(history.get(1).getRole()).isEqualTo("model"); - assertThat(history.get(2)).isEqualTo(expectedThirdContent); - assertThat(history.get(3).getRole()).isEqualTo("model"); + assertSizeAndAlternatingRolesInHistory( + Thread.currentThread().getStackTrace()[1].getMethodName(), + history, + 4, + ImmutableList.of(expectedFirstContent, expectedThirdContent)); } @Test @@ -98,8 +118,8 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I Content expectedFirstContent = ContentMaker.fromString(firstMessage); Content expectedThirdContent = ContentMaker.fromString(secondMessage); GenerationConfig config = GenerationConfig.newBuilder().setTemperature(0.7F).build(); - List safetySettings = - Arrays.asList( + ImmutableList safetySettings = + ImmutableList.of( SafetySetting.newBuilder() .setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT) .setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH) @@ -122,14 +142,79 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I .sendMessage(secondMessage); // Assert - List history = chat.getHistory(); - // GenAI output is flaky so we always print out the response. - // For the same reason, we don't do assertions much. - logger.info(String.format("The whole history is:\n%s", history)); - assertThat(history.size()).isEqualTo(4); - assertThat(history.get(0)).isEqualTo(expectedFirstContent); - assertThat(history.get(1).getRole()).isEqualTo("model"); - assertThat(history.get(2)).isEqualTo(expectedThirdContent); - assertThat(history.get(3).getRole()).isEqualTo("model"); + ImmutableList history = chat.getHistory(); + assertSizeAndAlternatingRolesInHistory( + Thread.currentThread().getStackTrace()[1].getMethodName(), + history, + 4, + ImmutableList.of(expectedFirstContent, expectedThirdContent)); + } + + @Test + public void sendMessageWithFunctionCalling_functionCallInResponse() throws IOException { + // Arrange + String firstMessage = "hello!"; + String secondMessage = "What is the weather in Boston?"; + // Making an Json object representing a function declaration + // The following code makes a function declaration + // { + // "name": "getCurrentWeather", + // "description": "Get the current weather in a given location", + // "parameters": { + // "type": "OBJECT", + // "properties": { + // "location": { + // "type": "STRING", + // "description": "location" + // } + // } + // } + // } + JsonObject locationJsonObject = new JsonObject(); + locationJsonObject.addProperty("type", "STRING"); + locationJsonObject.addProperty("description", "location"); + + JsonObject propertiesJsonObject = new JsonObject(); + propertiesJsonObject.add("location", locationJsonObject); + + JsonObject parametersJsonObject = new JsonObject(); + parametersJsonObject.addProperty("type", "OBJECT"); + parametersJsonObject.add("properties", propertiesJsonObject); + + JsonObject jsonObject = new JsonObject(); + jsonObject.addProperty("name", "getCurrentWeather"); + jsonObject.addProperty("description", "Get the current weather in a given location"); + jsonObject.add("parameters", parametersJsonObject); + Tool tool = + Tool.newBuilder() + .addFunctionDeclarations(FunctionDeclarationMaker.fromJsonObject(jsonObject)) + .build(); + ImmutableList tools = ImmutableList.of(tool); + + Content functionResponse = + ContentMaker.fromMultiModalData( + PartMaker.fromFunctionResponse( + "getCurrentWeather", Collections.singletonMap("currentWeather", "snowing"))); + + // Act + chat = model.startChat(); + GenerateContentResponse firstResponse = chat.sendMessage(firstMessage); + GenerateContentResponse secondResponse = chat.withTools(tools).sendMessage(secondMessage); + GenerateContentResponse thirdResponse = chat.sendMessage(functionResponse); + ImmutableList history = chat.getHistory(); + + // Assert + assertThat(firstResponse.getCandidatesList()).hasSize(1); + assertThat(secondResponse.getCandidatesList()).hasSize(1); + assertThat(ResponseHandler.getFunctionCalls(secondResponse).size()).isEqualTo(1); + assertThat(thirdResponse.getCandidatesList()).hasSize(1); + assertSizeAndAlternatingRolesInHistory( + Thread.currentThread().getStackTrace()[1].getMethodName(), + history, + 6, + ImmutableList.of( + ContentMaker.fromString(firstMessage), + ContentMaker.fromString(secondMessage), + functionResponse)); } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java index eb1cc42ff03d..39a65f92b1c4 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/it/ITGenerativeModelIntegrationTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import com.google.api.core.ApiFuture; +import com.google.cloud.vertexai.Transport; import com.google.cloud.vertexai.VertexAI; import com.google.cloud.vertexai.api.Content; import com.google.cloud.vertexai.api.CountTokensResponse; @@ -41,7 +42,9 @@ import org.junit.After; import org.junit.Before; import org.junit.Ignore; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TestName; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -70,6 +73,8 @@ public class ITGenerativeModelIntegrationTest { private GenerativeModel multiModalModel; private GenerativeModel latestGemini; + @Rule public TestName name = new TestName(); + @Before public void setUp() throws IOException { vertexAi = new VertexAI(PROJECT_ID, LOCATION); @@ -113,17 +118,30 @@ private static void assertNonEmptyAndLogTextContentOfResponseStream( } } + @Test + public void generateContent_restTransport_nonEmptyCandidateList() throws IOException { + try (VertexAI vertexAiViaRest = + new VertexAI.Builder() + .setProjectId(PROJECT_ID) + .setLocation(LOCATION) + .setTransport(Transport.REST) + .build()) { + GenerativeModel textModelWithRest = new GenerativeModel(MODEL_NAME_TEXT, vertexAiViaRest); + GenerateContentResponse response = textModelWithRest.generateContent(TEXT); + + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); + } + } + @Test public void generateContent_withPlainText_nonEmptyCandidateList() throws IOException { GenerateContentResponse response = textModel.generateContent(TEXT); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogResponse(methodName, TEXT, response); + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); } @Test public void generateContent_withCompleteConfig_nonEmptyCandidateList() throws IOException { - logger.info(String.format("Generating response for question: %s", TEXT)); Integer maxOutputTokens = 50; GenerationConfig generationConfig = GenerationConfig.newBuilder() @@ -147,8 +165,7 @@ public void generateContent_withCompleteConfig_nonEmptyCandidateList() throws IO String contentText = ResponseHandler.getText(response); int numWords = contentText.split("\\s+").length; - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogResponse(methodName, TEXT, response); + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); // We avoid calling the countTokens service and just assert that the number of words should be // less than the maxOutputTokens since each word on average results in more than one tokens. assertThat(numWords).isAtMost(maxOutputTokens); @@ -159,8 +176,7 @@ public void generateContentAsync_withPlainText_nonEmptyCandidateList() throws Ex ApiFuture responseFuture = textModel.generateContentAsync(TEXT); GenerateContentResponse response = responseFuture.get(); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogResponse(methodName, TEXT, response); + assertNonEmptyAndLogResponse(name.getMethodName(), TEXT, response); } @Test @@ -182,20 +198,18 @@ public void generateContent_withContentList_nonEmptyCandidate() throws IOExcepti // can it. Same for `fromMultiModalList` ContentMaker.fromString(followupPrompt))); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogResponse(methodName, followupPrompt, response); + assertNonEmptyAndLogResponse(name.getMethodName(), followupPrompt, response); } @Test public void generateContentStream_withPlainText_nonEmptyCandidateList() throws IOException { ResponseStream stream = textModel.generateContentStream(TEXT); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogTextContentOfResponseStream(methodName, TEXT, stream); + assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), TEXT, stream); } // TODO(b/333866041): Re-enable byteImage test - @Ignore("The test is not compatible with GraalVM native image test on GitHub.") + @Ignore("TODO(b/333866041):The test is not compatible with GraalVM native image test on GitHub.") @Test public void generateContentStream_withByteImage_nonEmptyCandidateList() throws Exception { ResponseStream stream = @@ -204,8 +218,7 @@ public void generateContentStream_withByteImage_nonEmptyCandidateList() throws E IMAGE_INQUIRY, PartMaker.fromMimeTypeAndData("image/jpeg", imageToBytes(new URL(IMAGE_URL))))); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogTextContentOfResponseStream(methodName, IMAGE_INQUIRY, stream); + assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), IMAGE_INQUIRY, stream); } @Test @@ -215,8 +228,7 @@ public void generateContentStream_withGcsVideo_nonEmptyCandidateList() throws Ex ResponseStream stream = multiModalModel.generateContentStream(content); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogTextContentOfResponseStream(methodName, VIDEO_INQUIRY, stream); + assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), VIDEO_INQUIRY, stream); } @Test @@ -226,8 +238,7 @@ public void generateContentStream_withGcsImage_nonEmptyCandidateList() throws Ex ResponseStream stream = multiModalModel.generateContentStream(content); - String methodName = Thread.currentThread().getStackTrace()[1].getMethodName(); - assertNonEmptyAndLogTextContentOfResponseStream(methodName, IMAGE_INQUIRY, stream); + assertNonEmptyAndLogTextContentOfResponseStream(name.getMethodName(), IMAGE_INQUIRY, stream); } @Test