diff --git a/plugin/trino-cassandra/pom.xml b/plugin/trino-cassandra/pom.xml
index bcff703d6f71..22e551aeeada 100644
--- a/plugin/trino-cassandra/pom.xml
+++ b/plugin/trino-cassandra/pom.xml
@@ -51,6 +51,12 @@
jackson-databind
+
+ com.google.errorprone
+ error_prone_annotations
+ true
+
+
com.google.guava
guava
@@ -148,12 +154,6 @@
provided
-
- com.google.errorprone
- error_prone_annotations
- runtime
-
-
io.airlift
concurrent
diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java
index a2f8089a1c01..fc1451013d63 100644
--- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java
+++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraClientModule.java
@@ -54,6 +54,7 @@
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.airlift.json.JsonBinder.jsonBinder;
import static io.airlift.json.JsonCodecBinder.jsonCodecBinder;
+import static io.trino.plugin.base.ClosingBinder.closingBinder;
import static io.trino.plugin.base.ssl.SslUtils.createSSLContext;
import static io.trino.plugin.cassandra.CassandraErrorCode.CASSANDRA_SSL_INITIALIZATION_FAILURE;
import static java.util.Objects.requireNonNull;
@@ -88,6 +89,8 @@ public void configure(Binder binder)
jsonCodecBinder(binder).bindListJsonCodec(ExtraColumnMetadata.class);
jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class);
+
+ closingBinder(binder).registerCloseable(CassandraSession.class);
}
public static final class TypeDeserializer
diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java
index 6353c6036b66..7d06071d5d11 100644
--- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java
+++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraSession.java
@@ -46,6 +46,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
@@ -76,7 +77,6 @@
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.literal;
import static com.datastax.oss.driver.api.querybuilder.QueryBuilder.selectFrom;
import static com.google.common.base.Preconditions.checkState;
-import static com.google.common.base.Suppliers.memoize;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getOnlyElement;
@@ -105,9 +105,13 @@ public class CassandraSession
private final CassandraTypeManager cassandraTypeManager;
private final JsonCodec> extraColumnMetadataCodec;
- private final Supplier session;
private final Duration noHostAvailableRetryTimeout;
+ @GuardedBy("this")
+ private Supplier sessionSupplier;
+ @GuardedBy("this")
+ private CqlSession session;
+
public CassandraSession(
CassandraTypeManager cassandraTypeManager,
JsonCodec> extraColumnMetadataCodec,
@@ -117,7 +121,16 @@ public CassandraSession(
this.cassandraTypeManager = requireNonNull(cassandraTypeManager, "cassandraTypeManager is null");
this.extraColumnMetadataCodec = requireNonNull(extraColumnMetadataCodec, "extraColumnMetadataCodec is null");
this.noHostAvailableRetryTimeout = requireNonNull(noHostAvailableRetryTimeout, "noHostAvailableRetryTimeout is null");
- this.session = memoize(sessionSupplier::get);
+ this.sessionSupplier = requireNonNull(sessionSupplier, "sessionSupplier is null");
+ }
+
+ private synchronized CqlSession session()
+ {
+ if (session == null) {
+ checkState(sessionSupplier != null, "already closed");
+ session = sessionSupplier.get();
+ }
+ return session;
}
public Version getCassandraVersion()
@@ -559,12 +572,12 @@ private void checkSizeEstimatesTableExist()
private T executeWithSession(SessionCallable sessionCallable)
{
- ReconnectionPolicy reconnectionPolicy = session.get().getContext().getReconnectionPolicy();
+ ReconnectionPolicy reconnectionPolicy = session().getContext().getReconnectionPolicy();
ReconnectionPolicy.ReconnectionSchedule schedule = reconnectionPolicy.newControlConnectionSchedule(false);
long deadline = System.currentTimeMillis() + noHostAvailableRetryTimeout.toMillis();
while (true) {
try {
- return sessionCallable.executeWithSession(session.get());
+ return sessionCallable.executeWithSession(session());
}
catch (AllNodesFailedException e) {
long timeLeft = deadline - System.currentTimeMillis();
@@ -611,9 +624,13 @@ private List getTypeArguments(DataType dataType)
}
@Override
- public void close()
+ public synchronized void close()
{
- session.get().close();
+ sessionSupplier = null;
+ if (session != null) {
+ session.close();
+ session = null;
+ }
}
private interface SessionCallable