Skip to content

Commit

Permalink
Migrates the java client authentication to Flight Auth v2 (#3423)
Browse files Browse the repository at this point in the history
There are also plumbing improvements that allow constructing sessions with specific authentication values, and the sharing of authenticated channels without extraneous duplication of logic.

Fixes #3285
  • Loading branch information
devinrsmith authored Feb 28, 2023
1 parent 3c71320 commit 203fc12
Show file tree
Hide file tree
Showing 41 changed files with 864 additions and 524 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public Optional<AuthContext> login(long protocolVersion, ByteBuffer payload, Han

@Override
public Optional<AuthContext> login(String payload, MetadataResponseListener listener) {
if (payload.length() == 0) {
if (payload.isEmpty()) {
return Optional.of(new AuthContext.Anonymous());
}
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,11 @@
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;

import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

@Module
public class BarrageSessionModule {
@Provides
public static BarrageSession newDeephavenClientSession(
SessionImpl session, BufferAllocator allocator, ManagedChannel managedChannel) {
return BarrageSession.of(session, allocator, managedChannel);
}

@Provides
public static CompletableFuture<? extends BarrageSession> newDeephavenClientSessionFuture(
CompletableFuture<? extends SessionImpl> sessionFuture, BufferAllocator allocator,
ManagedChannel managedChannel) {
return sessionFuture.thenApply((Function<SessionImpl, BarrageSession>) session -> BarrageSession
.of(session, allocator, managedChannel));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;

import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import javax.inject.Named;
import java.util.concurrent.ScheduledExecutorService;

@Subcomponent(modules = {SessionImplModule.class, FlightSessionModule.class, BarrageSessionModule.class})
public interface BarrageSubcomponent extends BarrageSessionFactory {

BarrageSession newBarrageSession();

CompletableFuture<? extends BarrageSession> newBarrageSessionFuture();

@Module(subcomponents = {BarrageSubcomponent.class})
interface DeephavenClientSubcomponentModule {

Expand All @@ -33,6 +32,9 @@ interface Builder extends BarrageSessionFactoryBuilder {

Builder allocator(@BindsInstance BufferAllocator bufferAllocator);

Builder authenticationTypeAndValue(
@BindsInstance @Nullable @Named("authenticationTypeAndValue") String authenticationTypeAndValue);

BarrageSubcomponent build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.deephaven.client.impl.BarrageSession;
import io.deephaven.client.impl.BarrageSessionFactory;
import io.deephaven.client.impl.BarrageSubcomponent.Builder;
import io.deephaven.client.impl.DaggerDeephavenBarrageRoot;
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;
Expand All @@ -21,6 +22,9 @@ abstract class BarrageClientExampleBase implements Callable<Void> {
@ArgGroup(exclusive = false)
ConnectOptions connectOptions;

@ArgGroup(exclusive = true)
AuthenticationOptions authenticationOptions;

protected abstract void execute(BarrageSession session) throws Exception;

@Override
Expand All @@ -32,15 +36,15 @@ public final Void call() throws Exception {
Runtime.getRuntime()
.addShutdownHook(new Thread(() -> onShutdown(scheduler, managedChannel)));

final BarrageSessionFactory barrageFactory =
DaggerDeephavenBarrageRoot.create().factoryBuilder()
.managedChannel(managedChannel)
.scheduler(scheduler)
.allocator(bufferAllocator)
.build();

final Builder builder = DaggerDeephavenBarrageRoot.create().factoryBuilder()
.managedChannel(managedChannel)
.scheduler(scheduler)
.allocator(bufferAllocator);
if (authenticationOptions != null) {
authenticationOptions.ifPresent(builder::authenticationTypeAndValue);
}
final BarrageSessionFactory barrageFactory = builder.build();
final BarrageSession deephavenSession = barrageFactory.newBarrageSession();

try {
try {
execute(deephavenSession);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import io.deephaven.extensions.barrage.BarrageSnapshotOptions;
import io.deephaven.extensions.barrage.BarrageSubscriptionOptions;
import io.deephaven.proto.DeephavenChannel;
import io.deephaven.qst.table.TableSpec;
import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -30,12 +31,9 @@ public static BarrageSession of(
return new BarrageSession(session, client, channel);
}

private final Channel interceptedChannel;

protected BarrageSession(
final SessionImpl session, final FlightClient client, final ManagedChannel channel) {
super(session, client);
this.interceptedChannel = ClientInterceptors.intercept(channel, new AuthInterceptor());
}

@Override
Expand Down Expand Up @@ -64,25 +62,12 @@ public BarrageSnapshot snapshot(final TableHandle tableHandle, final BarrageSnap
return new BarrageSnapshotImpl(this, session.executor(), tableHandle.newRef(), options);
}

public Channel channel() {
return interceptedChannel;
}

private class AuthInterceptor implements ClientInterceptor {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
final MethodDescriptor<ReqT, RespT> methodDescriptor, final CallOptions callOptions,
final Channel channel) {
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
channel.newCall(methodDescriptor, callOptions)) {
@Override
public void start(final Listener<RespT> responseListener, final Metadata headers) {
final AuthenticationInfo localAuth = ((SessionImpl) session()).auth();
headers.put(Metadata.Key.of(localAuth.sessionHeaderKey(), Metadata.ASCII_STRING_MARSHALLER),
localAuth.session());
super.start(responseListener, headers);
}
};
}
/**
* The authenticated channel.
*
* @return the authenticated channel
*/
public DeephavenChannel channel() {
return session.channel();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
*/
package io.deephaven.client.impl;

import java.util.concurrent.CompletableFuture;

public interface BarrageSessionFactory {
BarrageSession newBarrageSession();

CompletableFuture<? extends BarrageSession> newBarrageSessionFuture();
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;

import javax.annotation.Nullable;
import java.util.concurrent.ScheduledExecutorService;

public interface BarrageSessionFactoryBuilder {
Expand All @@ -15,5 +16,7 @@ public interface BarrageSessionFactoryBuilder {

BarrageSessionFactoryBuilder allocator(BufferAllocator bufferAllocator);

BarrageSessionFactoryBuilder authenticationTypeAndValue(@Nullable String authenticationTypeAndValue);

BarrageSessionFactory build();
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public BarrageSnapshotImpl(
final ClientCall<FlightData, BarrageMessage> call;
final Context previous = Context.ROOT.attach();
try {
call = session.channel().newCall(snapshotDescriptor, CallOptions.DEFAULT);
call = session.channel().channel().newCall(snapshotDescriptor, CallOptions.DEFAULT);
} finally {
Context.ROOT.detach(previous);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public BarrageSubscriptionImpl(
final ClientCall<FlightData, BarrageMessage> call;
final Context previous = Context.ROOT.attach();
try {
call = session.channel().newCall(subscribeDescriptor, CallOptions.DEFAULT);
call = session.channel().channel().newCall(subscribeDescriptor, CallOptions.DEFAULT);
} finally {
Context.ROOT.detach(previous);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.deephaven.client.examples;

import picocli.CommandLine.Option;

import java.util.function.Consumer;

public class AuthenticationOptions {
@Option(names = {"--mtls"}, description = "Use the connect mTLS")
Boolean mtls;

@Option(names = {"--psk"}, description = "The pre-shared key")
String psk;

@Option(names = {"--explicit"}, description = "The explicit authentication type and value")
String explicit;

public String toAuthenticationTypeAndValue() {
if (mtls != null && mtls) {
return "io.deephaven.authentication.mtls.MTlsAuthenticationHandler";
}
if (psk != null) {
return "psk " + psk;
}
if (explicit != null) {
return explicit;
}
return null;
}

public void ifPresent(Consumer<String> consumer) {
final String authenticationTypeAndValue = toAuthenticationTypeAndValue();
if (authenticationTypeAndValue != null) {
consumer.accept(authenticationTypeAndValue);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;

import java.util.concurrent.CompletableFuture;
import java.util.function.Function;

@Module
public class FlightSessionModule {

Expand All @@ -19,13 +16,4 @@ public static FlightSession newFlightSession(SessionImpl session, BufferAllocato
ManagedChannel managedChannel) {
return FlightSession.of(session, allocator, managedChannel);
}

@Provides
public static CompletableFuture<? extends FlightSession> newFlightSessionFuture(
CompletableFuture<? extends SessionImpl> sessionFuture, BufferAllocator allocator,
ManagedChannel managedChannel) {
return sessionFuture
.thenApply((Function<SessionImpl, FlightSession>) session -> FlightSession.of(session,
allocator, managedChannel));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;

import java.util.concurrent.CompletableFuture;
import javax.annotation.Nullable;
import javax.inject.Named;
import java.util.concurrent.ScheduledExecutorService;

@Subcomponent(modules = {SessionImplModule.class, FlightSessionModule.class})
public interface FlightSubcomponent extends FlightSessionFactory {

FlightSession newFlightSession();

CompletableFuture<? extends FlightSession> newFlightSessionFuture();

@Module(subcomponents = {FlightSubcomponent.class})
interface FlightSubcomponentModule {

Expand All @@ -33,6 +32,9 @@ interface Builder {

Builder allocator(@BindsInstance BufferAllocator bufferAllocator);

Builder authenticationTypeAndValue(
@BindsInstance @Nullable @Named("authenticationTypeAndValue") String authenticationTypeAndValue);

FlightSubcomponent build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.deephaven.client.impl.DaggerDeephavenFlightRoot;
import io.deephaven.client.impl.FlightSession;
import io.deephaven.client.impl.FlightSessionFactory;
import io.deephaven.client.impl.FlightSubcomponent.Builder;
import io.grpc.ManagedChannel;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
Expand All @@ -21,6 +22,9 @@ abstract class FlightExampleBase implements Callable<Void> {
@ArgGroup(exclusive = false)
ConnectOptions connectOptions;

@ArgGroup(exclusive = true)
AuthenticationOptions authenticationOptions;

BufferAllocator bufferAllocator = new RootAllocator();

protected abstract void execute(FlightSession flight) throws Exception;
Expand All @@ -32,15 +36,15 @@ public final Void call() throws Exception {
Runtime.getRuntime()
.addShutdownHook(new Thread(() -> onShutdown(scheduler, managedChannel)));

FlightSessionFactory flightSessionFactory =
DaggerDeephavenFlightRoot.create().factoryBuilder()
.managedChannel(managedChannel)
.scheduler(scheduler)
.allocator(bufferAllocator)
.build();

final Builder builder = DaggerDeephavenFlightRoot.create().factoryBuilder()
.managedChannel(managedChannel)
.scheduler(scheduler)
.allocator(bufferAllocator);
if (authenticationOptions != null) {
authenticationOptions.ifPresent(builder::authenticationTypeAndValue);
}
FlightSessionFactory flightSessionFactory = builder.build();
FlightSession flightSession = flightSessionFactory.newFlightSession();

try {
try {
execute(flightSession);
Expand Down

This file was deleted.

Loading

0 comments on commit 203fc12

Please sign in to comment.