From 114327d08e551e0420a35638b587a108a784c0a7 Mon Sep 17 00:00:00 2001 From: Jaycee Li Date: Wed, 29 May 2024 11:18:59 -0700 Subject: [PATCH] feat: [vertexai] support number types in PartMaker.fromFunctionResponse PiperOrigin-RevId: 638351101 --- .../vertexai/generativeai/PartMaker.java | 7 +++-- .../vertexai/generativeai/PartMakerTest.java | 30 ++++++++++++++++--- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java index eb5ca4a49e56..4e045f70f1a7 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/PartMaker.java @@ -114,8 +114,9 @@ public static Part fromFunctionResponse(String name, Map respons if (value instanceof String) { String stringValue = (String) value; structBuilder.putFields(key, Value.newBuilder().setStringValue(stringValue).build()); - } else if (value instanceof Double) { - Double doubleValue = (Double) value; + } else if (value instanceof Number) { + // Convert a number to a double value since the proto only supports double. + double doubleValue = ((Number) value).doubleValue(); structBuilder.putFields(key, Value.newBuilder().setNumberValue(doubleValue).build()); } else if (value instanceof Boolean) { Boolean boolValue = (Boolean) value; @@ -126,7 +127,7 @@ public static Part fromFunctionResponse(String name, Map respons } else { throw new IllegalArgumentException( "The value in the map can only be one of the following format: " - + "String, Double, Boolean, null."); + + "String, Number, Boolean, null."); } }); diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java index 991f8bbc733f..206a49287eed 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/PartMakerTest.java @@ -79,7 +79,7 @@ public void fromMimeTypeAndData_dataInURI() throws URISyntaxException { } @Test - public void testFromFunctionResponseWithStruct() { + public void testFromFunctionResponseWithStruct_containsRightFields() { String functionName = "getCurrentWeather"; Struct functionResponse = Struct.newBuilder() @@ -95,7 +95,7 @@ public void testFromFunctionResponseWithStruct() { } @Test - public void testFromFunctionResponseWithMap() { + public void testFromFunctionResponseWithMap_containsRightFields() { String functionName = "getCurrentWeather"; Map functionResponse = new HashMap<>(); functionResponse.put("currentWeather", "Super nice!"); @@ -115,7 +115,29 @@ public void testFromFunctionResponseWithMap() { } @Test - public void testFromFunctionResponseWithInvalidMap() { + public void testFromFunctionResponseWithNumberValues_containsRightFields() { + String functionName = "getCurrentWeather"; + Map functionResponse = new HashMap<>(); + functionResponse.put("integerNumber", 85); + functionResponse.put("doubleNumber", 85.0); + functionResponse.put("floatNumber", 85.0f); + functionResponse.put("longNumber", 85L); + functionResponse.put("shortNumber", (short) 85); + + Part part = PartMaker.fromFunctionResponse(functionName, functionResponse); + + assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather"); + + Map fieldsMap = part.getFunctionResponse().getResponse().getFieldsMap(); + assertThat(fieldsMap.get("integerNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("doubleNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("floatNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("longNumber").getNumberValue()).isEqualTo(85.0); + assertThat(fieldsMap.get("shortNumber").getNumberValue()).isEqualTo(85.0); + } + + @Test + public void testFromFunctionResponseWithInvalidMap_throwsIllegalArgumentException() { String functionName = "getCurrentWeather"; Map invalidResponse = new HashMap<>(); invalidResponse.put("currentWeather", new byte[] {1, 2, 3}); @@ -127,6 +149,6 @@ public void testFromFunctionResponseWithInvalidMap() { .hasMessageThat() .isEqualTo( "The value in the map can only be one of the following format: " - + "String, Double, Boolean, null."); + + "String, Number, Boolean, null."); } }