Skip to content

Commit

Permalink
fix: Simplify VertexAI with Suppliers.memorize and avoid accessing pr…
Browse files Browse the repository at this point in the history
…ivate members in tests. (#10694)

- Implement lazy init using Suppliers.memorize instead of an explicit lock.
  - Add a newBuilder method in VertexAI.
  - Updates unit tests to avoid accessing private fields in VertexAI.

PiperOrigin-RevId: 624303836

Co-authored-by: A Vertex SDK engineer <[email protected]>
  • Loading branch information
copybara-service[bot] and vertex-sdk-bot authored Apr 17, 2024
1 parent ae22f1c commit 7bdfa55
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.base.Strings;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;

Expand All @@ -61,13 +63,12 @@ public class VertexAI implements AutoCloseable {
private final String apiEndpoint;
private final Transport transport;
private final CredentialsProvider credentialsProvider;
private final ReentrantLock lock = new ReentrantLock();
// The clients will be instantiated lazily
private Optional<PredictionServiceClient> predictionServiceClient = Optional.empty();
private Optional<LlmUtilityServiceClient> llmUtilityClient = Optional.empty();

private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
private final transient Supplier<LlmUtilityServiceClient> llmClientSupplier;

/**
* Construct a VertexAI instance.
* Constructs a VertexAI instance.
*
* @param projectId the default project to use when making API calls
* @param location the default location to use when making API calls
Expand All @@ -78,8 +79,10 @@ public VertexAI(String projectId, String location) {
location,
Transport.GRPC,
ImmutableList.of(),
Optional.empty(),
Optional.empty());
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
/* llmClientSupplierOpt= */ Optional.empty());
}

private VertexAI(
Expand All @@ -88,7 +91,9 @@ private VertexAI(
Transport transport,
List<String> scopes,
Optional<Credentials> credentials,
Optional<String> apiEndpoint) {
Optional<String> apiEndpoint,
Optional<Supplier<PredictionServiceClient>> predictionClientSupplierOpt,
Optional<Supplier<LlmUtilityServiceClient>> llmClientSupplierOpt) {
if (!scopes.isEmpty() && credentials.isPresent()) {
throw new IllegalArgumentException(
"At most one of Credentials and scopes should be specified.");
Expand All @@ -113,9 +118,19 @@ private VertexAI(
.build();
}

this.predictionClientSupplier =
Suppliers.memoize(predictionClientSupplierOpt.orElse(this::newPredictionServiceClient));

this.llmClientSupplier =
Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient));

this.apiEndpoint = apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", location));
}

public static Builder builder() {
return new Builder();
}

/** Builder for {@link VertexAI}. */
public static class Builder {
private String projectId;
Expand All @@ -125,11 +140,25 @@ public static class Builder {
private Optional<Credentials> credentials = Optional.empty();
private Optional<String> apiEndpoint = Optional.empty();

private Supplier<PredictionServiceClient> predictionClientSupplier;

private Supplier<LlmUtilityServiceClient> llmClientSupplier;

Builder() {}

public VertexAI build() {
checkNotNull(projectId, "projectId must be set.");
checkNotNull(location, "location must be set.");

return new VertexAI(projectId, location, transport, scopes, credentials, apiEndpoint);
return new VertexAI(
projectId,
location,
transport,
scopes,
credentials,
apiEndpoint,
Optional.ofNullable(predictionClientSupplier),
Optional.ofNullable(llmClientSupplier));
}

public Builder setProjectId(String projectId) {
Expand Down Expand Up @@ -167,6 +196,19 @@ public Builder setCredentials(Credentials credentials) {
return this;
}

@CanIgnoreReturnValue
public Builder setPredictionClientSupplier(
Supplier<PredictionServiceClient> predictionClientSupplier) {
this.predictionClientSupplier = predictionClientSupplier;
return this;
}

@CanIgnoreReturnValue
public Builder setLlmClientSupplier(Supplier<LlmUtilityServiceClient> llmClientSupplier) {
this.llmClientSupplier = llmClientSupplier;
return this;
}

public Builder setScopes(List<String> scopes) {
checkNotNull(scopes, "scopes can't be null");

Expand Down Expand Up @@ -228,25 +270,23 @@ public Credentials getCredentials() throws IOException {
* method calls that map to the API methods.
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (predictionServiceClient.isPresent()) {
return predictionServiceClient.get();
}
lock.lock();
public PredictionServiceClient getPredictionServiceClient() {
return predictionClientSupplier.get();
}

private PredictionServiceClient newPredictionServiceClient() {
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);

try {
if (!predictionServiceClient.isPresent()) {
PredictionServiceSettings settings = getPredictionServiceSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = Optional.of(PredictionServiceClient.create(settings));
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return predictionServiceClient.get();
return PredictionServiceClient.create(getPredictionServiceSettings());
} catch (IOException e) {
throw new IllegalStateException(e);
} finally {
lock.unlock();
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
}

Expand All @@ -257,8 +297,8 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
} else {
builder = PredictionServiceSettings.newBuilder();
}
builder.setEndpoint(String.format("%s:443", this.apiEndpoint));
builder.setCredentialsProvider(this.credentialsProvider);
builder.setEndpoint(String.format("%s:443", apiEndpoint));
builder.setCredentialsProvider(credentialsProvider);

HeaderProvider headerProvider =
FixedHeaderProvider.create(
Expand All @@ -279,25 +319,23 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
* calls that map to the API methods.
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (llmUtilityClient.isPresent()) {
return llmUtilityClient.get();
}
lock.lock();
public LlmUtilityServiceClient getLlmUtilityClient() {
return llmClientSupplier.get();
}

private LlmUtilityServiceClient newLlmUtilityClient() {
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);

try {
if (!llmUtilityClient.isPresent()) {
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings));
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return llmUtilityClient.get();
return LlmUtilityServiceClient.create(getLlmUtilityServiceClientSettings());
} catch (IOException e) {
throw new IllegalStateException(e);
} finally {
lock.unlock();
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
}

Expand All @@ -308,8 +346,8 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
} else {
settingsBuilder = LlmUtilityServiceSettings.newBuilder();
}
settingsBuilder.setEndpoint(String.format("%s:443", this.apiEndpoint));
settingsBuilder.setCredentialsProvider(this.credentialsProvider);
settingsBuilder.setEndpoint(String.format("%s:443", apiEndpoint));
settingsBuilder.setCredentialsProvider(credentialsProvider);

HeaderProvider headerProvider =
FixedHeaderProvider.create(
Expand All @@ -325,11 +363,7 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
/** Closes the VertexAI instance together with all its instantiated clients. */
@Override
public void close() {
if (predictionServiceClient.isPresent()) {
predictionServiceClient.get().close();
}
if (llmUtilityClient.isPresent()) {
llmUtilityClient.get().close();
}
predictionClientSupplier.get().close();
llmClientSupplier.get().close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,9 @@
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.Type;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -309,12 +307,15 @@ public void sendMessageWithText_throwsIllegalStateExceptionWhenFinishReasonIsNot
public void testChatSessionMergeHistoryToRootChatSession() throws Exception {

// (Arrange) Set up the return value of the generateContent
VertexAI vertexAi = new VertexAI(PROJECT, LOCATION);
GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi);

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));
VertexAI vertexAi =
VertexAI.builder()
.setProjectId(PROJECT)
.setLocation(LOCATION)
.setPredictionClientSupplier(() -> mockPredictionServiceClient)
.build();

GenerativeModel model = new GenerativeModel("gemini-pro", vertexAi);

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand Down
Loading

0 comments on commit 7bdfa55

Please sign in to comment.