Skip to content

Commit

Permalink
chore: [vertexai] revert the change in the VertexAI.Builder constructor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626417155
  • Loading branch information
jaycee-li authored and copybara-github committed Apr 19, 2024
1 parent 51b2af1 commit 0b801eb
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.base.Strings;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -61,13 +63,12 @@ public class VertexAI implements AutoCloseable {
private final String apiEndpoint;
private final Transport transport;
private final CredentialsProvider credentialsProvider;
private final ReentrantLock lock = new ReentrantLock();
// The clients will be instantiated lazily
private Optional<PredictionServiceClient> predictionServiceClient = Optional.empty();
private Optional<LlmUtilityServiceClient> llmUtilityClient = Optional.empty();

private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
private final transient Supplier<LlmUtilityServiceClient> llmClientSupplier;

/**
* Construct a VertexAI instance.
* Constructs a VertexAI instance.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
Expand All @@ -78,8 +79,10 @@ public VertexAI(String projectId, String location) {
location,
Transport.GRPC,
ImmutableList.of(),
Optional.empty(),
Optional.empty());
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
/* llmClientSupplierOpt= */ Optional.empty());
}

private VertexAI(
Expand All @@ -88,7 +91,9 @@ private VertexAI(
Transport transport,
List<String> scopes,
Optional<Credentials> credentials,
Optional<String> apiEndpoint) {
Optional<String> apiEndpoint,
Optional<Supplier<PredictionServiceClient>> predictionClientSupplierOpt,
Optional<Supplier<LlmUtilityServiceClient>> llmClientSupplierOpt) {
if (!scopes.isEmpty() && credentials.isPresent()) {
throw new IllegalArgumentException(
"At most one of Credentials and scopes should be specified.");
Expand All @@ -113,6 +118,12 @@ private VertexAI(
.build();
}

this.predictionClientSupplier =
Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient));

this.llmClientSupplier =
Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient));

this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location));
}

