From c81be82bde3c56a1ccce7824fd0f20d87fef6ff1 Mon Sep 17 00:00:00 2001 From: Qiao Wang Date: Thu, 7 Mar 2024 14:46:34 -0800 Subject: [PATCH] feat: Add getFunctionCalls to ResponseHanlder PiperOrigin-RevId: 613710240 --- .../generativeai/ResponseHandler.java | 51 +++++++++++++------ .../generativeai/ResponseHandlerTest.java | 35 +++++++++++++ 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java index 7f9e5f616da5..44cca9bd4925 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/generativeai/ResponseHandler.java @@ -22,8 +22,10 @@ import com.google.cloud.vertexai.api.Citation; import com.google.cloud.vertexai.api.CitationMetadata; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.api.Part; +import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -33,7 +35,7 @@ public class ResponseHandler { /** - * Get the text message in a GenerateContentResponse. + * Gets the text message in a GenerateContentResponse. * * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance * @return a String that aggregates all the text parts in the response @@ -41,12 +43,7 @@ public class ResponseHandler { * response is blocked by safety reason or unauthorized citations */ public static String getText(GenerateContentResponse response) { - FinishReason finishReason = getFinishReason(response); - if (finishReason == FinishReason.SAFETY) { - throw new IllegalArgumentException("The response is blocked due to safety reason."); - } else if (finishReason == FinishReason.RECITATION) { - throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); - } + checkFinishReason(getFinishReason(response)); String text = ""; List parts = response.getCandidates(0).getContent().getPartsList(); @@ -58,7 +55,26 @@ public static String getText(GenerateContentResponse response) { } /** - * Get the content in a GenerateContentResponse. + * Gets the list of function calls in a GenerateContentResponse. + * + * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance + * @return a list of {@link com.google.cloud.vertexai.api.FunctionCall} in the response + * @throws IllegalArgumentException if the response has 0 or more than 1 candidates, or if the + * response is blocked by safety reason or unauthorized citations + */ + public static ImmutableList getFunctionCalls(GenerateContentResponse response) { + checkFinishReason(getFinishReason(response)); + if (response.getCandidatesCount() == 0) { + return ImmutableList.of(); + } + return response.getCandidates(0).getContent().getPartsList().stream() + .filter((part) -> part.hasFunctionCall()) + .map((part) -> part.getFunctionCall()) + .collect(ImmutableList.toImmutableList()); + } + + /** + * Gets the content in a GenerateContentResponse. * * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance * @return the {@link com.google.cloud.vertexai.api.Content} in the response @@ -66,18 +82,13 @@ public static String getText(GenerateContentResponse response) { * response is blocked by safety reason or unauthorized citations */ public static Content getContent(GenerateContentResponse response) { - FinishReason finishReason = getFinishReason(response); - if (finishReason == FinishReason.SAFETY) { - throw new IllegalArgumentException("The response is blocked due to safety reason."); - } else if (finishReason == FinishReason.RECITATION) { - throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); - } + checkFinishReason(getFinishReason(response)); return response.getCandidates(0).getContent(); } /** - * Get the finish reason in a GenerateContentResponse. + * Gets the finish reason in a GenerateContentResponse. * * @param response a {@link com.google.cloud.vertexai.api.GenerateContentResponse} instance * @return the {@link com.google.cloud.vertexai.api.FinishReason} in the response @@ -93,7 +104,7 @@ public static FinishReason getFinishReason(GenerateContentResponse response) { return response.getCandidates(0).getFinishReason(); } - /** Aggregate a stream of responses into a single GenerateContentResponse. */ + /** Aggregates a stream of responses into a single GenerateContentResponse. */ static GenerateContentResponse aggregateStreamIntoResponse( ResponseStream responseStream) { GenerateContentResponse res = GenerateContentResponse.getDefaultInstance(); @@ -170,4 +181,12 @@ static GenerateContentResponse aggregateStreamIntoResponse( return res; } + + private static void checkFinishReason(FinishReason finishReason) { + if (finishReason == FinishReason.SAFETY) { + throw new IllegalArgumentException("The response is blocked due to safety reason."); + } else if (finishReason == FinishReason.RECITATION) { + throw new IllegalArgumentException("The response is blocked due to unauthorized citations."); + } + } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java index 74c633779565..80487e7355fe 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ResponseHandlerTest.java @@ -25,8 +25,10 @@ import com.google.cloud.vertexai.api.Citation; import com.google.cloud.vertexai.api.CitationMetadata; import com.google.cloud.vertexai.api.Content; +import com.google.cloud.vertexai.api.FunctionCall; import com.google.cloud.vertexai.api.GenerateContentResponse; import com.google.cloud.vertexai.api.Part; +import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.Iterator; import org.junit.Rule; @@ -47,6 +49,13 @@ public final class ResponseHandlerTest { .addParts(Part.newBuilder().setText(TEXT_1)) .addParts(Part.newBuilder().setText(TEXT_2)) .build(); + private static final Content CONTENT_WITH_FNCTION_CALL = + Content.newBuilder() + .addParts(Part.newBuilder().setText(TEXT_1)) + .addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance())) + .addParts(Part.newBuilder().setText(TEXT_2)) + .addParts(Part.newBuilder().setFunctionCall(FunctionCall.getDefaultInstance())) + .build(); private static final Citation CITATION_1 = Citation.newBuilder().setUri("gs://citation1").setStartIndex(1).setEndIndex(2).build(); private static final Citation CITATION_2 = @@ -61,10 +70,14 @@ public final class ResponseHandlerTest { .setContent(CONTENT) .setCitationMetadata(CitationMetadata.newBuilder().addCitations(CITATION_2)) .build(); + private static final Candidate CANDIDATE_3 = + Candidate.newBuilder().setContent(CONTENT_WITH_FNCTION_CALL).build(); private static final GenerateContentResponse RESPONSE_1 = GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_1).build(); private static final GenerateContentResponse RESPONSE_2 = GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_2).build(); + private static final GenerateContentResponse RESPONSE_3 = + GenerateContentResponse.newBuilder().addCandidates(CANDIDATE_3).build(); private static final GenerateContentResponse INVALID_RESPONSE = GenerateContentResponse.newBuilder() .addCandidates(CANDIDATE_1) @@ -94,6 +107,28 @@ public void testGetTextFromInvalidResponse() { INVALID_RESPONSE.getCandidatesCount())); } + @Test + public void testGetFunctionCallsFromResponse() { + ImmutableList functionCalls = ResponseHandler.getFunctionCalls(RESPONSE_3); + assertThat(functionCalls.size()).isEqualTo(2); + assertThat(functionCalls.get(0)).isEqualTo(FunctionCall.getDefaultInstance()); + assertThat(functionCalls.get(1)).isEqualTo(FunctionCall.getDefaultInstance()); + } + + @Test + public void testGetFunctionCallsFromInvalidResponse() { + IllegalArgumentException thrown = + assertThrows( + IllegalArgumentException.class, + () -> ResponseHandler.getFunctionCalls(INVALID_RESPONSE)); + assertThat(thrown) + .hasMessageThat() + .isEqualTo( + String.format( + "This response should have exactly 1 candidate, but it has %s.", + INVALID_RESPONSE.getCandidatesCount())); + } + @Test public void testGetContentFromResponse() { Content content = ResponseHandler.getContent(RESPONSE_1);