Skip to content

Commit

Permalink
chore: [vertexai] Add integration test for model with generationConfi…
Browse files Browse the repository at this point in the history
…g and safetySettings. (#10668)

PiperOrigin-RevId: 622900461

Co-authored-by: Zhenyi Qi <[email protected]>
  • Loading branch information
copybara-service[bot] and Zhenyi Qi authored Apr 11, 2024
1 parent a11077f commit 4cbd770
Showing 1 changed file with 38 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.CountTokensResponse;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.HarmCategory;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.SafetySetting;
import com.google.cloud.vertexai.generativeai.ContentMaker;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.PartMaker;
Expand All @@ -31,6 +34,7 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.logging.Logger;
import javax.imageio.ImageIO;
import org.junit.After;
Expand Down Expand Up @@ -113,6 +117,40 @@ public void generateContent_withPlainText_respondWithNonEmptyCandidateList() thr
assertNonEmptyAndLogResponse(methodName, TEXT, response);
}

@Test
public void generateContent_withCompleteConfig_respondWithNonEmptyCandidateList()
throws IOException {
logger.info(String.format("Generating response for question: %s", TEXT));
Integer maxOutputTokens = 50;
GenerationConfig generationConfig =
GenerationConfig.newBuilder()
.setMaxOutputTokens(maxOutputTokens)
.setTemperature(0)
.setTopP(0.3f)
.setTopK(2)
.addStopSequences("<end_of_sentence>")
.build();
SafetySetting safetySetting =
SafetySetting.newBuilder()
.setCategory(HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT)
.setThreshold(SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
.build();
GenerativeModel newModel =
textModel
.withGenerationConfig(generationConfig)
.withSafetySettings(Arrays.asList(safetySetting));

GenerateContentResponse response = newModel.generateContent(TEXT);
String contentText = ResponseHandler.getText(response);
int numWords = contentText.split("\\s+").length;

String methodName = Thread.currentThread().getStackTrace()[1].getMethodName();
assertNonEmptyAndLogResponse(methodName, TEXT, response);
// We avoid calling the countTokens service and just assert that the number of words should be
// less than the maxOutputTokens since each word on average results in more than one tokens.
assertThat(numWords).isAtMost(maxOutputTokens);
}

@Test
public void generateContentStream_withPlainText_respondWithNonEmptyCandidateList()
throws IOException {
Expand Down

0 comments on commit 4cbd770

Please sign in to comment.