diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index cb67d7bd2ff3..1e1cc2ec9449 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -61,10 +61,10 @@ public class VertexAI implements AutoCloseable { private final String apiEndpoint; private final Transport transport; private final CredentialsProvider credentialsProvider; - // The clients will be instantiated lazily - private PredictionServiceClient predictionServiceClient = null; - private LlmUtilityServiceClient llmUtilityClient = null; private final ReentrantLock lock = new ReentrantLock(); + // The clients will be instantiated lazily + private Optional<PredictionServiceClient> predictionServiceClient = Optional.empty(); + private Optional<LlmUtilityServiceClient> llmUtilityClient = Optional.empty(); /** * Construct a VertexAI instance. @@ -229,22 +229,22 @@ public Credentials getCredentials() throws IOException { */ @InternalApi public PredictionServiceClient getPredictionServiceClient() throws IOException { - if (predictionServiceClient != null) { - return predictionServiceClient; + if (predictionServiceClient.isPresent()) { + return predictionServiceClient.get(); } lock.lock(); try { - if (predictionServiceClient == null) { + if (!predictionServiceClient.isPresent()) { PredictionServiceSettings settings = getPredictionServiceSettings(); // Disable the warning message logged in getApplicationDefault Logger defaultCredentialsProviderLogger = Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); Level previousLevel = defaultCredentialsProviderLogger.getLevel(); defaultCredentialsProviderLogger.setLevel(Level.SEVERE); - predictionServiceClient = PredictionServiceClient.create(settings); + predictionServiceClient = Optional.of(PredictionServiceClient.create(settings)); defaultCredentialsProviderLogger.setLevel(previousLevel); } - return predictionServiceClient; + return predictionServiceClient.get(); } finally { lock.unlock(); } @@ -280,22 +280,22 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept */ @InternalApi public LlmUtilityServiceClient getLlmUtilityClient() throws IOException { - if (llmUtilityClient != null) { - return llmUtilityClient; + if (llmUtilityClient.isPresent()) { + return llmUtilityClient.get(); } lock.lock(); try { - if (llmUtilityClient == null) { + if (!llmUtilityClient.isPresent()) { LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings(); // Disable the warning message logged in getApplicationDefault Logger defaultCredentialsProviderLogger = Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider"); Level previousLevel = defaultCredentialsProviderLogger.getLevel(); defaultCredentialsProviderLogger.setLevel(Level.SEVERE); - llmUtilityClient = LlmUtilityServiceClient.create(settings); + llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings)); defaultCredentialsProviderLogger.setLevel(previousLevel); } - return llmUtilityClient; + return llmUtilityClient.get(); } finally { lock.unlock(); } @@ -325,11 +325,11 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO /** Closes the VertexAI instance together with all its instantiated clients. */ @Override public void close() { - if (predictionServiceClient != null) { - predictionServiceClient.close(); + if (predictionServiceClient.isPresent()) { + predictionServiceClient.get().close(); } - if (llmUtilityClient != null) { - llmUtilityClient.close(); + if (llmUtilityClient.isPresent()) { + llmUtilityClient.get().close(); } } } diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java index ef6733af4c60..a4cabcf068c5 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java @@ -45,6 +45,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.Optional; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -313,7 +314,7 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java index 84f8d0b61fa5..70174b29858f 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java @@ -51,6 +51,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.Optional; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -240,7 +241,7 @@ public void testCountTokenswithText() throws Exception { Field field = VertexAI.class.getDeclaredField("llmUtilityClient"); field.setAccessible(true); - field.set(vertexAi, mockLlmUtilityServiceClient); + field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient)); CountTokensResponse unused = model.countTokens(TEXT); @@ -255,7 +256,7 @@ public void testCountTokenswithContent() throws Exception { Field field = VertexAI.class.getDeclaredField("llmUtilityClient"); field.setAccessible(true); - field.set(vertexAi, mockLlmUtilityServiceClient); + field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient)); Content content = ContentMaker.fromString(TEXT); CountTokensResponse unused = model.countTokens(content); @@ -271,7 +272,7 @@ public void testCountTokenswithContents() throws Exception { Field field = VertexAI.class.getDeclaredField("llmUtilityClient"); field.setAccessible(true); - field.set(vertexAi, mockLlmUtilityServiceClient); + field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient)); Content content = ContentMaker.fromString(TEXT); CountTokensResponse unused = model.countTokens(Arrays.asList(content)); @@ -287,7 +288,7 @@ public void testGenerateContentwithText() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -307,7 +308,7 @@ public void testGenerateContentwithContent() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -329,7 +330,7 @@ public void testGenerateContentwithContents() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -356,7 +357,7 @@ public void testGenerateContentwithDefaultGenerationConfig() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -382,7 +383,7 @@ public void testGenerateContentwithDefaultSafetySettings() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -408,7 +409,7 @@ public void testGenerateContentwithDefaultTools() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -429,7 +430,7 @@ public void testGenerateContentwithFluentApi() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.call(any(GenerateContentRequest.class))) @@ -457,7 +458,7 @@ public void testGenerateContentStreamwithText() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -480,7 +481,7 @@ public void testGenerateContentStreamwithContent() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -505,7 +506,7 @@ public void testGenerateContentStreamwithContents() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -535,7 +536,7 @@ public void testGenerateContentStreamwithDefaultGenerationConfig() throws Except Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -562,7 +563,7 @@ public void testGenerateContentStreamwithDefaultSafetySettings() throws Exceptio Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -589,7 +590,7 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -611,7 +612,7 @@ public void testGenerateContentStreamwithFluentApi() throws Exception { Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.streamGenerateContentCallable()) .thenReturn(mockServerStreamCallable); @@ -641,7 +642,7 @@ public void generateContentAsync_withText_sendsCorrectRequest() throws Exception Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture); @@ -662,7 +663,7 @@ public void generateContentAsync_withContent_sendsCorrectRequest() throws Except Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture); @@ -683,7 +684,7 @@ public void generateContentAsync_withContents_sendsCorrectRequest() throws Excep Field field = VertexAI.class.getDeclaredField("predictionServiceClient"); field.setAccessible(true); - field.set(vertexAi, mockPredictionServiceClient); + field.set(vertexAi, Optional.of(mockPredictionServiceClient)); when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable); when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);