diff --git a/client/src/main/java/io/streamnative/oxia/client/AsyncOxiaClientImpl.java b/client/src/main/java/io/streamnative/oxia/client/AsyncOxiaClientImpl.java index 7414d55b..04cecd8d 100644 --- a/client/src/main/java/io/streamnative/oxia/client/AsyncOxiaClientImpl.java +++ b/client/src/main/java/io/streamnative/oxia/client/AsyncOxiaClientImpl.java @@ -72,7 +72,7 @@ class AsyncOxiaClientImpl implements AsyncOxiaClient { shardManager.addCallback(notificationManager); var readBatchManager = BatchManager.newReadBatchManager(config, stubByShardId, instrumentProvider); - var sessionManager = new SessionManager(config, stubByShardId, instrumentProvider); + var sessionManager = new SessionManager(executor, config, stubByShardId, instrumentProvider); shardManager.addCallback(sessionManager); var writeBatchManager = BatchManager.newWriteBatchManager( diff --git a/client/src/main/java/io/streamnative/oxia/client/session/Session.java b/client/src/main/java/io/streamnative/oxia/client/session/Session.java index 677da3c4..c4cde018 100644 --- a/client/src/main/java/io/streamnative/oxia/client/session/Session.java +++ b/client/src/main/java/io/streamnative/oxia/client/session/Session.java @@ -19,30 +19,31 @@ import static lombok.AccessLevel.PUBLIC; import com.google.common.annotations.VisibleForTesting; +import io.grpc.stub.StreamObserver; import io.opentelemetry.api.common.Attributes; import io.streamnative.oxia.client.ClientConfig; import io.streamnative.oxia.client.grpc.OxiaStub; import io.streamnative.oxia.client.metrics.Counter; import io.streamnative.oxia.client.metrics.InstrumentProvider; import io.streamnative.oxia.client.metrics.Unit; +import io.streamnative.oxia.client.util.Backoff; import io.streamnative.oxia.proto.CloseSessionRequest; +import io.streamnative.oxia.proto.CloseSessionResponse; +import io.streamnative.oxia.proto.KeepAliveResponse; import io.streamnative.oxia.proto.SessionHeartbeat; import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import lombok.Getter; import lombok.NonNull; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import reactor.core.Disposable; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; -import reactor.util.retry.Retry; -import reactor.util.retry.RetryBackoffSpec; - -@RequiredArgsConstructor(access = PACKAGE) + @Slf4j -public class Session implements AutoCloseable { +public class Session implements StreamObserver { private final @NonNull Function stubByShardId; private final @NonNull Duration sessionTimeout; @@ -61,37 +62,42 @@ public class Session implements AutoCloseable { private final @NonNull SessionNotificationListener listener; - private Scheduler scheduler; - private Disposable keepAliveSubscription; + private volatile boolean closed; private Counter sessionsOpened; private Counter sessionsExpired; private Counter sessionsClosed; + private final ScheduledFuture heartbeatFuture; + private final Backoff backoff = new Backoff(); + + private volatile Instant lastSuccessfullResponse; + Session( + @NonNull ScheduledExecutorService executor, @NonNull Function stubByShardId, @NonNull ClientConfig config, long shardId, long sessionId, InstrumentProvider instrumentProvider, SessionNotificationListener listener) { - this( - stubByShardId, - config.sessionTimeout(), + this.stubByShardId = stubByShardId; + this.sessionTimeout = config.sessionTimeout(); + this.heartbeatInterval = Duration.ofMillis( - Math.max(config.sessionTimeout().toMillis() / 10, Duration.ofSeconds(2).toMillis())), - shardId, - sessionId, - config.clientIdentifier(), - SessionHeartbeat.newBuilder().setShardId(shardId).setSessionId(sessionId).build(), - listener); + Math.max(config.sessionTimeout().toMillis() / 10, Duration.ofSeconds(2).toMillis())); + this.shardId = shardId; + this.sessionId = sessionId; + this.clientIdentifier = config.clientIdentifier(); + this.heartbeat = + SessionHeartbeat.newBuilder().setShardId(shardId).setSessionId(sessionId).build(); + this.listener = listener; + log.info( "Session created shard={} sessionId={} clientIdentity={}", shardId, sessionId, config.clientIdentifier()); - var threadName = String.format("session-[id=%s,shard=%s]-keep-alive", sessionId, shardId); - scheduler = Schedulers.newSingle(threadName); this.sessionsOpened = instrumentProvider.newCounter( @@ -113,60 +119,98 @@ public class Session implements AutoCloseable { Attributes.builder().put("oxia.shard", shardId).build()); sessionsOpened.increment(); + + this.lastSuccessfullResponse = Instant.now(); + this.heartbeatFuture = + executor.scheduleAtFixedRate( + this::sendKeepAlive, + heartbeatInterval.toMillis(), + heartbeatInterval.toMillis(), + TimeUnit.MILLISECONDS); } - void start() { - RetryBackoffSpec retrySpec = - Retry.backoff(Long.MAX_VALUE, Duration.ofMillis(100)) - .doBeforeRetry( - signal -> - log.warn( - "Retrying sending keep-alives for session [id={},shard={}] - {}", - sessionId, - shardId, - signal)); - keepAliveSubscription = - Mono.just(heartbeat) - .repeat() - .delayElements(heartbeatInterval) - .flatMap(hb -> stubByShardId.apply(shardId).reactor().keepAlive(hb)) - .retryWhen(retrySpec) - .timeout(sessionTimeout) - .publishOn(scheduler) - .doOnError(this::handleSessionExpired) - .subscribe(); + private void sendKeepAlive() { + Duration diff = Duration.between(lastSuccessfullResponse, Instant.now()); + + if (diff.toMillis() > sessionTimeout.toMillis()) { + handleSessionExpired(); + return; + } + + stubByShardId.apply(shardId).async().keepAlive(heartbeat, this); } - private void handleSessionExpired(Throwable t) { - sessionsExpired.increment(); + @Override + public void onNext(KeepAliveResponse value) { + lastSuccessfullResponse = Instant.now(); + if (log.isDebugEnabled()) { + log.debug( + "Received keep-alive response shard={} sessionId={} clientIdentity={}", + shardId, + sessionId, + clientIdentifier); + } + } + + @Override + public void onError(Throwable t) { log.warn( - "Session expired shard={} sessionId={} clientIdentity={}: {}", + "Error during session keep-alive shard={} sessionId={} clientIdentity={}: {}", shardId, sessionId, clientIdentifier, t.getMessage()); - close(); } @Override - public void close() { + public void onCompleted() { + // Nothing to do + } + + private void handleSessionExpired() { + sessionsExpired.increment(); + log.warn( + "Session expired shard={} sessionId={} clientIdentity={}", + shardId, + sessionId, + clientIdentifier); + close(); + } + + public CompletableFuture close() { sessionsClosed.increment(); - keepAliveSubscription.dispose(); + heartbeatFuture.cancel(true); var stub = stubByShardId.apply(shardId); var request = CloseSessionRequest.newBuilder().setShardId(shardId).setSessionId(sessionId).build(); - try { - stub.blocking().closeSession(request); - log.info( - "Session closed shard={} sessionId={} clientIdentity={}", - shardId, - sessionId, - clientIdentifier); - } catch (Exception e) { - // Ignore errors in closing the session, since it might have already expired - } - scheduler.dispose(); - listener.onSessionClosed(this); + CompletableFuture result = new CompletableFuture<>(); + stub.async() + .closeSession( + request, + new StreamObserver<>() { + @Override + public void onNext(CloseSessionResponse value) { + log.info( + "Session closed shard={} sessionId={} clientIdentity={}", + shardId, + sessionId, + clientIdentifier); + listener.onSessionClosed(Session.this); + result.complete(null); + } + + @Override + public void onError(Throwable t) { + // Ignore errors in closing the session, since it might have already expired + listener.onSessionClosed(Session.this); + result.complete(null); + } + + @Override + public void onCompleted() {} + }); + + return result; } } diff --git a/client/src/main/java/io/streamnative/oxia/client/session/SessionFactory.java b/client/src/main/java/io/streamnative/oxia/client/session/SessionFactory.java index db20c647..25a61537 100644 --- a/client/src/main/java/io/streamnative/oxia/client/session/SessionFactory.java +++ b/client/src/main/java/io/streamnative/oxia/client/session/SessionFactory.java @@ -22,12 +22,14 @@ import io.streamnative.oxia.client.metrics.InstrumentProvider; import io.streamnative.oxia.proto.CreateSessionRequest; import io.streamnative.oxia.proto.CreateSessionResponse; +import java.util.concurrent.ScheduledExecutorService; import java.util.function.Function; import lombok.NonNull; import lombok.RequiredArgsConstructor; @RequiredArgsConstructor(access = PACKAGE) public class SessionFactory { + @NonNull private final ScheduledExecutorService executor; @NonNull final ClientConfig config; @NonNull final SessionNotificationListener listener; @@ -47,6 +49,12 @@ Session create(long shardId) { .build(); CreateSessionResponse response = stub.blocking().createSession(request); return new Session( - stubByShardId, config, shardId, response.getSessionId(), instrumentProvider, listener); + executor, + stubByShardId, + config, + shardId, + response.getSessionId(), + instrumentProvider, + listener); } } diff --git a/client/src/main/java/io/streamnative/oxia/client/session/SessionManager.java b/client/src/main/java/io/streamnative/oxia/client/session/SessionManager.java index ce72a4c7..7a83c760 100644 --- a/client/src/main/java/io/streamnative/oxia/client/session/SessionManager.java +++ b/client/src/main/java/io/streamnative/oxia/client/session/SessionManager.java @@ -27,6 +27,7 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; import java.util.function.Consumer; import java.util.function.Function; import lombok.NonNull; @@ -41,32 +42,23 @@ public class SessionManager private volatile boolean closed = false; public SessionManager( + @NonNull ScheduledExecutorService executor, @NonNull ClientConfig config, @NonNull Function stubByShardId, @NonNull InstrumentProvider instrumentProvider) { - this.factory = new SessionFactory(config, this, stubByShardId, instrumentProvider); + this.factory = new SessionFactory(executor, config, this, stubByShardId, instrumentProvider); } - public SessionManager(SessionFactory factory) { + SessionManager(SessionFactory factory) { this.factory = factory; } @NonNull public Session getSession(long shardId) { - try { - if (closed) { - throw new IllegalStateException("session manager has been closed"); - } - return sessionsByShardId.computeIfAbsent( - shardId, - s -> { - var session = factory.create(shardId); - session.start(); - return session; - }); - } catch (Exception e) { - throw e; + if (closed) { + throw new IllegalStateException("session manager has been closed"); } + return sessionsByShardId.computeIfAbsent(shardId, s -> factory.create(shardId)); } @Override @@ -99,11 +91,7 @@ public void accept(@NonNull ShardAssignmentChanges changes) { @VisibleForTesting Optional closeQuietly(Session session) { if (session != null) { - try { - session.close(); - } catch (Exception e) { - log.warn("Error closing session {}", session.getSessionId(), e); - } + session.close(); } return Optional.ofNullable(session); } diff --git a/client/src/test/java/io/streamnative/oxia/client/session/SessionManagerTest.java b/client/src/test/java/io/streamnative/oxia/client/session/SessionManagerTest.java index 9a6bda42..0e17fb0e 100644 --- a/client/src/test/java/io/streamnative/oxia/client/session/SessionManagerTest.java +++ b/client/src/test/java/io/streamnative/oxia/client/session/SessionManagerTest.java @@ -29,6 +29,9 @@ import io.streamnative.oxia.client.shard.Shard; import io.streamnative.oxia.client.shard.ShardManager.ShardAssignmentChanges; import java.util.Set; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -41,19 +44,25 @@ class SessionManagerTest { @Mock SessionFactory factory; @Mock Session session; SessionManager manager; + ScheduledExecutorService executor; @BeforeEach void setup() { + executor = Executors.newSingleThreadScheduledExecutor(); manager = new SessionManager(factory); } + @AfterEach + void cleanup() { + executor.shutdownNow(); + } + @Test void newSession() { var shardId = 1L; when(factory.create(shardId)).thenReturn(session); assertThat(manager.getSession(shardId)).isSameAs(session); verify(factory).create(shardId); - verify(session).start(); } @Test @@ -62,7 +71,6 @@ void existingSession() { when(factory.create(shardId)).thenReturn(session); var session1 = manager.getSession(shardId); verify(factory, times(1)).create(shardId); - verify(session, times(1)).start(); var session2 = manager.getSession(shardId); assertThat(session2).isSameAs(session1); diff --git a/client/src/test/java/io/streamnative/oxia/client/session/SessionTest.java b/client/src/test/java/io/streamnative/oxia/client/session/SessionTest.java index 6a2dea88..c86d9968 100644 --- a/client/src/test/java/io/streamnative/oxia/client/session/SessionTest.java +++ b/client/src/test/java/io/streamnative/oxia/client/session/SessionTest.java @@ -23,18 +23,21 @@ import io.grpc.Server; import io.grpc.inprocess.InProcessChannelBuilder; import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; import io.streamnative.oxia.client.ClientConfig; import io.streamnative.oxia.client.grpc.OxiaStub; import io.streamnative.oxia.client.metrics.InstrumentProvider; import io.streamnative.oxia.proto.CloseSessionRequest; import io.streamnative.oxia.proto.CloseSessionResponse; import io.streamnative.oxia.proto.KeepAliveResponse; -import io.streamnative.oxia.proto.ReactorOxiaClientGrpc; +import io.streamnative.oxia.proto.OxiaClientGrpc; import io.streamnative.oxia.proto.SessionHeartbeat; import java.io.IOException; import java.time.Duration; -import java.util.LinkedList; -import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import org.junit.jupiter.api.AfterEach; @@ -42,8 +45,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; @ExtendWith(MockitoExtension.class) class SessionTest { @@ -58,10 +59,11 @@ class SessionTest { private Server server; private OxiaStub stub; private TestService service; + private ScheduledExecutorService executor; @BeforeEach void setup() throws IOException { - StepVerifier.setDefaultTimeout(Duration.ofSeconds(3)); + executor = Executors.newSingleThreadScheduledExecutor(); config = new ClientConfig( @@ -92,17 +94,18 @@ void setup() throws IOException { @AfterEach public void stopServer() throws Exception { server.shutdown(); - server.awaitTermination(); stub.close(); server = null; stub = null; + executor.shutdownNow(); } @Test void sessionId() { var session = new Session( + executor, stubByShardId, config, shardId, @@ -117,13 +120,13 @@ void sessionId() { void start() throws Exception { var session = new Session( + executor, stubByShardId, config, shardId, sessionId, InstrumentProvider.NOOP, mock(SessionNotificationListener.class)); - session.start(); await() .untilAsserted( @@ -136,33 +139,33 @@ void start() throws Exception { .setShardId(shardId) .build()); }); - session.close(); + session.close().join(); assertThat(service.closed).isTrue(); assertThat(service.signalsAfterClosed).isEmpty(); } - static class TestService extends ReactorOxiaClientGrpc.OxiaClientImplBase { - List signals = new LinkedList<>(); - List signalsAfterClosed = new LinkedList<>(); + static class TestService extends OxiaClientGrpc.OxiaClientImplBase { + BlockingQueue signals = new LinkedBlockingQueue<>(); + BlockingQueue signalsAfterClosed = new LinkedBlockingQueue<>(); AtomicBoolean closed = new AtomicBoolean(false); @Override - public Mono keepAlive(Mono request) { - return request.map( - heartbeat -> { - if (!closed.get()) { - signals.add(heartbeat); - } else { - signalsAfterClosed.add(heartbeat); - } - return KeepAliveResponse.getDefaultInstance(); - }); + public void keepAlive( + SessionHeartbeat heartbeat, StreamObserver responseObserver) { + if (!closed.get()) { + signals.add(heartbeat); + } else { + signalsAfterClosed.add(heartbeat); + } + + responseObserver.onNext(KeepAliveResponse.getDefaultInstance()); } @Override - public Mono closeSession(Mono request) { + public void closeSession( + CloseSessionRequest request, StreamObserver responseObserver) { closed.compareAndSet(false, true); - return Mono.just(CloseSessionResponse.getDefaultInstance()); + responseObserver.onNext(CloseSessionResponse.getDefaultInstance()); } } }