Skip to content

Commit

Permalink
Refactor to avoid using 1 thread per shard for sessions management (#150
Browse files Browse the repository at this point in the history
)
  • Loading branch information
merlimat authored May 6, 2024
1 parent f3b2a5f commit 1068626
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class AsyncOxiaClientImpl implements AsyncOxiaClient {
shardManager.addCallback(notificationManager);
var readBatchManager =
BatchManager.newReadBatchManager(config, stubByShardId, instrumentProvider);
var sessionManager = new SessionManager(config, stubByShardId, instrumentProvider);
var sessionManager = new SessionManager(executor, config, stubByShardId, instrumentProvider);
shardManager.addCallback(sessionManager);
var writeBatchManager =
BatchManager.newWriteBatchManager(
Expand Down
166 changes: 105 additions & 61 deletions client/src/main/java/io/streamnative/oxia/client/session/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,31 @@
import static lombok.AccessLevel.PUBLIC;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.stub.StreamObserver;
import io.opentelemetry.api.common.Attributes;
import io.streamnative.oxia.client.ClientConfig;
import io.streamnative.oxia.client.grpc.OxiaStub;
import io.streamnative.oxia.client.metrics.Counter;
import io.streamnative.oxia.client.metrics.InstrumentProvider;
import io.streamnative.oxia.client.metrics.Unit;
import io.streamnative.oxia.client.util.Backoff;
import io.streamnative.oxia.proto.CloseSessionRequest;
import io.streamnative.oxia.proto.CloseSessionResponse;
import io.streamnative.oxia.proto.KeepAliveResponse;
import io.streamnative.oxia.proto.SessionHeartbeat;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import reactor.core.Disposable;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;
import reactor.util.retry.Retry;
import reactor.util.retry.RetryBackoffSpec;

@RequiredArgsConstructor(access = PACKAGE)

@Slf4j
public class Session implements AutoCloseable {
public class Session implements StreamObserver<KeepAliveResponse> {

private final @NonNull Function<Long, OxiaStub> stubByShardId;
private final @NonNull Duration sessionTimeout;
Expand All @@ -61,37 +62,42 @@ public class Session implements AutoCloseable {

private final @NonNull SessionNotificationListener listener;

private Scheduler scheduler;
private Disposable keepAliveSubscription;
private volatile boolean closed;

private Counter sessionsOpened;
private Counter sessionsExpired;
private Counter sessionsClosed;

private final ScheduledFuture<?> heartbeatFuture;
private final Backoff backoff = new Backoff();

private volatile Instant lastSuccessfullResponse;

Session(
@NonNull ScheduledExecutorService executor,
@NonNull Function<Long, OxiaStub> stubByShardId,
@NonNull ClientConfig config,
long shardId,
long sessionId,
InstrumentProvider instrumentProvider,
SessionNotificationListener listener) {
this(
stubByShardId,
config.sessionTimeout(),
this.stubByShardId = stubByShardId;
this.sessionTimeout = config.sessionTimeout();
this.heartbeatInterval =
Duration.ofMillis(
Math.max(config.sessionTimeout().toMillis() / 10, Duration.ofSeconds(2).toMillis())),
shardId,
sessionId,
config.clientIdentifier(),
SessionHeartbeat.newBuilder().setShardId(shardId).setSessionId(sessionId).build(),
listener);
Math.max(config.sessionTimeout().toMillis() / 10, Duration.ofSeconds(2).toMillis()));
this.shardId = shardId;
this.sessionId = sessionId;
this.clientIdentifier = config.clientIdentifier();
this.heartbeat =
SessionHeartbeat.newBuilder().setShardId(shardId).setSessionId(sessionId).build();
this.listener = listener;

log.info(
"Session created shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
config.clientIdentifier());
var threadName = String.format("session-[id=%s,shard=%s]-keep-alive", sessionId, shardId);
scheduler = Schedulers.newSingle(threadName);

this.sessionsOpened =
instrumentProvider.newCounter(
Expand All @@ -113,60 +119,98 @@ public class Session implements AutoCloseable {
Attributes.builder().put("oxia.shard", shardId).build());

sessionsOpened.increment();

this.lastSuccessfullResponse = Instant.now();
this.heartbeatFuture =
executor.scheduleAtFixedRate(
this::sendKeepAlive,
heartbeatInterval.toMillis(),
heartbeatInterval.toMillis(),
TimeUnit.MILLISECONDS);
}

void start() {
RetryBackoffSpec retrySpec =
Retry.backoff(Long.MAX_VALUE, Duration.ofMillis(100))
.doBeforeRetry(
signal ->
log.warn(
"Retrying sending keep-alives for session [id={},shard={}] - {}",
sessionId,
shardId,
signal));
keepAliveSubscription =
Mono.just(heartbeat)
.repeat()
.delayElements(heartbeatInterval)
.flatMap(hb -> stubByShardId.apply(shardId).reactor().keepAlive(hb))
.retryWhen(retrySpec)
.timeout(sessionTimeout)
.publishOn(scheduler)
.doOnError(this::handleSessionExpired)
.subscribe();
private void sendKeepAlive() {
Duration diff = Duration.between(lastSuccessfullResponse, Instant.now());

if (diff.toMillis() > sessionTimeout.toMillis()) {
handleSessionExpired();
return;
}

stubByShardId.apply(shardId).async().keepAlive(heartbeat, this);
}

private void handleSessionExpired(Throwable t) {
sessionsExpired.increment();
@Override
public void onNext(KeepAliveResponse value) {
lastSuccessfullResponse = Instant.now();
if (log.isDebugEnabled()) {
log.debug(
"Received keep-alive response shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
clientIdentifier);
}
}

@Override
public void onError(Throwable t) {
log.warn(
"Session expired shard={} sessionId={} clientIdentity={}: {}",
"Error during session keep-alive shard={} sessionId={} clientIdentity={}: {}",
shardId,
sessionId,
clientIdentifier,
t.getMessage());
close();
}

@Override
public void close() {
public void onCompleted() {
// Nothing to do
}

private void handleSessionExpired() {
sessionsExpired.increment();
log.warn(
"Session expired shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
clientIdentifier);
close();
}

public CompletableFuture<Void> close() {
sessionsClosed.increment();
keepAliveSubscription.dispose();
heartbeatFuture.cancel(true);
var stub = stubByShardId.apply(shardId);
var request =
CloseSessionRequest.newBuilder().setShardId(shardId).setSessionId(sessionId).build();

try {
stub.blocking().closeSession(request);
log.info(
"Session closed shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
clientIdentifier);
} catch (Exception e) {
// Ignore errors in closing the session, since it might have already expired
}
scheduler.dispose();
listener.onSessionClosed(this);
CompletableFuture<Void> result = new CompletableFuture<>();
stub.async()
.closeSession(
request,
new StreamObserver<>() {
@Override
public void onNext(CloseSessionResponse value) {
log.info(
"Session closed shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
clientIdentifier);
listener.onSessionClosed(Session.this);
result.complete(null);
}

@Override
public void onError(Throwable t) {
// Ignore errors in closing the session, since it might have already expired
listener.onSessionClosed(Session.this);
result.complete(null);
}

@Override
public void onCompleted() {}
});

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import io.streamnative.oxia.client.metrics.InstrumentProvider;
import io.streamnative.oxia.proto.CreateSessionRequest;
import io.streamnative.oxia.proto.CreateSessionResponse;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor(access = PACKAGE)
public class SessionFactory {
@NonNull private final ScheduledExecutorService executor;
@NonNull final ClientConfig config;

@NonNull final SessionNotificationListener listener;
Expand All @@ -47,6 +49,12 @@ Session create(long shardId) {
.build();
CreateSessionResponse response = stub.blocking().createSession(request);
return new Session(
stubByShardId, config, shardId, response.getSessionId(), instrumentProvider, listener);
executor,
stubByShardId,
config,
shardId,
response.getSessionId(),
instrumentProvider,
listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Consumer;
import java.util.function.Function;
import lombok.NonNull;
Expand All @@ -41,32 +42,23 @@ public class SessionManager
private volatile boolean closed = false;

public SessionManager(
@NonNull ScheduledExecutorService executor,
@NonNull ClientConfig config,
@NonNull Function<Long, OxiaStub> stubByShardId,
@NonNull InstrumentProvider instrumentProvider) {
this.factory = new SessionFactory(config, this, stubByShardId, instrumentProvider);
this.factory = new SessionFactory(executor, config, this, stubByShardId, instrumentProvider);
}

public SessionManager(SessionFactory factory) {
SessionManager(SessionFactory factory) {
this.factory = factory;
}

@NonNull
public Session getSession(long shardId) {
try {
if (closed) {
throw new IllegalStateException("session manager has been closed");
}
return sessionsByShardId.computeIfAbsent(
shardId,
s -> {
var session = factory.create(shardId);
session.start();
return session;
});
} catch (Exception e) {
throw e;
if (closed) {
throw new IllegalStateException("session manager has been closed");
}
return sessionsByShardId.computeIfAbsent(shardId, s -> factory.create(shardId));
}

@Override
Expand Down Expand Up @@ -99,11 +91,7 @@ public void accept(@NonNull ShardAssignmentChanges changes) {
@VisibleForTesting
Optional<Session> closeQuietly(Session session) {
if (session != null) {
try {
session.close();
} catch (Exception e) {
log.warn("Error closing session {}", session.getSessionId(), e);
}
session.close();
}
return Optional.ofNullable(session);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import io.streamnative.oxia.client.shard.Shard;
import io.streamnative.oxia.client.shard.ShardManager.ShardAssignmentChanges;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -41,19 +44,25 @@ class SessionManagerTest {
@Mock SessionFactory factory;
@Mock Session session;
SessionManager manager;
ScheduledExecutorService executor;

@BeforeEach
void setup() {
executor = Executors.newSingleThreadScheduledExecutor();
manager = new SessionManager(factory);
}

@AfterEach
void cleanup() {
executor.shutdownNow();
}

@Test
void newSession() {
var shardId = 1L;
when(factory.create(shardId)).thenReturn(session);
assertThat(manager.getSession(shardId)).isSameAs(session);
verify(factory).create(shardId);
verify(session).start();
}

@Test
Expand All @@ -62,7 +71,6 @@ void existingSession() {
when(factory.create(shardId)).thenReturn(session);
var session1 = manager.getSession(shardId);
verify(factory, times(1)).create(shardId);
verify(session, times(1)).start();

var session2 = manager.getSession(shardId);
assertThat(session2).isSameAs(session1);
Expand Down
Loading

0 comments on commit 1068626

Please sign in to comment.