From 3d62afb770009faaa7586b9e47a85e62f5b280e1 Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Wed, 29 May 2024 13:24:09 -0700 Subject: [PATCH] fix: [vertexai] remove last content in the chat history if API call fails PiperOrigin-RevId: 638391924 --- .../vertexai/generativeai/ChatSession.java | 23 ++++++++++++-- .../vertexai/generativeai/PartMaker.java | 7 +++-- .../vertexai/generativeai/PartMakerTest.java | 30 ++++++++++++++++--- 3 files changed, 50 insertions(+), 10 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java index 3f1d1b1f9ba5..b72bf6eedac9 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ChatSession.java @@ -129,10 +129,18 @@ public ResponseStream sendMessageStream(String text) th * stream by stream() method. */ public ResponseStream sendMessageStream(Content content) - throws IOException, IllegalArgumentException { + throws IOException { checkLastResponseAndEditHistory(); history.add(content); - ResponseStream respStream = model.generateContentStream(history); + + ResponseStream respStream; + try { + respStream = model.generateContentStream(history); + } catch (IOException e) { + // If the API call fails, remove the last content from the history before throwing. + removeLastContent(); + throw e; + } setCurrentResponseStream(Optional.of(respStream)); return respStream; @@ -157,8 +165,17 @@ public GenerateContentResponse sendMessage(String text) throws IOException { public GenerateContentResponse sendMessage(Content content) throws IOException { checkLastResponseAndEditHistory(); history.add(content); - GenerateContentResponse response = model.generateContent(history); + + GenerateContentResponse response; + try { + response = model.generateContent(history); + } catch (IOException e) { + // If the API call fails, remove the last content from the history before throwing. + removeLastContent(); + throw e; + } setCurrentResponse(Optional.of(response)); + return response; } diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java index eb5ca4a49e56..4e045f70f1a7 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java @@ -114,8 +114,9 @@ public static Part fromFunctionResponse(String name, Map respons if (value instanceof String) { String stringValue = (String) value; structBuilder.putFields(key, Value.newBuilder().setStringValue(stringValue).build()); - } else if (value instanceof Double) { - Double doubleValue = (Double) value; + } else if (value instanceof Number) { + // Convert a number to a double value since the proto only supports double. + double doubleValue = ((Number) value).doubleValue(); structBuilder.putFields(key, Value.newBuilder().setNumberValue(doubleValue).build()); } else if (value instanceof Boolean) { Boolean boolValue = (Boolean) value; @@ -126,7 +127,7 @@ public static Part fromFunctionResponse(String name, Map respons } else { throw new IllegalArgumentException( "The value in the map can only be one of the following format: " - + "String, Double, Boolean, null."); + + "String, Number, Boolean, null."); } }); diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java index 991f8bbc733f..206a49287eed 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java @@ -79,7 +79,7 @@ public void fromMimeTypeAndData_dataInURI() throws URISyntaxException { } @Test - public void testFromFunctionResponseWithStruct() { + public void testFromFunctionResponseWithStruct_containsRightFields() { String functionName = "getCurrentWeather"; Struct functionResponse = Struct.newBuilder() @@ -95,7 +95,7 @@ public void testFromFunctionResponseWithStruct() { } @Test - public void testFromFunctionResponseWithMap() { + public void testFromFunctionResponseWithMap_containsRightFields() { String functionName = "getCurrentWeather"; Map functionResponse = new HashMap<>(); functionResponse.put("currentWeather", "Super nice!"); @@ -115,7 +115,29 @@ public void testFromFunctionResponseWithMap() { } @Test - public void testFromFunctionResponseWithInvalidMap() { + public void testFromFunctionResponseWithNumberValues_containsRightFields() { + String functionName = "getCurrentWeather"; + Map functionResponse = new HashMap<>(); + functionResponse.put("integerNumber", 85); + functionResponse.put("doubleNumber", 85.0); + functionResponse.put("floatNumber", 85.0f); + functionResponse.put("longNumber", 85L); + functionResponse.put("shortNumber", (short) 85); + + Part part = PartMaker.fromFunctionResponse(functionName, functionResponse); + + assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather"); + + Map fieldsMap = part.getFunctionResponse().getResponse().getFieldsMap(); + assertThat(fieldsMap.get("integerNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("doubleNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("floatNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("longNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("shortNumber").getNumberValue()).isEqualTo(85.0); + } + + @Test + public void testFromFunctionResponseWithInvalidMap_throwsIllegalArgumentException() { String functionName = "getCurrentWeather"; Map invalidResponse = new HashMap<>(); invalidResponse.put("currentWeather", new byte[] {1, 2, 3}); @@ -127,6 +149,6 @@ public void testFromFunctionResponseWithInvalidMap() { .hasMessageThat() .isEqualTo( "The value in the map can only be one of the following format: " - + "String, Double, Boolean, null."); + + "String, Number, Boolean, null."); } }