Skip to content

Commit

Permalink
chore: [vertexai] Make client fields Optional in VertexAI (#10666)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622249327

Co-authored-by: Jaycee Li <[email protected]>
  • Loading branch information
copybara-service[bot] and jaycee-li authored Apr 8, 2024
1 parent 5e05307 commit 76d2b2c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ public class VertexAI implements AutoCloseable {
private final String apiEndpoint;
private final Transport transport;
private final CredentialsProvider credentialsProvider;
// The clients will be instantiated lazily
private PredictionServiceClient predictionServiceClient = null;
private LlmUtilityServiceClient llmUtilityClient = null;
private final ReentrantLock lock = new ReentrantLock();
// The clients will be instantiated lazily
private Optional<PredictionServiceClient> predictionServiceClient = Optional.empty();
private Optional<LlmUtilityServiceClient> llmUtilityClient = Optional.empty();

/**
* Construct a VertexAI instance.
Expand Down Expand Up @@ -229,22 +229,22 @@ public Credentials getCredentials() throws IOException {
*/
@InternalApi
public PredictionServiceClient getPredictionServiceClient() throws IOException {
if (predictionServiceClient != null) {
return predictionServiceClient;
if (predictionServiceClient.isPresent()) {
return predictionServiceClient.get();
}
lock.lock();
try {
if (predictionServiceClient == null) {
if (!predictionServiceClient.isPresent()) {
PredictionServiceSettings settings = getPredictionServiceSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
predictionServiceClient = PredictionServiceClient.create(settings);
predictionServiceClient = Optional.of(PredictionServiceClient.create(settings));
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return predictionServiceClient;
return predictionServiceClient.get();
} finally {
lock.unlock();
}
Expand Down Expand Up @@ -280,22 +280,22 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
*/
@InternalApi
public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
if (llmUtilityClient != null) {
return llmUtilityClient;
if (llmUtilityClient.isPresent()) {
return llmUtilityClient.get();
}
lock.lock();
try {
if (llmUtilityClient == null) {
if (!llmUtilityClient.isPresent()) {
LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
// Disable the warning message logged in getApplicationDefault
Logger defaultCredentialsProviderLogger =
Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
Level previousLevel = defaultCredentialsProviderLogger.getLevel();
defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
llmUtilityClient = LlmUtilityServiceClient.create(settings);
llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings));
defaultCredentialsProviderLogger.setLevel(previousLevel);
}
return llmUtilityClient;
return llmUtilityClient.get();
} finally {
lock.unlock();
}
Expand Down Expand Up @@ -325,11 +325,11 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
/** Closes the VertexAI instance together with all its instantiated clients. */
@Override
public void close() {
if (predictionServiceClient != null) {
predictionServiceClient.close();
if (predictionServiceClient.isPresent()) {
predictionServiceClient.get().close();
}
if (llmUtilityClient != null) {
llmUtilityClient.close();
if (llmUtilityClient.isPresent()) {
llmUtilityClient.get().close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -313,7 +314,7 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -240,7 +241,7 @@ public void testCountTokenswithText() throws Exception {

Field field = VertexAI.class.getDeclaredField("llmUtilityClient");
field.setAccessible(true);
field.set(vertexAi, mockLlmUtilityServiceClient);
field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient));

CountTokensResponse unused = model.countTokens(TEXT);

Expand All @@ -255,7 +256,7 @@ public void testCountTokenswithContent() throws Exception {

Field field = VertexAI.class.getDeclaredField("llmUtilityClient");
field.setAccessible(true);
field.set(vertexAi, mockLlmUtilityServiceClient);
field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient));

Content content = ContentMaker.fromString(TEXT);
CountTokensResponse unused = model.countTokens(content);
Expand All @@ -271,7 +272,7 @@ public void testCountTokenswithContents() throws Exception {

Field field = VertexAI.class.getDeclaredField("llmUtilityClient");
field.setAccessible(true);
field.set(vertexAi, mockLlmUtilityServiceClient);
field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient));

Content content = ContentMaker.fromString(TEXT);
CountTokensResponse unused = model.countTokens(Arrays.asList(content));
Expand All @@ -287,7 +288,7 @@ public void testGenerateContentwithText() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand All @@ -307,7 +308,7 @@ public void testGenerateContentwithContent() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand All @@ -329,7 +330,7 @@ public void testGenerateContentwithContents() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand All @@ -356,7 +357,7 @@ public void testGenerateContentwithDefaultGenerationConfig() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand All @@ -382,7 +383,7 @@ public void testGenerateContentwithDefaultSafetySettings() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand All @@ -408,7 +409,7 @@ public void testGenerateContentwithDefaultTools() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand All @@ -429,7 +430,7 @@ public void testGenerateContentwithFluentApi() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
Expand Down Expand Up @@ -457,7 +458,7 @@ public void testGenerateContentStreamwithText() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand All @@ -480,7 +481,7 @@ public void testGenerateContentStreamwithContent() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand All @@ -505,7 +506,7 @@ public void testGenerateContentStreamwithContents() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand Down Expand Up @@ -535,7 +536,7 @@ public void testGenerateContentStreamwithDefaultGenerationConfig() throws Except

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand All @@ -562,7 +563,7 @@ public void testGenerateContentStreamwithDefaultSafetySettings() throws Exceptio

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand All @@ -589,7 +590,7 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand All @@ -611,7 +612,7 @@ public void testGenerateContentStreamwithFluentApi() throws Exception {

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.streamGenerateContentCallable())
.thenReturn(mockServerStreamCallable);
Expand Down Expand Up @@ -641,7 +642,7 @@ public void generateContentAsync_withText_sendsCorrectRequest() throws Exception

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);
Expand All @@ -662,7 +663,7 @@ public void generateContentAsync_withContent_sendsCorrectRequest() throws Except

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);
Expand All @@ -683,7 +684,7 @@ public void generateContentAsync_withContents_sendsCorrectRequest() throws Excep

Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
field.setAccessible(true);
field.set(vertexAi, mockPredictionServiceClient);
field.set(vertexAi, Optional.of(mockPredictionServiceClient));

when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);
Expand Down

0 comments on commit 76d2b2c

Please sign in to comment.