Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a race condition between ShardConsumer shutdown and initialization #1319

Merged
merged 6 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ public void executeLifecycle() {
// Task rejection during the subscribe() call will not be propagated back as it not executed
// in the context of the Scheduler thread. Hence we should not assume the subscription will
// always be successful.
// But if subscription was not successful, then it will recover
// during healthCheck which will restart subscription.
// From Shardconsumer point of view, initialization after the below subscribe call
// is complete
subscribe();
needsInitialization = false;
}
Expand Down Expand Up @@ -276,6 +280,16 @@ void subscribe() {

@VisibleForTesting
synchronized CompletableFuture<Boolean> initializeComplete() {
if (!needsInitialization) {
// initialization already complete, this must be a no-op.
// ShardConsumer must be in ProcessingState and
// any further activity will be driven by publisher pushing data to subscriber
// which invokes handleInput and that triggers ProcessTask.
// Scheduler is only meant to do health-checks to ensure the consumer
// is not stuck for any reason and to do shutdown handling.
return CompletableFuture.completedFuture(true);
}

if (taskOutcome != null) {
updateState(taskOutcome);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand All @@ -45,6 +47,7 @@
import java.util.Optional;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
Expand All @@ -53,6 +56,7 @@
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import org.junit.After;
Expand All @@ -62,7 +66,9 @@
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.runners.MockitoJUnitRunner;
import org.reactivestreams.Subscriber;
Expand Down Expand Up @@ -148,6 +154,7 @@ public class ShardConsumerTest {

@Before
public void before() {
MockitoAnnotations.initMocks(this);
shardInfo = new ShardInfo(shardId, concurrencyToken, null, ExtendedSequenceNumber.TRIM_HORIZON);
ThreadFactory factory = new ThreadFactoryBuilder().setNameFormat("test-" + testName.getMethodName() + "-%04d")
.setDaemon(true).build();
Expand Down Expand Up @@ -848,6 +855,114 @@ public void testLongRunningTasks() throws Exception {
verifyNoMoreInteractions(taskExecutionListener);
}

@Test
public void testEmptyShardProcessingRaceCondition() throws Exception {
final RecordsPublisher mockPublisher = mock(RecordsPublisher.class);
final ExecutorService mockExecutor = mock(ExecutorService.class);
final ConsumerState mockState = mock(ConsumerState.class);
final ShardConsumer consumer = new ShardConsumer(mockPublisher, mockExecutor, shardInfo, Optional.of(1L),
shardConsumerArgument, mockState, Function.identity(), 1, taskExecutionListener, 0);

when(mockState.state()).thenReturn(ShardConsumerState.WAITING_ON_PARENT_SHARDS);
when(mockState.taskType()).thenReturn(TaskType.BLOCK_ON_PARENT_SHARDS);
final ConsumerTask mockTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockTask);
// Simulate successful BlockedOnParent task execution
// and successful Initialize task execution
when(mockTask.call()).thenReturn(new TaskResult(false));

log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to initiate async" +
" processing of blocked on parent task");
consumer.executeLifecycle();
final ArgumentCaptor<Runnable> taskToExecute = ArgumentCaptor.forClass(Runnable.class);
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
taskToExecute.getValue().run();
log.info("RecordProcessor Thread: Simulated successful execution of Blocked on parent task");
reset(mockExecutor);

log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to InitializingState" +
" and initiate async processing of initialize task");
Comment on lines +883 to +884
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need logs in unit tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes its a complex unit test, the logs help for someone to follow along.

when(mockState.successTransition()).thenReturn(mockState);
when(mockState.state()).thenReturn(ShardConsumerState.INITIALIZING);
when(mockState.taskType()).thenReturn(TaskType.INITIALIZE);
consumer.executeLifecycle();
verify(mockExecutor, timeout(100)).execute(taskToExecute.capture());
log.info("RecordProcessor Thread: Simulated successful execution of Initialize task");
taskToExecute.getValue().run();

log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to move to ProcessingState" +
" and mark initialization future as complete");
when(mockState.state()).thenReturn(ShardConsumerState.PROCESSING);
consumer.executeLifecycle();

// Simulate the race where
// scheduler invokes executeLifecycle which performs Publisher.subscribe(subscriber)
// on recordProcessor thread
// but before scheduler thread finishes initialization, handleInput is invoked
// on record processor thread.

// Since ShardConsumer creates its own instance of subscriber that cannot be mocked
// this test sequence will appear a little odd.
// In order to control the order in which execution occurs, lets first invoke
// handleInput, although this will never happen, since there isn't a way
// to control the precise timing of the thread execution, this is the best way
final CountDownLatch processTaskLatch = new CountDownLatch(1);
new Thread(() -> {
reset(mockState);
when(mockState.taskType()).thenReturn(TaskType.PROCESS);
final ConsumerTask mockProcessTask = mock(ConsumerTask.class);
when(mockState.createTask(any(), any(), any())).thenReturn(mockProcessTask);
when(mockProcessTask.call()).then(input -> {
// first we want to wait for subscribe to be called,
// but we cannot control the timing, so wait for 10 seconds
// to let the main thread invoke executeLifecyle which
// will perform subscribe
processTaskLatch.countDown();
log.info("Record Processor Thread: Holding shardConsumer lock, waiting for 10 seconds to" +
" let subscribe be called by scheduler thread");
Thread.sleep(10 * 1000);
log.info("RecordProcessor Thread: Done waiting");
// then return shard end result
log.info("RecordProcessor Thread: Simulating execution of ProcessTask and returning shard-end result");
return new TaskResult(true);
});
final Subscription mockSubscription = mock(Subscription.class);
consumer.handleInput(ProcessRecordsInput.builder().isAtShardEnd(true).build(), mockSubscription);
}).start();

processTaskLatch.await();

// invoke executeLifecycle, which should invoke subscribe
// meanwhile if scheduler tries to acquire the ShardConsumer lock it will
// be blocked during initialization processing because handleInput was
// already invoked and will be holding the lock. Thereby creating the
// race condition we want.
reset(mockState);
AtomicBoolean successTransitionCalled = new AtomicBoolean(false);
when(mockState.successTransition()).then(input -> {
successTransitionCalled.set(true);
return mockState;
});
AtomicBoolean shutdownTransitionCalled = new AtomicBoolean(false);
when(mockState.shutdownTransition(any())).then(input -> {
shutdownTransitionCalled.set(true);
return mockState;
});
when(mockState.state()).then(input -> {
if (successTransitionCalled.get() && shutdownTransitionCalled.get()) {
return ShardConsumerState.SHUTTING_DOWN;
}
return ShardConsumerState.PROCESSING;
});
log.info("Scheduler Thread: Invoking ShardConsumer.executeLifecycle() to invoke subscribe and" +
" complete initialization");
consumer.executeLifecycle();
log.info("Scheduler Thread: Done initializing the ShardConsumer");

log.info("Verifying scheduler did not perform shutdown transition during initialization");
verify(mockState, times(0)).shutdownTransition(any());
}

private void mockSuccessfulShutdown(CyclicBarrier taskCallBarrier) {
mockSuccessfulShutdown(taskCallBarrier, null);
}
Expand Down
Loading