Skip to content

Commit

Permalink
chore: [vertexai]Integration test for function calling in ChatSession.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625428210
  • Loading branch information
Zhenyi Qi authored and copybara-github committed Apr 16, 2024
1 parent ae22f1c commit 500d24b
Showing 1 changed file with 107 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.generativeai.ChatSession;
import com.google.cloud.vertexai.generativeai.ContentMaker;
import com.google.cloud.vertexai.generativeai.FunctionDeclarationMaker;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.PartMaker;
import com.google.cloud.vertexai.generativeai.ResponseHandler;
import com.google.cloud.vertexai.generativeai.ResponseStream;
import com.google.common.collect.ImmutableList;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Collections;
import java.util.logging.Logger;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -60,6 +65,24 @@ public void tearDown() throws IOException {
vertexAi.close();
}

private static void assertSizeAndAlternatingRolesInHistory(
String methodName,
ImmutableList<Content> history,
int expectedSize,
ImmutableList<Content> expectedUserContent) {
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("%s: The whole history is:\n%s", methodName, history));
assertThat(history.size()).isEqualTo(expectedSize);
for (int i = 1; i < expectedSize; i += 2) {
if (!expectedUserContent.isEmpty()) {
assertThat(history.get(i - 1)).isEqualTo(expectedUserContent.get((i - 1) / 2));
}
assertThat(history.get(i - 1).getRole()).isEqualTo("user");
assertThat(history.get(i).getRole()).isEqualTo("model");
}
}

@Test
public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
// Arrange
Expand All @@ -77,17 +100,14 @@ public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
assertThat(resp.getCandidatesList()).isNotEmpty();
}
GenerateContentResponse response = chat.sendMessage(secondMessage);
List<Content> history = chat.getHistory();
ImmutableList<Content> history = chat.getHistory();

// Assert
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("The whole history is:\n%s", history));
assertThat(history.size()).isEqualTo(4);
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
assertThat(history.get(1).getRole()).isEqualTo("model");
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
assertThat(history.get(3).getRole()).isEqualTo("model");
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
4,
ImmutableList.of(expectedFirstContent, expectedThirdContent));
}

@Test
Expand All @@ -98,8 +118,8 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I
Content expectedFirstContent = ContentMaker.fromString(firstMessage);
Content expectedThirdContent = ContentMaker.fromString(secondMessage);
GenerationConfig config = GenerationConfig.newBuilder().setTemperature(0.7F).build();
List<SafetySetting> safetySettings =
Arrays.asList(
ImmutableList<SafetySetting> safetySettings =
ImmutableList.of(
SafetySetting.newBuilder()
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH)
Expand All @@ -122,14 +142,79 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I
.sendMessage(secondMessage);

// Assert
List<Content> history = chat.getHistory();
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("The whole history is:\n%s", history));
assertThat(history.size()).isEqualTo(4);
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
assertThat(history.get(1).getRole()).isEqualTo("model");
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
assertThat(history.get(3).getRole()).isEqualTo("model");
ImmutableList<Content> history = chat.getHistory();
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
4,
ImmutableList.of(expectedFirstContent, expectedThirdContent));
}

@Test
public void sendMessageWithFunctionCalling_functionCallInResponse() throws IOException {
// Arrange
String firstMessage = "hello!";
String secondMessage = "What is the weather in Boston?";
// Making an Json object representing a function declaration
// The following code makes a function declaration
// {
// "name": "getCurrentWeather",
// "description": "Get the current weather in a given location",
// "parameters": {
// "type": "OBJECT",
// "properties": {
// "location": {
// "type": "STRING",
// "description": "location"
// }
// }
// }
// }
JsonObject locationJsonObject = new JsonObject();
locationJsonObject.addProperty("type", "STRING");
locationJsonObject.addProperty("description", "location");

JsonObject propertiesJsonObject = new JsonObject();
propertiesJsonObject.add("location", locationJsonObject);

JsonObject parametersJsonObject = new JsonObject();
parametersJsonObject.addProperty("type", "OBJECT");
parametersJsonObject.add("properties", propertiesJsonObject);

JsonObject jsonObject = new JsonObject();
jsonObject.addProperty("name", "getCurrentWeather");
jsonObject.addProperty("description", "Get the current weather in a given location");
jsonObject.add("parameters", parametersJsonObject);
Tool tool =
Tool.newBuilder()
.addFunctionDeclarations(FunctionDeclarationMaker.fromJsonObject(jsonObject))
.build();
ImmutableList<Tool> tools = ImmutableList.of(tool);

Content functionResponse =
ContentMaker.fromMultiModalData(
PartMaker.fromFunctionResponse(
"getCurrentWeather", Collections.singletonMap("currentWeather", "snowing")));

// Act
chat = model.startChat();
GenerateContentResponse firstResponse = chat.sendMessage(firstMessage);
GenerateContentResponse secondResponse = chat.withTools(tools).sendMessage(secondMessage);
GenerateContentResponse thirdResponse = chat.sendMessage(functionResponse);
ImmutableList<Content> history = chat.getHistory();

// Assert
assertThat(firstResponse.getCandidatesList()).hasSize(1);
assertThat(secondResponse.getCandidatesList()).hasSize(1);
assertThat(ResponseHandler.getFunctionCalls(secondResponse).size()).isEqualTo(1);
assertThat(thirdResponse.getCandidatesList()).hasSize(1);
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
6,
ImmutableList.of(
ContentMaker.fromString(firstMessage),
ContentMaker.fromString(secondMessage),
functionResponse));
}
}

0 comments on commit 500d24b

Please sign in to comment.