Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [vertexai] adding system instruction support #10775

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

/** This class holds a generative model that can complete what you provided. */
public final class GenerativeModel {
Expand All @@ -45,6 +46,7 @@ public final class GenerativeModel {
private final GenerationConfig generationConfig;
private final ImmutableList<SafetySetting> safetySettings;
private final ImmutableList<Tool> tools;
private final Optional<Content> systemInstruction;

/**
* Constructs a GenerativeModel instance.
Expand All @@ -53,7 +55,7 @@ public final class GenerativeModel {
* "models/gemini-pro", "publishers/google/models/gemini-pro", where "gemini-pro" is the model
* name. Valid model names can be found at
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
public GenerativeModel(String modelName, VertexAI vertexAi) {
Expand All @@ -62,6 +64,7 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
GenerationConfig.getDefaultInstance(),
ImmutableList.of(),
ImmutableList.of(),
Optional.empty(),
vertexAi);
}

Expand All @@ -76,14 +79,15 @@ public GenerativeModel(String modelName, VertexAI vertexAi) {
* that will be used by default for generating response
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} instances that can be used by
* the model as auxiliary tools to generate content.
* @param vertexAI a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* @param vertexAi a {@link com.google.cloud.vertexai.VertexAI} that contains the default configs
* for the generative model
*/
private GenerativeModel(
String modelName,
GenerationConfig generationConfig,
ImmutableList<SafetySetting> safetySettings,
ImmutableList<Tool> tools,
Optional<Content> systemInstruction,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a setSystemInstruction in the GenerativeModel.Builder class? Also a getter method in the GenerativeModel

VertexAI vertexAi) {
checkArgument(
!Strings.isNullOrEmpty(modelName),
Expand All @@ -105,6 +109,7 @@ private GenerativeModel(
this.generationConfig = generationConfig;
this.safetySettings = safetySettings;
this.tools = tools;
this.systemInstruction = systemInstruction;
}

/** Builder class for {@link GenerativeModel}. */
Expand All @@ -114,20 +119,22 @@ public static class Builder {
private GenerationConfig generationConfig = GenerationConfig.getDefaultInstance();
private ImmutableList<SafetySetting> safetySettings = ImmutableList.of();
private ImmutableList<Tool> tools = ImmutableList.of();
private Optional<Content> systemInstructions = Optional.empty();

public GenerativeModel build() {
checkArgument(
!Strings.isNullOrEmpty(modelName),
"modelName is required. Please call setModelName() before building.");
checkNotNull(vertexAi, "vertexAi is required. Please call setVertexAi() before building.");
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
return new GenerativeModel(
modelName, generationConfig, safetySettings, tools, systemInstructions, vertexAi);
}

/**
* Sets the name of the generative model. This is required for building a GenerativeModel
* instance. Supported format: "gemini-pro", "models/gemini-pro",
* "publishers/google/models/gemini-pro", where "gemini-pro" is the model name. Valid model
* names can be found at
* names can be found in the Gemini models documentation
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
@CanIgnoreReturnValue
Expand Down Expand Up @@ -197,7 +204,13 @@ public Builder setTools(List<Tool> tools) {
* @return a new {@link GenerativeModel} instance with the specified GenerationConfig.
*/
public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
return new GenerativeModel(modelName, generationConfig, safetySettings, tools, vertexAi);
return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
Expand All @@ -209,19 +222,46 @@ public GenerativeModel withGenerationConfig(GenerationConfig generationConfig) {
*/
public GenerativeModel withSafetySettings(List<SafetySetting> safetySettings) {
return new GenerativeModel(
modelName, generationConfig, ImmutableList.copyOf(safetySettings), tools, vertexAi);
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated tools.
*
* @param safetySettings a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in
* the new model.
* @param tools a list of {@link com.google.cloud.vertexai.api.Tool} that will be used in the new
* model.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withTools(List<Tool> tools) {
return new GenerativeModel(
modelName, generationConfig, safetySettings, ImmutableList.copyOf(tools), vertexAi);
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
systemInstruction,
vertexAi);
}

/**
* Creates a copy of the current model with updated system instructions.
*
* @param systemInstructions a {@link com.google.cloud.vertexai.api.Content} containing system
* instructions.
* @return a new {@link GenerativeModel} instance with the specified tools.
*/
public GenerativeModel withSystemInstructions(Content systemInstructions) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace all SystemInstructions with SystemInstruction?

return new GenerativeModel(
modelName,
generationConfig,
ImmutableList.copyOf(safetySettings),
ImmutableList.copyOf(tools),
Optional.of(systemInstructions),
vertexAi);
}

/**
Expand Down Expand Up @@ -453,13 +493,20 @@ private ApiFuture<GenerateContentResponse> generateContentAsync(GenerateContentR
*/
private GenerateContentRequest buildGenerateContentRequest(List<Content> contents) {
checkArgument(contents != null && !contents.isEmpty(), "contents can't be null or empty.");
return GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
.setGenerationConfig(generationConfig)
.addAllSafetySettings(safetySettings)
.addAllTools(tools)
.build();

GenerateContentRequest.Builder requestBuilder =
GenerateContentRequest.newBuilder()
.setModel(resourceName)
.addAllContents(contents)
.setGenerationConfig(generationConfig)
.addAllSafetySettings(safetySettings)
.addAllTools(tools);

if (systemInstruction.isPresent()) {
requestBuilder.setSystemInstruction(systemInstruction.get());
}

return requestBuilder.build();
}

/** Returns the model name of this generative model. */
Expand All @@ -475,8 +522,7 @@ public GenerationConfig getGenerationConfig() {
}

/**
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySettings} of this generative
* model.
* Returns a list of {@link com.google.cloud.vertexai.api.SafetySetting} of this generative model.
*/
public ImmutableList<SafetySetting> getSafetySettings() {
return safetySettings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,28 @@ public void testGenerateContentwithContents() throws Exception {
assertThat(request.getValue().getContents(0).getParts(0).getText()).isEqualTo(TEXT);
}

@Test
public void testGenerateContentwithSystemInstructions() throws Exception {
String systemInstructionText =
"You're a helpful assistant that starts all its answers with: \"COOL\"";
Content systemInstructions = ContentMaker.fromString(systemInstructionText);

model = new GenerativeModel(MODEL_NAME, vertexAi).withSystemInstructions(systemInstructions);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
.thenReturn(mockGenerateContentResponse);

Content content = ContentMaker.fromString(TEXT);
GenerateContentResponse unused = model.generateContent(Arrays.asList(content));

ArgumentCaptor<GenerateContentRequest> request =
ArgumentCaptor.forClass(GenerateContentRequest.class);
verify(mockUnaryCallable).call(request.capture());
assertThat(request.getValue().getSystemInstruction().getParts(0).getText())
.isEqualTo(systemInstructionText);
}

@Test
public void testGenerateContentwithDefaultGenerationConfig() throws Exception {
model =
Expand Down
Loading