Skip to content

Commit

Permalink
Changing AutoCloseableRefCounted to a generic class
Browse files Browse the repository at this point in the history
This is in response to PR feedback, and helps avoid a verbose get() method by introducing the generic at the class level.

Signed-off-by: Kartik Ganesh <[email protected]>
  • Loading branch information
kartg committed Mar 15, 2022
1 parent 065798f commit 9988089
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
* Adapter class that enables a {@link RefCounted} implementation to function like an {@link AutoCloseable}.
* The {@link #close()} API invokes {@link RefCounted#decRef()} and ensures idempotency using a {@link OneWayGate}.
*/
public class AutoCloseableRefCounted implements AutoCloseable {
public class AutoCloseableRefCounted<T extends RefCounted> implements AutoCloseable {

private final RefCounted ref;
private final T ref;
private final OneWayGate gate;

public AutoCloseableRefCounted(RefCounted ref) {
public AutoCloseableRefCounted(T ref) {
this.ref = ref;
gate = new OneWayGate();
}

public <T extends RefCounted> T get(Class<T> returnType) {
return (T) ref;
public T get() {
return ref;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private void doRecovery(final long recoveryId, final StartRecoveryRequest preExi
logger.trace("not running recovery with id [{}] - can not find it (probably finished)", recoveryId);
return;
}
final RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
final RecoveryTarget recoveryTarget = recoveryRef.get();
timer = recoveryTarget.state().getTimer();
cancellableThreads = recoveryTarget.cancellableThreads();
if (preExistingRequest == null) {
Expand Down Expand Up @@ -377,7 +377,7 @@ public void messageReceived(RecoveryPrepareForTranslogOperationsRequest request,
if (listener == null) {
return;
}
recoveryRef.get(RecoveryTarget.class).prepareForTranslogOperations(request.totalTranslogOps(), listener);
recoveryRef.get().prepareForTranslogOperations(request.totalTranslogOps(), listener);
}
}
}
Expand All @@ -391,7 +391,7 @@ public void messageReceived(RecoveryFinalizeRecoveryRequest request, TransportCh
if (listener == null) {
return;
}
recoveryRef.get(RecoveryTarget.class).finalizeRecovery(request.globalCheckpoint(), request.trimAboveSeqNo(), listener);
recoveryRef.get().finalizeRecovery(request.globalCheckpoint(), request.trimAboveSeqNo(), listener);
}
}
}
Expand All @@ -402,7 +402,7 @@ class HandoffPrimaryContextRequestHandler implements TransportRequestHandler<Rec
public void messageReceived(final RecoveryHandoffPrimaryContextRequest request, final TransportChannel channel, Task task)
throws Exception {
try (RecoveryRef recoveryRef = onGoingRecoveries.getRecoverySafe(request.recoveryId(), request.shardId())) {
recoveryRef.get(RecoveryTarget.class).handoffPrimaryContext(request.primaryContext());
recoveryRef.get().handoffPrimaryContext(request.primaryContext());
}
channel.sendResponse(TransportResponse.Empty.INSTANCE);
}
Expand All @@ -415,7 +415,7 @@ class TranslogOperationsRequestHandler implements TransportRequestHandler<Recove
public void messageReceived(final RecoveryTranslogOperationsRequest request, final TransportChannel channel, Task task)
throws IOException {
try (RecoveryRef recoveryRef = onGoingRecoveries.getRecoverySafe(request.recoveryId(), request.shardId())) {
final RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
final RecoveryTarget recoveryTarget = recoveryRef.get();
final ActionListener<Void> listener = createOrFinishListener(
recoveryRef,
channel,
Expand All @@ -436,7 +436,7 @@ private void performTranslogOps(
final ActionListener<Void> listener,
final RecoveryRef recoveryRef
) {
final RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
final RecoveryTarget recoveryTarget = recoveryRef.get();

final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext());
final Consumer<Exception> retryOnMappingException = exception -> {
Expand Down Expand Up @@ -501,7 +501,7 @@ public void messageReceived(RecoveryFilesInfoRequest request, TransportChannel c
return;
}

recoveryRef.get(RecoveryTarget.class)
recoveryRef.get()
.receiveFileInfo(
request.phase1FileNames,
request.phase1FileSizes,
Expand All @@ -524,7 +524,7 @@ public void messageReceived(RecoveryCleanFilesRequest request, TransportChannel
return;
}

recoveryRef.get(RecoveryTarget.class)
recoveryRef.get()
.cleanFiles(request.totalTranslogOps(), request.getGlobalCheckpoint(), request.sourceMetaSnapshot(), listener);
}
}
Expand All @@ -538,7 +538,7 @@ class FileChunkTransportRequestHandler implements TransportRequestHandler<Recove
@Override
public void messageReceived(final RecoveryFileChunkRequest request, TransportChannel channel, Task task) throws Exception {
try (RecoveryRef recoveryRef = onGoingRecoveries.getRecoverySafe(request.recoveryId(), request.shardId())) {
final RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
final RecoveryTarget recoveryTarget = recoveryRef.get();
final ActionListener<Void> listener = createOrFinishListener(recoveryRef, channel, Actions.FILE_CHUNK, request);
if (listener == null) {
return;
Expand Down Expand Up @@ -588,7 +588,7 @@ private ActionListener<Void> createOrFinishListener(
final RecoveryTransportRequest request,
final CheckedFunction<Void, TransportResponse, Exception> responseFn
) {
final RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
final RecoveryTarget recoveryTarget = recoveryRef.get();
final ActionListener<TransportResponse> channelListener = new ChannelActionListener<>(channel, action, request);
final ActionListener<Void> voidListener = ActionListener.map(channelListener, responseFn);

Expand Down Expand Up @@ -622,7 +622,7 @@ public void onFailure(Exception e) {
try (RecoveryRef recoveryRef = onGoingRecoveries.getRecovery(recoveryId)) {
if (recoveryRef != null) {
logger.error(() -> new ParameterizedMessage("unexpected error during recovery [{}], failing shard", recoveryId), e);
RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
RecoveryTarget recoveryTarget = recoveryRef.get();
onGoingRecoveries.failRecovery(
recoveryId,
new RecoveryFailedException(recoveryTarget.state(), "unexpected error", e),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public RecoveryRef getRecoverySafe(long id, ShardId shardId) {
if (recoveryRef == null) {
throw new IndexShardClosedException(shardId);
}
RecoveryTarget recoveryTarget = recoveryRef.get(RecoveryTarget.class);
RecoveryTarget recoveryTarget = recoveryRef.get();
assert recoveryTarget.indexShard().shardId().equals(shardId);
return recoveryRef;
}
Expand Down Expand Up @@ -284,7 +284,7 @@ public boolean cancelRecoveriesForShard(ShardId shardId, String reason) {
* causes {@link RecoveryTarget#decRef()} to be called. This makes sure that the underlying resources
* will not be freed until {@link RecoveryRef#close()} is called.
*/
public static class RecoveryRef extends AutoCloseableRefCounted {
public static class RecoveryRef extends AutoCloseableRefCounted<RecoveryTarget> {

/**
* Important: {@link RecoveryTarget#tryIncRef()} should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ private void doReplication(final long replicationId) {
logger.trace("not running replication with id [{}] - can not find it (probably finished)", replicationId);
return;
}
final SegmentReplicationTarget replicationTarget = replicationRef.get(SegmentReplicationTarget.class);
final SegmentReplicationTarget replicationTarget = replicationRef.get();
timer = replicationTarget.state().getTimer();
final IndexShard indexShard = replicationTarget.indexShard();

Expand Down Expand Up @@ -218,7 +218,7 @@ public void onFailure(Exception e) {
() -> new ParameterizedMessage("unexpected error during replication [{}], failing shard", replicationId),
e
);
SegmentReplicationTarget replicationTarget = replicationRef.get(SegmentReplicationTarget.class);
SegmentReplicationTarget replicationTarget = replicationRef.get();
onGoingReplications.failReplication(
replicationId,
new ReplicationFailedException(replicationTarget.indexShard(), "unexpected error", e),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public void messageReceived(final ReplicationFileChunkRequest request, Transport
ReplicationCollection.ReplicationRef replicationRef = segmentReplicationReplicaService.getOnGoingReplications()
.getReplicationSafe(request.getReplicationId(), request.shardId())
) {
final SegmentReplicationTarget replicationTarget = replicationRef.get(SegmentReplicationTarget.class);
final SegmentReplicationTarget replicationTarget = replicationRef.get();
final ActionListener<Void> listener = createOrFinishListener(replicationRef, channel, Actions.FILE_CHUNK, request);
if (listener == null) {
return;
Expand Down Expand Up @@ -274,7 +274,7 @@ private ActionListener<Void> createOrFinishListener(
final ReplicationFileChunkRequest request,
final CheckedFunction<Void, TransportResponse, Exception> responseFn
) {
final SegmentReplicationTarget replicationTarget = replicationRef.get(SegmentReplicationTarget.class);
final SegmentReplicationTarget replicationTarget = replicationRef.get();
final ActionListener<TransportResponse> channelListener = new ChannelActionListener<>(channel, action, request);
final ActionListener<Void> voidListener = ActionListener.map(channelListener, responseFn);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public ReplicationRef getReplicationSafe(long id, ShardId shardId) {
if (replicationRef == null) {
throw new IndexShardClosedException(shardId);
}
SegmentReplicationTarget replicationTarget = replicationRef.get(SegmentReplicationTarget.class);
SegmentReplicationTarget replicationTarget = replicationRef.get();
assert replicationTarget.indexShard().shardId().equals(shardId);
return replicationRef;
}
Expand Down Expand Up @@ -208,7 +208,7 @@ public boolean cancelRecoveriesForShard(ShardId shardId, String reason) {
* causes {@link SegmentReplicationTarget#decRef()} to be called. This makes sure that the underlying resources
* will not be freed until {@link ReplicationRef#close()} is called.
*/
public static class ReplicationRef extends AutoCloseableRefCounted {
public static class ReplicationRef extends AutoCloseableRefCounted<SegmentReplicationTarget> {

/**
* Important: {@link SegmentReplicationTarget#tryIncRef()} should
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
public class AutoCloseableRefCountedTests extends OpenSearchTestCase {

private RefCounted mockRefCounted;
private AutoCloseableRefCounted testObject;
private AutoCloseableRefCounted<RefCounted> testObject;

@Before
public void setup() {
mockRefCounted = mock(RefCounted.class);
testObject = new AutoCloseableRefCounted(mockRefCounted);
testObject = new AutoCloseableRefCounted<>(mockRefCounted);
}

public void testGet() {
assertEquals(mockRefCounted, testObject.get(RefCounted.class));
assertEquals(mockRefCounted, testObject.get());
}

public void testClose() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,10 @@ public void testLastAccessTimeUpdate() throws Exception {
final RecoveriesCollection collection = new RecoveriesCollection(logger, threadPool);
final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shards.addReplica());
try (RecoveriesCollection.RecoveryRef status = collection.getRecovery(recoveryId)) {
final long lastSeenTime = status.get(RecoveryTarget.class).lastAccessTime();
final long lastSeenTime = status.get().lastAccessTime();
assertBusy(() -> {
try (RecoveriesCollection.RecoveryRef currentStatus = collection.getRecovery(recoveryId)) {
assertThat(
"access time failed to update",
lastSeenTime,
lessThan(currentStatus.get(RecoveryTarget.class).lastAccessTime())
);
assertThat("access time failed to update", lastSeenTime, lessThan(currentStatus.get().lastAccessTime()));
}
});
} finally {
Expand Down Expand Up @@ -124,7 +120,7 @@ public void testRecoveryCancellation() throws Exception {
final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shards.addReplica());
final long recoveryId2 = startRecovery(collection, shards.getPrimaryNode(), shards.addReplica());
try (RecoveriesCollection.RecoveryRef recoveryRef = collection.getRecovery(recoveryId)) {
ShardId shardId = recoveryRef.get(RecoveryTarget.class).indexShard().shardId();
ShardId shardId = recoveryRef.get().indexShard().shardId();
assertTrue("failed to cancel recoveries", collection.cancelRecoveriesForShard(shardId, "test"));
assertThat("all recoveries should be cancelled", collection.size(), equalTo(0));
} finally {
Expand Down Expand Up @@ -164,7 +160,7 @@ public void testResetRecovery() throws Exception {
assertEquals(currentAsTarget, shard.recoveryStats().currentAsTarget());
try (RecoveriesCollection.RecoveryRef newRecoveryRef = collection.getRecovery(resetRecoveryId)) {
shards.recoverReplica(shard, (s, n) -> {
RecoveryTarget newRecoveryTarget = newRecoveryRef.get(RecoveryTarget.class);
RecoveryTarget newRecoveryTarget = newRecoveryRef.get();
assertSame(s, newRecoveryTarget.indexShard());
return newRecoveryTarget;
}, false);
Expand Down

0 comments on commit 9988089

Please sign in to comment.