Skip to content

Commit

Permalink
fix: [vertexai] check null and empty values for input String (#10658)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621886111

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Apr 8, 2024
1 parent 76d2b2c commit 949889d
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

package com.google.cloud.vertexai.generativeai;

import static com.google.common.base.Preconditions.checkArgument;

import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.Part;
import com.google.common.base.Strings;

/** Helper class to create content. */
public class ContentMaker {
private static String role = "user";
private static final String DEFAULT_ROLE = "user";

/**
* Creates a ContentMakerForRole for a given role.
Expand All @@ -34,6 +37,7 @@ public static ContentMakerForRole forRole(String role) {
}

private static Content fromStringWithRole(String role, String text) {
checkArgument(!Strings.isNullOrEmpty(text), "text message can't be null or empty.");
return Content.newBuilder().addParts(Part.newBuilder().setText(text)).setRole(role).build();
}

Expand Down Expand Up @@ -61,7 +65,7 @@ private static Content fromMultiModalDataWithRole(String role, Object... multiMo
* <p>To create a text content for "model", use `ContentMaker.forRole("model").fromString(text);
*/
public static Content fromString(String text) {
return fromStringWithRole(role, text);
return fromStringWithRole(DEFAULT_ROLE, text);
}

/**
Expand All @@ -76,8 +80,9 @@ public static Content fromString(String text) {
* could be either a single String or a Part. When it's a single string, it's converted to a
* {@link com.google.cloud.vertexai.api.Part} that has the Text field set.
*/
// TODO(b/333097480) Deprecate ContentMakerForRole
public static Content fromMultiModalData(Object... multiModalData) {
return fromMultiModalDataWithRole(role, multiModalData);
return fromMultiModalDataWithRole(DEFAULT_ROLE, multiModalData);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ public GenerativeModel withTools(List<Tool> tools) {
*/
@BetaApi
public CountTokensResponse countTokens(String text) throws IOException {
// TODO(b/330402637): Check null and empty values for the input string.
return countTokens(ContentMaker.fromString(text));
}

Expand All @@ -255,6 +254,7 @@ public CountTokensResponse countTokens(Content content) throws IOException {
*/
@BetaApi
public CountTokensResponse countTokens(List<Content> contents) throws IOException {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
CountTokensRequest request =
CountTokensRequest.newBuilder()
.setEndpoint(resourceName)
Expand Down Expand Up @@ -287,7 +287,6 @@ private CountTokensResponse countTokensFromRequest(CountTokensRequest request)
* @throws IOException if an I/O error occurs while making the API call
*/
public GenerateContentResponse generateContent(String text) throws IOException {
// TODO(b/330402637): Check null and empty values for the input string.
return generateContent(ContentMaker.fromString(text));
}

Expand Down Expand Up @@ -447,6 +446,7 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
* contents and model configurations.
*/
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
return GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.cloud.vertexai.generativeai;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;

import com.google.cloud.vertexai.api.Content;
import com.google.protobuf.ByteString;
Expand All @@ -38,6 +39,24 @@ public void fromString_returnsContentWithText() {
assertThat(content.getParts(0).getText()).isEqualTo(stringInput);
}

@Test
public void fromString_throwsIllegalArgumentException_withEmptyString() {
String stringInput = "";

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> ContentMaker.fromString(stringInput));
assertThat(thrown).hasMessageThat().isEqualTo("text message can't be null or empty.");
}

@Test
public void fromString_throwsIllegalArgumentException_withNullString() {
String stringInput = null;

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> ContentMaker.fromString(stringInput));
assertThat(thrown).hasMessageThat().isEqualTo("text message can't be null or empty.");
}

@Test
public void forRole_returnsContentWithArbitraryRoleSet() {
// Although in our docstring, we said only three roles are acceptable, we make sure the code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import com.google.cloud.vertexai.api.Type;
import com.google.cloud.vertexai.api.VertexAISearch;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -452,6 +453,16 @@ public void testGenerateContentwithFluentApi() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContent_withNullContents_throws() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
List<Content> contents = null;

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> model.generateContent(contents));
assertThat(thrown).hasMessageThat().isEqualTo("contents can't be null or empty.");
}

@Test
public void testGenerateContentStreamwithText() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down Expand Up @@ -636,6 +647,16 @@ public void testGenerateContentStreamwithFluentApi() throws Exception {
assertThat(request.getValue().getTools(0)).isEqualTo(TOOL);
}

@Test
public void generateContentStream_withEmptyContents_throws() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
List<Content> contents = new ArrayList<>();

IllegalArgumentException thrown =
assertThrows(IllegalArgumentException.class, () -> model.generateContentStream(contents));
assertThat(thrown).hasMessageThat().isEqualTo("contents can't be null or empty.");
}

@Test
public void generateContentAsync_withText_sendsCorrectRequest() throws Exception {
model = new GenerativeModel(MODEL_NAME, vertexAi);
Expand Down

0 comments on commit 949889d

Please sign in to comment.