Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: [vertexai]Integration test for function calling in ChatSession. #10700

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.generativeai.ChatSession;
import com.google.cloud.vertexai.generativeai.ContentMaker;
import com.google.cloud.vertexai.generativeai.FunctionDeclarationMaker;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.PartMaker;
import com.google.cloud.vertexai.generativeai.ResponseHandler;
import com.google.cloud.vertexai.generativeai.ResponseStream;
import com.google.common.collect.ImmutableList;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Collections;
import java.util.logging.Logger;
import org.junit.After;
import org.junit.Before;
Expand Down Expand Up @@ -60,6 +65,24 @@ public void tearDown() throws IOException {
vertexAi.close();
}

private static void assertSizeAndAlternatingRolesInHistory(
String methodName,
ImmutableList<Content> history,
int expectedSize,
ImmutableList<Content> expectedUserContent) {
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("%s: The whole history is:\n%s", methodName, history));
assertThat(history.size()).isEqualTo(expectedSize);
for (int i = 1; i < expectedSize; i += 2) {
if (!expectedUserContent.isEmpty()) {
assertThat(history.get(i - 1)).isEqualTo(expectedUserContent.get((i - 1) / 2));
}
assertThat(history.get(i - 1).getRole()).isEqualTo("user");
assertThat(history.get(i).getRole()).isEqualTo("model");
}
}

@Test
public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
// Arrange
Expand All @@ -77,17 +100,14 @@ public void sendMessageMixedStreamAndUnary_historyOfFour() throws IOException {
assertThat(resp.getCandidatesList()).isNotEmpty();
}
GenerateContentResponse response = chat.sendMessage(secondMessage);
List<Content> history = chat.getHistory();
ImmutableList<Content> history = chat.getHistory();

// Assert
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("The whole history is:\n%s", history));
assertThat(history.size()).isEqualTo(4);
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
assertThat(history.get(1).getRole()).isEqualTo("model");
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
assertThat(history.get(3).getRole()).isEqualTo("model");
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
4,
ImmutableList.of(expectedFirstContent, expectedThirdContent));
}

@Test
Expand All @@ -98,8 +118,8 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I
Content expectedFirstContent = ContentMaker.fromString(firstMessage);
Content expectedThirdContent = ContentMaker.fromString(secondMessage);
GenerationConfig config = GenerationConfig.newBuilder().setTemperature(0.7F).build();
List<SafetySetting> safetySettings =
Arrays.asList(
ImmutableList<SafetySetting> safetySettings =
ImmutableList.of(
SafetySetting.newBuilder()
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH)
Expand All @@ -122,14 +142,79 @@ public void sendMessageWithNewConfigs_historyContainsFullConversation() throws I
.sendMessage(secondMessage);

// Assert
List<Content> history = chat.getHistory();
// GenAI output is flaky so we always print out the response.
// For the same reason, we don't do assertions much.
logger.info(String.format("The whole history is:\n%s", history));
assertThat(history.size()).isEqualTo(4);
assertThat(history.get(0)).isEqualTo(expectedFirstContent);
assertThat(history.get(1).getRole()).isEqualTo("model");
assertThat(history.get(2)).isEqualTo(expectedThirdContent);
assertThat(history.get(3).getRole()).isEqualTo("model");
ImmutableList<Content> history = chat.getHistory();
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
4,
ImmutableList.of(expectedFirstContent, expectedThirdContent));
}

@Test
public void sendMessageWithFunctionCalling_functionCallInResponse() throws IOException {
// Arrange
String firstMessage = "hello!";
String secondMessage = "What is the weather in Boston?";
// Making an Json object representing a function declaration
// The following code makes a function declaration
// {
// "name": "getCurrentWeather",
// "description": "Get the current weather in a given location",
// "parameters": {
// "type": "OBJECT",
// "properties": {
// "location": {
// "type": "STRING",
// "description": "location"
// }
// }
// }
// }
JsonObject locationJsonObject = new JsonObject();
locationJsonObject.addProperty("type", "STRING");
locationJsonObject.addProperty("description", "location");

JsonObject propertiesJsonObject = new JsonObject();
propertiesJsonObject.add("location", locationJsonObject);

JsonObject parametersJsonObject = new JsonObject();
parametersJsonObject.addProperty("type", "OBJECT");
parametersJsonObject.add("properties", propertiesJsonObject);

JsonObject jsonObject = new JsonObject();
jsonObject.addProperty("name", "getCurrentWeather");
jsonObject.addProperty("description", "Get the current weather in a given location");
jsonObject.add("parameters", parametersJsonObject);
Tool tool =
Tool.newBuilder()
.addFunctionDeclarations(FunctionDeclarationMaker.fromJsonObject(jsonObject))
.build();
ImmutableList<Tool> tools = ImmutableList.of(tool);

Content functionResponse =
ContentMaker.fromMultiModalData(
PartMaker.fromFunctionResponse(
"getCurrentWeather", Collections.singletonMap("currentWeather", "snowing")));

// Act
chat = model.startChat();
GenerateContentResponse firstResponse = chat.sendMessage(firstMessage);
GenerateContentResponse secondResponse = chat.withTools(tools).sendMessage(secondMessage);
GenerateContentResponse thirdResponse = chat.sendMessage(functionResponse);
ImmutableList<Content> history = chat.getHistory();

// Assert
assertThat(firstResponse.getCandidatesList()).hasSize(1);
assertThat(secondResponse.getCandidatesList()).hasSize(1);
assertThat(ResponseHandler.getFunctionCalls(secondResponse).size()).isEqualTo(1);
assertThat(thirdResponse.getCandidatesList()).hasSize(1);
assertSizeAndAlternatingRolesInHistory(
Thread.currentThread().getStackTrace()[1].getMethodName(),
history,
6,
ImmutableList.of(
ContentMaker.fromString(firstMessage),
ContentMaker.fromString(secondMessage),
functionResponse));
}
}
Loading