Skip to content

Commit

Permalink
feat: [vertexai] support number types in PartMaker.fromFunctionResponse
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638351101
  • Loading branch information
jaycee-li authored and copybara-github committed May 29, 2024
1 parent d14d1f7 commit 114327d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ public static Part fromFunctionResponse(String name, Map<String, Object> respons
if (value instanceof String) {
String stringValue = (String) value;
structBuilder.putFields(key, Value.newBuilder().setStringValue(stringValue).build());
} else if (value instanceof Double) {
Double doubleValue = (Double) value;
} else if (value instanceof Number) {
// Convert a number to a double value since the proto only supports double.
double doubleValue = ((Number) value).doubleValue();
structBuilder.putFields(key, Value.newBuilder().setNumberValue(doubleValue).build());
} else if (value instanceof Boolean) {
Boolean boolValue = (Boolean) value;
Expand All @@ -126,7 +127,7 @@ public static Part fromFunctionResponse(String name, Map<String, Object> respons
} else {
throw new IllegalArgumentException(
"The value in the map can only be one of the following format: "
+ "String, Double, Boolean, null.");
+ "String, Number, Boolean, null.");
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ public void fromMimeTypeAndData_dataInURI() throws URISyntaxException {
}

@Test
public void testFromFunctionResponseWithStruct() {
public void testFromFunctionResponseWithStruct_containsRightFields() {
String functionName = "getCurrentWeather";
Struct functionResponse =
Struct.newBuilder()
Expand All @@ -95,7 +95,7 @@ public void testFromFunctionResponseWithStruct() {
}

@Test
public void testFromFunctionResponseWithMap() {
public void testFromFunctionResponseWithMap_containsRightFields() {
String functionName = "getCurrentWeather";
Map<String, Object> functionResponse = new HashMap<>();
functionResponse.put("currentWeather", "Super nice!");
Expand All @@ -115,7 +115,29 @@ public void testFromFunctionResponseWithMap() {
}

@Test
public void testFromFunctionResponseWithInvalidMap() {
public void testFromFunctionResponseWithNumberValues_containsRightFields() {
String functionName = "getCurrentWeather";
Map<String, Object> functionResponse = new HashMap<>();
functionResponse.put("integerNumber", 85);
functionResponse.put("doubleNumber", 85.0);
functionResponse.put("floatNumber", 85.0f);
functionResponse.put("longNumber", 85L);
functionResponse.put("shortNumber", (short) 85);

Part part = PartMaker.fromFunctionResponse(functionName, functionResponse);

assertThat(part.getFunctionResponse().getName()).isEqualTo("getCurrentWeather");

Map<String, Value> fieldsMap = part.getFunctionResponse().getResponse().getFieldsMap();
assertThat(fieldsMap.get("integerNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("doubleNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("floatNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("longNumber").getNumberValue()).isEqualTo(85.0);
assertThat(fieldsMap.get("shortNumber").getNumberValue()).isEqualTo(85.0);
}

@Test
public void testFromFunctionResponseWithInvalidMap_throwsIllegalArgumentException() {
String functionName = "getCurrentWeather";
Map<String, Object> invalidResponse = new HashMap<>();
invalidResponse.put("currentWeather", new byte[] {1, 2, 3});
Expand All @@ -127,6 +149,6 @@ public void testFromFunctionResponseWithInvalidMap() {
.hasMessageThat()
.isEqualTo(
"The value in the map can only be one of the following format: "
+ "String, Double, Boolean, null.");
+ "String, Number, Boolean, null.");
}
}

0 comments on commit 114327d

Please sign in to comment.