Skip to content

Commit

Permalink
Fixed session expiration handling (#137)
Browse files Browse the repository at this point in the history
* Fixed session expiration handling

* Fixed test

* spotless
  • Loading branch information
merlimat authored Apr 29, 2024
1 parent 494acac commit 7396536
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import io.streamnative.oxia.client.grpc.OxiaStub;
import io.streamnative.oxia.client.metrics.SessionMetrics;
import io.streamnative.oxia.proto.CloseSessionRequest;
import io.streamnative.oxia.proto.CreateSessionRequest;
import io.streamnative.oxia.proto.SessionHeartbeat;
import java.time.Duration;
import java.util.function.Function;
Expand Down Expand Up @@ -53,9 +52,13 @@ public class Session implements AutoCloseable {
@Getter(PUBLIC)
private final long sessionId;

private final String clientIdentifier;

private final @NonNull SessionHeartbeat heartbeat;
private final @NonNull SessionMetrics metrics;

private final @NonNull SessionNotificationListener listener;

private Scheduler scheduler;
private Disposable keepAliveSubscription;

Expand All @@ -64,16 +67,24 @@ public class Session implements AutoCloseable {
@NonNull ClientConfig config,
long shardId,
long sessionId,
SessionMetrics metrics) {
SessionMetrics metrics,
SessionNotificationListener listener) {
this(
stubByShardId,
config.sessionTimeout(),
Duration.ofMillis(
Math.max(config.sessionTimeout().toMillis() / 10, Duration.ofSeconds(2).toMillis())),
shardId,
sessionId,
config.clientIdentifier(),
SessionHeartbeat.newBuilder().setShardId(shardId).setSessionId(sessionId).build(),
metrics);
metrics,
listener);
log.info(
"Session created shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
config.clientIdentifier());
var threadName = String.format("session-[id=%s,shard=%s]-keep-alive", sessionId, shardId);
scheduler = Schedulers.newSingle(threadName);
}
Expand All @@ -97,41 +108,38 @@ void start() {
.timeout(sessionTimeout)
.publishOn(scheduler)
.doOnEach(metrics::recordKeepAlive)
.doOnError(
t -> log.warn("Session keep-alive error: [id={},shard={}]", sessionId, shardId, t))
.doOnError(this::handleSessionExpired)
.subscribe();
}

private void handleSessionExpired(Throwable t) {
log.warn(
"Session expired shard={} sessionId={} clientIdentity={}: {}",
shardId,
sessionId,
clientIdentifier,
t.getMessage());
close();
}

@Override
public void close() throws Exception {
public void close() {
keepAliveSubscription.dispose();
var stub = stubByShardId.apply(shardId);
var request =
CloseSessionRequest.newBuilder().setShardId(shardId).setSessionId(sessionId).build();
stub.reactor().closeSession(request).block();
scheduler.dispose();
}

@RequiredArgsConstructor(access = PACKAGE)
static class Factory {
@NonNull ClientConfig config;
@NonNull Function<Long, OxiaStub> stubByShardId;
@NonNull SessionMetrics metrics;

@NonNull
Session create(long shardId) {
var stub = stubByShardId.apply(shardId);
var request =
CreateSessionRequest.newBuilder()
.setSessionTimeoutMs((int) config.sessionTimeout().toMillis())
.setShardId(shardId)
.setClientIdentity(config.clientIdentifier())
.build();
var response = stub.reactor().createSession(request).block();
if (response == null) {
throw new IllegalStateException("Empty session returned for shardId: " + shardId);
}
return new Session(stubByShardId, config, shardId, response.getSessionId(), metrics);
try {
stub.blocking().closeSession(request);
log.info(
"Session closed shard={} sessionId={} clientIdentity={}",
shardId,
sessionId,
clientIdentifier);
} catch (Exception e) {
// Ignore errors in closing the session, since it might have already expired
}
scheduler.dispose();
listener.onSessionClosed(this);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright © 2022-2024 StreamNative Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.streamnative.oxia.client.session;

import static lombok.AccessLevel.PACKAGE;

import io.streamnative.oxia.client.ClientConfig;
import io.streamnative.oxia.client.grpc.OxiaStub;
import io.streamnative.oxia.client.metrics.SessionMetrics;
import io.streamnative.oxia.proto.CreateSessionRequest;
import io.streamnative.oxia.proto.CreateSessionResponse;
import java.util.function.Function;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor(access = PACKAGE)
public class SessionFactory {
@NonNull final ClientConfig config;

@NonNull final SessionNotificationListener listener;

@NonNull final Function<Long, OxiaStub> stubByShardId;

@NonNull final SessionMetrics metrics;

@NonNull
Session create(long shardId) {
var stub = stubByShardId.apply(shardId);
var request =
CreateSessionRequest.newBuilder()
.setSessionTimeoutMs((int) config.sessionTimeout().toMillis())
.setShardId(shardId)
.setClientIdentity(config.clientIdentifier())
.build();
CreateSessionResponse response = stub.blocking().createSession(request);
return new Session(stubByShardId, config, shardId, response.getSessionId(), metrics, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
package io.streamnative.oxia.client.session;

import static java.util.Collections.unmodifiableMap;
import static lombok.AccessLevel.PACKAGE;

import com.google.common.annotations.VisibleForTesting;
import io.streamnative.oxia.client.ClientConfig;
import io.streamnative.oxia.client.grpc.OxiaStub;
import io.streamnative.oxia.client.metrics.SessionMetrics;
import io.streamnative.oxia.client.shard.ShardManager.ShardAssignmentChanges;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
Expand All @@ -32,20 +30,24 @@
import java.util.function.Consumer;
import java.util.function.Function;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

@Slf4j
@RequiredArgsConstructor(access = PACKAGE)
public class SessionManager implements AutoCloseable, Consumer<ShardAssignmentChanges> {
public class SessionManager
implements AutoCloseable, Consumer<ShardAssignmentChanges>, SessionNotificationListener {

private final ConcurrentMap<Long, Session> sessionsByShardId = new ConcurrentHashMap<>();
private final @NonNull Session.Factory factory;
private final SessionFactory factory;
private volatile boolean closed = false;

public SessionManager(
@NonNull ClientConfig config, @NonNull Function<Long, OxiaStub> stubByShardId) {
this(new Session.Factory(config, stubByShardId, SessionMetrics.create(config.metrics())));
this.factory =
new SessionFactory(config, this, stubByShardId, SessionMetrics.create(config.metrics()));
}

public SessionManager(SessionFactory factory) {
this.factory = factory;
}

@NonNull
Expand All @@ -66,16 +68,18 @@ public Session getSession(long shardId) {
}
}

@Override
public void onSessionClosed(Session session) {
sessionsByShardId.remove(session.getSessionId(), session);
}

@Override
public void close() throws Exception {
if (closed) {
return;
}
closed = true;
var closedSessions = new ArrayList<Session>();
sessionsByShardId.entrySet().parallelStream()
.forEach(entry -> closeQuietly(entry.getValue()).ifPresent(closedSessions::add));
closedSessions.forEach(s -> sessionsByShardId.remove(s.getSessionId()));
sessionsByShardId.entrySet().parallelStream().forEach(entry -> closeQuietly(entry.getValue()));
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright © 2022-2024 StreamNative Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.streamnative.oxia.client.session;

public interface SessionNotificationListener {
void onSessionClosed(Session session);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.awaitility.Awaitility.await;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -37,7 +38,7 @@
@ExtendWith(MockitoExtension.class)
class SessionManagerTest {

@Mock Session.Factory factory;
@Mock SessionFactory factory;
@Mock Session session;
SessionManager manager;

Expand Down Expand Up @@ -73,6 +74,13 @@ void close() throws Exception {
var shardId = 1L;
when(session.getSessionId()).thenReturn(shardId);
when(factory.create(shardId)).thenReturn(session);
doAnswer(
invocation -> {
manager.onSessionClosed(session);
return null;
})
.when(session)
.close();
manager.getSession(shardId);

assertThat(manager.sessions()).containsEntry(shardId, session);
Expand Down Expand Up @@ -110,6 +118,26 @@ void accept() throws Exception {
verify(session).close();
}

@Test
void testSessionClosed() throws Exception {
var shardId = 1L;
when(session.getSessionId()).thenReturn(shardId);
doAnswer(
invocation -> {
manager.onSessionClosed(session);
return null;
})
.when(session)
.close();
when(factory.create(shardId)).thenReturn(session);
manager.getSession(shardId);

assertThat(manager.sessions()).containsEntry(shardId, session);

session.close();
assertThat(manager.sessions()).doesNotContainKey(shardId);
}

@Test
void closeQuietly() throws Exception {
var value = manager.closeQuietly(session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.awaitility.Awaitility.await;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import io.grpc.Server;
Expand Down Expand Up @@ -108,14 +109,28 @@ public void stopServer() throws Exception {

@Test
void sessionId() {
var session = new Session(stubByShardId, config, shardId, sessionId, metrics);
var session =
new Session(
stubByShardId,
config,
shardId,
sessionId,
metrics,
mock(SessionNotificationListener.class));
assertThat(session.getShardId()).isEqualTo(shardId);
assertThat(session.getSessionId()).isEqualTo(sessionId);
}

@Test
void start() throws Exception {
var session = new Session(stubByShardId, config, shardId, sessionId, metrics);
var session =
new Session(
stubByShardId,
config,
shardId,
sessionId,
metrics,
mock(SessionNotificationListener.class));
session.start();

await()
Expand Down

0 comments on commit 7396536

Please sign in to comment.