Skip to content

Commit

Permalink
fix: [vertexai] remove last content in the chat history if API call f…
Browse files Browse the repository at this point in the history
…ails

PiperOrigin-RevId: 638391924
  • Loading branch information
jaycee-li authored and copybara-github committed May 29, 2024
1 parent 19b2c49 commit 3d62afb
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,18 @@ public ResponseStream<GenerateContentResponse> sendMessageStream(String text) th
* stream by stream() method.
*/
public ResponseStream<GenerateContentResponse> sendMessageStream(Content content)
throws IOException, IllegalArgumentException {
throws IOException {
checkLastResponseAndEditHistory();
history.add(content);
ResponseStream<GenerateContentResponse> respStream = model.generateContentStream(history);

ResponseStream<GenerateContentResponse> respStream;
try {
respStream = model.generateContentStream(history);
} catch (IOException e) {
// If the API call fails, remove the last content from the history before throwing.
removeLastContent();
throw e;
}
setCurrentResponseStream(Optional.of(respStream));

return respStream;
Expand All @@ -157,8 +165,17 @@ public GenerateContentResponse sendMessage(String text) throws IOException {
public GenerateContentResponse sendMessage(Content content) throws IOException {
checkLastResponseAndEditHistory();
history.add(content);
GenerateContentResponse response = model.generateContent(history);

GenerateContentResponse response;
try {
response = model.generateContent(history);
} catch (IOException e) {
// If the API call fails, remove the last content from the history before throwing.
removeLastContent();
throw e;
}
setCurrentResponse(Optional.of(response));

return response;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ public static Part fromFunctionResponse(String name, Map<String, Object> respons
if (value instanceof String) {
String stringValue = (String) value;
structBuilder.putFields(key, Value.newBuilder().setStringValue(stringValue).build());
} else if (value instanceof Double) {
Double doubleValue = (Double) value;
} else if (value instanceof Number) {
// Convert a number to a double value since the proto only supports double.
double doubleValue = ((Number) value).doubleValue();
structBuilder.putFields(key, Value.newBuilder().setNumberValue(doubleValue).build());
} else if (value instanceof Boolean) {
Boolean boolValue = (Boolean) value;
Expand All @@ -126,7 +127,7 @@ public static Part fromFunctionResponse(String name, Map<String, Object> respons
} else {
throw new IllegalArgumentException(
"The value in the map can only be one of the following format: "
+ "String, Double, Boolean, null.");
+ "String, Number, Boolean, null.");
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void fromMimeTypeAndData_dataInURI() throws URISyntaxException {
}

@Test
public void testFromFunctionResponseWithStruct() {
public void testFromFunctionResponseWithStruct_containsRightFields() {
String functionName = "getCurrentWeather";
Struct functionResponse =
Struct.newBuilder()
Expand All @@ -95,7 +95,7 @@ public void testFromFunctionResponseWithStruct() {
}

@Test
public void testFromFunctionResponseWithMap() {
public void testFromFunctionResponseWithMap_containsRightFields() {
String functionName = "getCurrentWeather";
Map<String, Object> functionResponse = new HashMap<>();
functionResponse.put("currentWeather", "Super nice!");
Expand All @@ -115,7 +115,29 @@ public void testFromFunctionResponseWithMap() {
}

@Test
public void testFromFunctionResponseWithInvalidMap() {
public void testFromFunctionResponseWithNumberValues_containsRightFields() {
String functionName = "getCurrentWeather";
Map<String, Object> functionResponse = new HashMap<>();
functionResponse.put("integerNumber", 85);
functionResponse.put("doubleNumber", 85.0);
functionResponse.put("floatNumber", 85.0f);
functionResponse.put("longNumber", 85L);
functionResponse.put("shortNumber", (short) 85);

Part part = PartMaker.fromFunctionResponse(functionName, functionResponse);

assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather");

Map<String, Value> fieldsMap = part.getFunctionResponse().getResponse().getFieldsMap();
assertThat(fieldsMap.get("integerNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("doubleNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("floatNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("longNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("shortNumber").getNumberValue()).isEqualTo(85.0);
}

@Test
public void testFromFunctionResponseWithInvalidMap_throwsIllegalArgumentException() {
String functionName = "getCurrentWeather";
Map<String, Object> invalidResponse = new HashMap<>();
invalidResponse.put("currentWeather", new byte[] {1, 2, 3});
Expand All @@ -127,6 +149,6 @@ public void testFromFunctionResponseWithInvalidMap() {
.hasMessageThat()
.isEqualTo(
"The value in the map can only be one of the following format: "
+ "String, Double, Boolean, null.");
+ "String, Number, Boolean, null.");
}
}

0 comments on commit 3d62afb

Please sign in to comment.