diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java
index cb67d7bd2ff3..1e1cc2ec9449 100644
--- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java
+++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java
@@ -61,10 +61,10 @@ public class VertexAI implements AutoCloseable {
   private final String apiEndpoint;
   private final Transport transport;
   private final CredentialsProvider credentialsProvider;
-  // The clients will be instantiated lazily
-  private PredictionServiceClient predictionServiceClient = null;
-  private LlmUtilityServiceClient llmUtilityClient = null;
   private final ReentrantLock lock = new ReentrantLock();
+  // The clients will be instantiated lazily
+  private Optional<PredictionServiceClient> predictionServiceClient = Optional.empty();
+  private Optional<LlmUtilityServiceClient> llmUtilityClient = Optional.empty();
 
   /**
    * Construct a VertexAI instance.
@@ -229,22 +229,22 @@ public Credentials getCredentials() throws IOException {
    */
   @InternalApi
   public PredictionServiceClient getPredictionServiceClient() throws IOException {
-    if (predictionServiceClient != null) {
-      return predictionServiceClient;
+    if (predictionServiceClient.isPresent()) {
+      return predictionServiceClient.get();
     }
     lock.lock();
     try {
-      if (predictionServiceClient == null) {
+      if (!predictionServiceClient.isPresent()) {
         PredictionServiceSettings settings = getPredictionServiceSettings();
         // Disable the warning message logged in getApplicationDefault
         Logger defaultCredentialsProviderLogger =
             Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
         Level previousLevel = defaultCredentialsProviderLogger.getLevel();
         defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
-        predictionServiceClient = PredictionServiceClient.create(settings);
+        predictionServiceClient = Optional.of(PredictionServiceClient.create(settings));
         defaultCredentialsProviderLogger.setLevel(previousLevel);
       }
-      return predictionServiceClient;
+      return predictionServiceClient.get();
     } finally {
       lock.unlock();
     }
@@ -280,22 +280,22 @@ private PredictionServiceSettings getPredictionServiceSettings() throws IOExcept
    */
   @InternalApi
   public LlmUtilityServiceClient getLlmUtilityClient() throws IOException {
-    if (llmUtilityClient != null) {
-      return llmUtilityClient;
+    if (llmUtilityClient.isPresent()) {
+      return llmUtilityClient.get();
     }
     lock.lock();
     try {
-      if (llmUtilityClient == null) {
+      if (!llmUtilityClient.isPresent()) {
         LlmUtilityServiceSettings settings = getLlmUtilityServiceClientSettings();
         // Disable the warning message logged in getApplicationDefault
         Logger defaultCredentialsProviderLogger =
             Logger.getLogger("com.google.auth.oauth2.DefaultCredentialsProvider");
         Level previousLevel = defaultCredentialsProviderLogger.getLevel();
         defaultCredentialsProviderLogger.setLevel(Level.SEVERE);
-        llmUtilityClient = LlmUtilityServiceClient.create(settings);
+        llmUtilityClient = Optional.of(LlmUtilityServiceClient.create(settings));
         defaultCredentialsProviderLogger.setLevel(previousLevel);
       }
-      return llmUtilityClient;
+      return llmUtilityClient.get();
     } finally {
       lock.unlock();
     }
@@ -325,11 +325,11 @@ private LlmUtilityServiceSettings getLlmUtilityServiceClientSettings() throws IO
   /** Closes the VertexAI instance together with all its instantiated clients. */
   @Override
   public void close() {
-    if (predictionServiceClient != null) {
-      predictionServiceClient.close();
+    if (predictionServiceClient.isPresent()) {
+      predictionServiceClient.get().close();
     }
-    if (llmUtilityClient != null) {
-      llmUtilityClient.close();
+    if (llmUtilityClient.isPresent()) {
+      llmUtilityClient.get().close();
     }
   }
 }
diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java
index ef6733af4c60..a4cabcf068c5 100644
--- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java
+++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/ChatSessionTest.java
@@ -45,6 +45,7 @@
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Optional;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -313,7 +314,7 @@ public void testChatSessionMergeHistoryToRootChatSession() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java
index 84f8d0b61fa5..70174b29858f 100644
--- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java
+++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/generativeai/GenerativeModelTest.java
@@ -51,6 +51,7 @@
 import java.util.Arrays;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Optional;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -240,7 +241,7 @@ public void testCountTokenswithText() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("llmUtilityClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockLlmUtilityServiceClient);
+    field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient));
 
     CountTokensResponse unused = model.countTokens(TEXT);
 
@@ -255,7 +256,7 @@ public void testCountTokenswithContent() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("llmUtilityClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockLlmUtilityServiceClient);
+    field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient));
 
     Content content = ContentMaker.fromString(TEXT);
     CountTokensResponse unused = model.countTokens(content);
@@ -271,7 +272,7 @@ public void testCountTokenswithContents() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("llmUtilityClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockLlmUtilityServiceClient);
+    field.set(vertexAi, Optional.of(mockLlmUtilityServiceClient));
 
     Content content = ContentMaker.fromString(TEXT);
     CountTokensResponse unused = model.countTokens(Arrays.asList(content));
@@ -287,7 +288,7 @@ public void testGenerateContentwithText() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -307,7 +308,7 @@ public void testGenerateContentwithContent() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -329,7 +330,7 @@ public void testGenerateContentwithContents() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -356,7 +357,7 @@ public void testGenerateContentwithDefaultGenerationConfig() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -382,7 +383,7 @@ public void testGenerateContentwithDefaultSafetySettings() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -408,7 +409,7 @@ public void testGenerateContentwithDefaultTools() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -429,7 +430,7 @@ public void testGenerateContentwithFluentApi() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.call(any(GenerateContentRequest.class)))
@@ -457,7 +458,7 @@ public void testGenerateContentStreamwithText() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -480,7 +481,7 @@ public void testGenerateContentStreamwithContent() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -505,7 +506,7 @@ public void testGenerateContentStreamwithContents() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -535,7 +536,7 @@ public void testGenerateContentStreamwithDefaultGenerationConfig() throws Except
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -562,7 +563,7 @@ public void testGenerateContentStreamwithDefaultSafetySettings() throws Exceptio
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -589,7 +590,7 @@ public void testGenerateContentStreamwithDefaultTools() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -611,7 +612,7 @@ public void testGenerateContentStreamwithFluentApi() throws Exception {
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.streamGenerateContentCallable())
         .thenReturn(mockServerStreamCallable);
@@ -641,7 +642,7 @@ public void generateContentAsync_withText_sendsCorrectRequest() throws Exception
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);
@@ -662,7 +663,7 @@ public void generateContentAsync_withContent_sendsCorrectRequest() throws Except
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);
@@ -683,7 +684,7 @@ public void generateContentAsync_withContents_sendsCorrectRequest() throws Excep
 
     Field field = VertexAI.class.getDeclaredField("predictionServiceClient");
     field.setAccessible(true);
-    field.set(vertexAi, mockPredictionServiceClient);
+    field.set(vertexAi, Optional.of(mockPredictionServiceClient));
 
     when(mockPredictionServiceClient.generateContentCallable()).thenReturn(mockUnaryCallable);
     when(mockUnaryCallable.futureCall(any(GenerateContentRequest.class))).thenReturn(mockApiFuture);