Expand All @@ -125,48 +136,79 @@ public static class Builder {
private Optional<Credentials> credentials = Optional.empty();
private Optional<String> apiEndpoint = Optional.empty();

private Supplier<PredictionServiceClient> predictionClientSupplier;

private Supplier<LlmUtilityServiceClient> llmClientSupplier;

public VertexAI build() {
checkNotNull(projectId, "projectId must be set.");
checkNotNull(location, "location must be set.");

return new VertexAI(projectId, location, transport, scopes, credentials, apiEndpoint);
return new VertexAI(
projectId,
location,
transport,
scopes,
credentials,
apiEndpoint,
Optional.ofNullable(predictionClientSupplier),
Optional.ofNullable(llmClientSupplier));
}

@CanIgnoreReturnValue
public Builder setProjectId(String projectId) {
checkArgument(!Strings.isNullOrEmpty(projectId), "projectId can't be null or empty");

this.projectId = projectId;
return this;
}

@CanIgnoreReturnValue
public Builder setLocation(String location) {
checkArgument(!Strings.isNullOrEmpty(location), "location can't be null or empty");

this.location = location;
return this;
}

@CanIgnoreReturnValue
public Builder setApiEndpoint(String apiEndpoint) {
checkArgument(!Strings.isNullOrEmpty(apiEndpoint), "apiEndpoint can't be null or empty");

this.apiEndpoint = Optional.of(apiEndpoint);
return this;
}

@CanIgnoreReturnValue
public Builder setTransport(Transport transport) {
checkNotNull(transport, "transport can't be null");

this.transport = transport;
return this;
}

@CanIgnoreReturnValue
public Builder setCredentials(Credentials credentials) {
checkNotNull(credentials, "credentials can't be null");

this.credentials = Optional.of(credentials);
return this;
}

@CanIgnoreReturnValue
public Builder setPredictionClientSupplier(
Supplier<PredictionServiceClient> predictionClientSupplier) {
this.predictionClientSupplier = predictionClientSupplier;
return this;
}

@CanIgnoreReturnValue
public Builder setLlmClientSupplier(Supplier<LlmUtilityServiceClient> llmClientSupplier) {
this.llmClientSupplier = llmClientSupplier;
return this;
}

@CanIgnoreReturnValue
public Builder setScopes(List<String> scopes) {
checkNotNull(scopes, "scopes can't be null");

Expand Down Expand Up @@ -228,25 +270,23 @@ public Credentials getCredentials() throws IOException {
* method calls that map to the API methods.
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (predictionServiceClient.isPresent()) {
return predictionServiceClient.get();
}
lock.lock();
public PredictionServiceClient getPredictionServiceClient() {
return predictionClientSupplier.get();
}

private PredictionServiceClient newPredictionServiceClient() {
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);

try {
if (!predictionServiceClient.isPresent()) {
PredictionServiceSettings settings = getPredictionServiceSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = Optional.of(PredictionServiceClient.create(settings));
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return predictionServiceClient.get();
return PredictionServiceClient.create(getPredictionServiceSettings());
} catch (IOException e) {
throw new IllegalStateException(e);
} finally {
lock.unlock();
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
}

Expand All @@ -257,8 +297,8 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
} else {
builder = PredictionServiceSettings.newBuilder();
}
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
builder.setCredentialsProvider(this.credentialsProvider);
builder.setEndpoint(String.format("%s:443", apiEndpoint));
builder.setCredentialsProvider(credentialsProvider);

HeaderProvider headerProvider =
FixedHeaderProvider.create(
Expand All @@ -279,25 +319,23 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
* calls that map to the API methods.
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (llmUtilityClient.isPresent()) {
return llmUtilityClient.get();
}
lock.lock();
public LlmUtilityServiceClient getLlmUtilityClient() {
return llmClientSupplier.get();
}

private LlmUtilityServiceClient newLlmUtilityClient() {
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);

try {
if (!llmUtilityClient.isPresent()) {
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings));
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return llmUtilityClient.get();
return LlmUtilityServiceClient.create(getLlmUtilityServiceClientSettings());
} catch (IOException e) {
throw new IllegalStateException(e);
} finally {
lock.unlock();
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
}

Expand All @@ -308,8 +346,8 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
} else {
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
}
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
settingsBuilder.setEndpoint(String.format("%s:443", apiEndpoint));
settingsBuilder.setCredentialsProvider(credentialsProvider);

HeaderProvider headerProvider =
FixedHeaderProvider.create(
Expand All @@ -325,11 +363,7 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
/** Closes the VertexAI instance together with all its instantiated clients. */
@Override
public void close() {
if (predictionServiceClient.isPresent()) {
predictionServiceClient.get().close();
}
if (llmUtilityClient.isPresent()) {
llmUtilityClient.get().close();
}
predictionClientSupplier.get().close();
llmClientSupplier.get().close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.cloud.vertexai.api.Tool;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -129,6 +130,7 @@ public GenerativeModel build() {
* names can be found at
* https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models#gemini-models
*/
@CanIgnoreReturnValue
public Builder setModelName(String modelName) {
checkArgument(
!Strings.isNullOrEmpty(modelName),
Expand All @@ -144,6 +146,7 @@ public Builder setModelName(String modelName) {
* Sets {@link com.google.cloud.vertexai.VertexAI} that contains the default configs for the
* generative model. This is required for building a GenerativeModel instance.
*/
@CanIgnoreReturnValue
public Builder setVertexAi(VertexAI vertexAi) {
checkNotNull(vertexAi, "VertexAI can't be null.");
this.vertexAi = vertexAi;
Expand All @@ -154,6 +157,7 @@ public Builder setVertexAi(VertexAI vertexAi) {
* Sets {@link com.google.cloud.vertexai.api.GenerationConfig} that will be used by default to
* interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setGenerationConfig(GenerationConfig generationConfig) {
checkNotNull(generationConfig, "GenerationConfig can't be null.");
this.generationConfig = generationConfig;
Expand All @@ -164,6 +168,7 @@ public Builder setGenerationConfig(GenerationConfig generationConfig) {
* Sets a list of {@link com.google.cloud.vertexai.api.SafetySetting} that will be used by
* default to interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setSafetySettings(List<SafetySetting> safetySettings) {
checkNotNull(
safetySettings,
Expand All @@ -176,6 +181,7 @@ public Builder setSafetySettings(List<SafetySetting> safetySettings) {
* Sets a list of {@link com.google.cloud.vertexai.api.Tool} that will be used by default to
* interact with the generative model.
*/
@CanIgnoreReturnValue
public Builder setTools(List<Tool> tools) {
checkNotNull(tools, "tools can't be null. Use an empty list if no tool is to be used.");
this.tools = ImmutableList.copyOf(tools);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.Type;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -309,12 +307,15 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot
public void testChatSessionMergeHistoryToRootChatSession() throws Exception {

// (Arrange) Set up the return value of the generateContent
VertexAI vertexAi = new VertexAI(PROJECT, LOCATION);
GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));
VertexAI vertexAi =
new VertexAI.Builder()
.setProjectId(PROJECT)
.setLocation(LOCATION)
.setPredictionClientSupplier(() -> mockPredictionServiceClient)
.build();

GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand Down
Loading

0 comments on commit 0b801eb

Please sign in to comment.