diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java index d33642b9c3a..1429253d194 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java @@ -24,6 +24,7 @@ import com.datastax.driver.core.ConsistencyLevel; import com.datastax.driver.core.PlainTextAuthProvider; import com.datastax.driver.core.QueryOptions; +import com.datastax.driver.core.SSLOptions; import com.datastax.driver.core.Session; import com.datastax.driver.core.SocketOptions; import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy; @@ -192,6 +193,9 @@ public abstract static class Read extends PTransform> @Nullable abstract ValueProvider> ringRanges(); + @Nullable + abstract ValueProvider sslOptions(); + abstract Builder builder(); /** Specify the hosts of the Apache Cassandra instances. */ @@ -385,6 +389,22 @@ public Read withRingRanges(ValueProvider> ringRange) { return builder().setRingRanges(ringRange).build(); } + /** + * Optionally, specify {@link SSLOptions} configuration to utilize SSL. See + * https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic + */ + public Read withSsl(SSLOptions sslOptions) { + return withSsl(ValueProvider.StaticValueProvider.of(sslOptions)); + } + + /** + * Optionally, specify {@link SSLOptions} configuration to utilize SSL. See + * https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic + */ + public Read withSsl(ValueProvider sslOptions) { + return builder().setSslOptions(sslOptions).build(); + } + @Override public PCollection expand(PBegin input) { checkArgument((hosts() != null && port() != null), "WithHosts() and withPort() are required"); @@ -422,7 +442,8 @@ private static Set getRingRanges(Read read) { read.localDc(), read.consistencyLevel(), read.connectTimeout(), - read.readTimeout())) { + read.readTimeout(), + read.sslOptions())) { if (isMurmur3Partitioner(cluster)) { LOG.info("Murmur3Partitioner detected, splitting"); Integer splitCount; @@ -495,6 +516,8 @@ abstract static class Builder { abstract Builder setRingRanges(ValueProvider> ringRange); + abstract Builder setSslOptions(ValueProvider sslOptions); + abstract Read autoBuild(); public Read build() { @@ -543,6 +566,8 @@ public abstract static class Write extends PTransform, PDone> abstract @Nullable ValueProvider readTimeout(); + abstract @Nullable ValueProvider sslOptions(); + abstract @Nullable SerializableFunction mapperFactoryFn(); abstract Builder builder(); @@ -725,6 +750,22 @@ public Write withMapperFactoryFn(SerializableFunction mapper return builder().setMapperFactoryFn(mapperFactoryFn).build(); } + /** + * Optionally, specify {@link SSLOptions} configuration to utilize SSL. See + * https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic + */ + public Write withSsl(SSLOptions sslOptions) { + return withSsl(ValueProvider.StaticValueProvider.of(sslOptions)); + } + + /** + * Optionally, specify {@link SSLOptions} configuration to utilize SSL. See + * https://docs.datastax.com/en/developer/java-driver/3.11/manual/ssl/#jsse-programmatic + */ + public Write withSsl(ValueProvider sslOptions) { + return builder().setSslOptions(sslOptions).build(); + } + @Override public void validate(PipelineOptions pipelineOptions) { checkState( @@ -799,6 +840,8 @@ abstract static class Builder { abstract Optional> mapperFactoryFn(); + abstract Builder setSslOptions(ValueProvider sslOptions); + abstract Write autoBuild(); // not public public Write build() { @@ -880,7 +923,8 @@ static Cluster getCluster( ValueProvider localDc, ValueProvider consistencyLevel, ValueProvider connectTimeout, - ValueProvider readTimeout) { + ValueProvider readTimeout, + ValueProvider sslOptions) { Cluster.Builder builder = Cluster.builder().addContactPoints(hosts.get().toArray(new String[0])).withPort(port.get()); @@ -913,6 +957,10 @@ static Cluster getCluster( socketOptions.setReadTimeoutMillis(readTimeout.get()); } + if (sslOptions != null) { + builder.withSSL(sslOptions.get()); + } + return builder.build(); } @@ -941,7 +989,8 @@ private static class Mutator { spec.localDc(), spec.consistencyLevel(), spec.connectTimeout(), - spec.readTimeout()); + spec.readTimeout(), + spec.sslOptions()); this.session = cluster.connect(spec.keyspace().get()); this.mapperFactoryFn = spec.mapperFactoryFn(); this.mutateFutures = new ArrayList<>(); diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java index 21e7d257dca..962e8ad8ec0 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/ConnectionManager.java @@ -71,7 +71,8 @@ static Session getSession(Read read) { read.localDc(), read.consistencyLevel(), read.connectTimeout(), - read.readTimeout())); + read.readTimeout(), + read.sslOptions())); return sessionMap.computeIfAbsent( readToSessionHash(read), k -> cluster.connect(Objects.requireNonNull(read.keyspace()).get()));