diff --git a/docs/src/main/asciidoc/websockets-next-reference.adoc b/docs/src/main/asciidoc/websockets-next-reference.adoc index 55203bc86b1a2..6eb75e98c601e 100644 --- a/docs/src/main/asciidoc/websockets-next-reference.adoc +++ b/docs/src/main/asciidoc/websockets-next-reference.adoc @@ -641,6 +641,8 @@ quarkus.http.auth.permission.secured.policy=authenticated Other options for securing HTTP upgrade requests, such as using the security annotations, will be explored in the future. +NOTE: When OpenID Connect extension is used and token expires, Quarkus automatically closes connection. + [[websocket-next-configuration-reference]] == Configuration reference diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java new file mode 100644 index 0000000000000..3351c71033053 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/security/AuthenticationExpiredTest.java @@ -0,0 +1,129 @@ +package io.quarkus.websockets.next.test.security; + +import static io.quarkus.websockets.next.test.security.SecurityTestBase.basicAuth; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URI; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.inject.Inject; +import jakarta.inject.Singleton; + +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.security.Authenticated; +import io.quarkus.security.identity.AuthenticationRequestContext; +import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.security.identity.SecurityIdentityAugmentor; +import io.quarkus.security.runtime.QuarkusSecurityIdentity; +import io.quarkus.security.test.utils.TestIdentityController; +import io.quarkus.security.test.utils.TestIdentityProvider; +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.CloseReason; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.smallrye.mutiny.Uni; +import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; + +public class AuthenticationExpiredTest { + + @Inject + Vertx vertx; + + @TestHTTPResource("end") + URI endUri; + + @BeforeAll + public static void setupUsers() { + TestIdentityController.resetRoles() + .add("admin", "admin", "admin") + .add("user", "user", "user"); + } + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot(root -> root.addClasses(Endpoint.class, TestIdentityProvider.class, + TestIdentityController.class, WSClient.class, ExpiredIdentityAugmentor.class, SecurityTestBase.class)); + + @Test + public void testConnectionClosedWhenAuthExpires() { + try (WSClient client = new WSClient(vertx)) { + client.connect(basicAuth("admin", "admin"), endUri); + + long threeSecondsFromNow = Duration.ofMillis(System.currentTimeMillis()).plusSeconds(3).toMillis(); + for (int i = 1; true; i++) { + if (client.isClosed()) { + break; + } else if (System.currentTimeMillis() > threeSecondsFromNow) { + Assertions.fail("Authentication expired, therefore connection should had been closed"); + } + client.sendAndAwaitReply("Hello #" + i + " from "); + } + + var receivedMessages = client.getMessages().stream().map(Buffer::toString).toList(); + assertTrue(receivedMessages.size() > 2, receivedMessages.toString()); + assertTrue(receivedMessages.contains("Hello #1 from admin"), receivedMessages.toString()); + assertTrue(receivedMessages.contains("Hello #2 from admin"), receivedMessages.toString()); + assertEquals(1008, client.closeStatusCode(), "Expected close status 1008, but got " + client.closeStatusCode()); + + Awaitility + .await() + .atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> assertTrue(Endpoint.CLOSED_MESSAGE.get() + .startsWith("Connection closed with reason 'Authentication expired'"))); + } + } + + @Singleton + public static class ExpiredIdentityAugmentor implements SecurityIdentityAugmentor { + + @Override + public Uni augment(SecurityIdentity securityIdentity, + AuthenticationRequestContext authenticationRequestContext) { + return Uni + .createFrom() + .item(QuarkusSecurityIdentity + .builder(securityIdentity) + .addAttribute("quarkus.identity.expire-time", expireIn2Seconds()) + .build()); + } + + private static long expireIn2Seconds() { + return Duration.ofMillis(System.currentTimeMillis()) + .plusSeconds(2) + .toSeconds(); + } + } + + @WebSocket(path = "/end") + public static class Endpoint { + + static final AtomicReference CLOSED_MESSAGE = new AtomicReference<>(); + + @Inject + SecurityIdentity currentIdentity; + + @Authenticated + @OnTextMessage + String echo(String message) { + return message + currentIdentity.getPrincipal().getName(); + } + + @OnClose + void close(CloseReason reason, WebSocketConnection connection) { + CLOSED_MESSAGE.set("Connection closed with reason '%s': %s".formatted(reason.getMessage(), connection)); + } + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java index ce4d2c096628d..15980876612be 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/Endpoints.java @@ -208,6 +208,7 @@ public void handle(Void event) { handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnClose callback", connection); } + securitySupport.onClose(); onClose.run(); if (timerId != null) { vertx.cancelTimer(timerId); diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java index 8ec115e085e70..eeb5f5a5ad342 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/SecuritySupport.java @@ -1,22 +1,36 @@ package io.quarkus.websockets.next.runtime; import java.util.Objects; +import java.util.concurrent.TimeUnit; import jakarta.enterprise.inject.Instance; +import org.jboss.logging.Logger; + import io.quarkus.security.identity.CurrentIdentityAssociation; import io.quarkus.security.identity.SecurityIdentity; +import io.quarkus.websockets.next.CloseReason; +import io.vertx.core.Vertx; public class SecuritySupport { - static final SecuritySupport NOOP = new SecuritySupport(null, null); + private static final Logger LOG = Logger.getLogger(SecuritySupport.class); + static final SecuritySupport NOOP = new SecuritySupport(null, null, null, null); private final Instance currentIdentity; private final SecurityIdentity identity; + private final Runnable onClose; - SecuritySupport(Instance currentIdentity, SecurityIdentity identity) { + SecuritySupport(Instance currentIdentity, SecurityIdentity identity, Vertx vertx, + WebSocketConnectionImpl connection) { this.currentIdentity = currentIdentity; - this.identity = currentIdentity != null ? Objects.requireNonNull(identity) : identity; + if (this.currentIdentity != null) { + this.identity = Objects.requireNonNull(identity); + this.onClose = closeConnectionWhenIdentityExpired(vertx, connection, this.identity); + } else { + this.identity = null; + this.onClose = null; + } } /** @@ -29,4 +43,25 @@ void start() { } } + void onClose() { + if (onClose != null) { + onClose.run(); + } + } + + private static Runnable closeConnectionWhenIdentityExpired(Vertx vertx, WebSocketConnectionImpl connection, + SecurityIdentity identity) { + if (identity.getAttribute("quarkus.identity.expire-time") instanceof Long expireAt) { + long timerId = vertx.setTimer(TimeUnit.SECONDS.toMillis(expireAt) - System.currentTimeMillis(), + ignored -> connection + .close(new CloseReason(1008, "Authentication expired")) + .subscribe() + .with( + v -> LOG.tracef("Closed connection due to expired authentication: %s", connection), + e -> LOG.errorf("Unable to close connection [%s] after authentication " + + "expired due to unhandled failure: %s", connection, e))); + return () -> vertx.cancelTimer(timerId); + } + return null; + } } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java index 35bdae2ca2206..2878f921d680c 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketServerRecorder.java @@ -90,8 +90,6 @@ public Handler createEndpointHandler(String generatedEndpointCla @Override public void handle(RoutingContext ctx) { - SecuritySupport securitySupport = initializeSecuritySupport(container, ctx); - Future future = ctx.request().toWebSocket(); future.onSuccess(ws -> { Vertx vertx = VertxCoreRecorder.getVertx().get(); @@ -101,6 +99,8 @@ public void handle(RoutingContext ctx) { connectionManager.add(generatedEndpointClass, connection); LOG.debugf("Connection created: %s", connection); + SecuritySupport securitySupport = initializeSecuritySupport(container, ctx, vertx, connection); + Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass, config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(), () -> connectionManager.remove(generatedEndpointClass, connection)); @@ -109,14 +109,15 @@ public void handle(RoutingContext ctx) { }; } - SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx) { + SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx, Vertx vertx, + WebSocketConnectionImpl connection) { Instance currentIdentityAssociation = container.select(CurrentIdentityAssociation.class); if (currentIdentityAssociation.isResolvable()) { // Security extension is present // Obtain the current security identity from the handshake request QuarkusHttpUser user = (QuarkusHttpUser) ctx.user(); if (user != null) { - return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity()); + return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity(), vertx, connection); } } return SecuritySupport.NOOP;