Skip to content

Commit

Permalink
Improve dedup logic using seperate map (opensearch-project#7305)
Browse files Browse the repository at this point in the history
* Improve dedup logic using seperate map
* Removed use of LinkedHashSet and used IdentityHashMap for dup logic

Signed-off-by: Dhwanil Patel <[email protected]>
Signed-off-by: Shivansh Arora <[email protected]>
  • Loading branch information
dhwanilpatel authored and shiv0408 committed Apr 25, 2024
1 parent 845e549 commit daee2c3
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public abstract class TaskBatcher {
private final PrioritizedOpenSearchThreadPoolExecutor threadExecutor;
// package visible for tests
final Map<Object, LinkedHashSet<BatchedTask>> tasksPerBatchingKey = new ConcurrentHashMap<>();
final Map<Object, Map<Object, BatchedTask>> taskIdentityPerBatchingKey = new ConcurrentHashMap<>();
private final TaskBatcherListener taskBatcherListener;

public TaskBatcher(Logger logger, PrioritizedOpenSearchThreadPoolExecutor threadExecutor, TaskBatcherListener taskBatcherListener) {
Expand Down Expand Up @@ -90,20 +91,30 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
throw new IllegalStateException("cannot add duplicate task: " + a);
}, IdentityHashMap::new));
LinkedHashSet<BatchedTask> newTasks = new LinkedHashSet<>(tasks);
tasksPerBatchingKey.merge(firstTask.batchingKey, newTasks, (existingTasks, updatedTasks) -> {
for (BatchedTask existing : existingTasks) {
// Need to maintain below order in which task identity map and task map are updated.
// For insert: First insert identity in taskIdentity map with dup check and then insert task in taskMap.
// For remove: First remove task from taskMap and then remove identity from taskIdentity map.
// We are inserting identity first and removing at last to ensure no duplicate tasks are enqueued.
// Changing this order might lead to duplicate tasks in queue.
taskIdentityPerBatchingKey.merge(firstTask.batchingKey, tasksIdentity, (existingIdentities, newIdentities) -> {
for (Object newIdentity : newIdentities.keySet()) {
// check that there won't be two tasks with the same identity for the same batching key
BatchedTask duplicateTask = tasksIdentity.get(existing.getTask());
if (duplicateTask != null) {
if (existingIdentities.containsKey(newIdentity)) {
BatchedTask duplicateTask = newIdentities.get(newIdentity);
throw new IllegalStateException(
"task ["
+ duplicateTask.describeTasks(Collections.singletonList(existing))
+ duplicateTask.describeTasks(Collections.singletonList(duplicateTask))
+ "] with source ["
+ duplicateTask.source
+ "] is already queued"
);
}
}
existingIdentities.putAll(newIdentities);
return existingIdentities;
});
// since we have checked for dup tasks in above map, we can add all new task in map.
tasksPerBatchingKey.merge(firstTask.batchingKey, newTasks, (existingTasks, updatedTasks) -> {
existingTasks.addAll(updatedTasks);
return existingTasks;
});
Expand All @@ -119,26 +130,37 @@ public void submitTasks(List<? extends BatchedTask> tasks, @Nullable TimeValue t
}
}

private void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue timeout) {
void onTimeoutInternal(List<? extends BatchedTask> tasks, TimeValue timeout) {
final ArrayList<BatchedTask> toRemove = new ArrayList<>();
final ArrayList<Object> toRemoveIdentities = new ArrayList<>();
for (BatchedTask task : tasks) {
if (task.processed.getAndSet(true) == false) {
logger.debug("task [{}] timed out after [{}]", task.source, timeout);
toRemove.add(task);
toRemoveIdentities.add(task.getTask());
}
}
if (toRemove.isEmpty() == false) {
BatchedTask firstTask = toRemove.get(0);
Object batchingKey = firstTask.batchingKey;
assert tasks.stream().allMatch(t -> t.batchingKey == batchingKey)
: "tasks submitted in a batch should share the same batching key: " + tasks;
// While removing task, need to remove task first from taskMap and then remove identity from identityMap.
// Changing this order might lead to duplicate task during submission.
tasksPerBatchingKey.computeIfPresent(batchingKey, (tasksKey, existingTasks) -> {
existingTasks.removeAll(toRemove);
if (existingTasks.isEmpty()) {
return null;
}
return existingTasks;
});
taskIdentityPerBatchingKey.computeIfPresent(batchingKey, (tasksKey, existingIdentities) -> {
toRemoveIdentities.stream().forEach(existingIdentities::remove);
if (existingIdentities.isEmpty()) {
return null;
}
return existingIdentities;
});
taskBatcherListener.onTimeout(toRemove);
onTimeout(toRemove, timeout);
}
Expand All @@ -156,7 +178,10 @@ void runIfNotProcessed(BatchedTask updateTask) {
if (updateTask.processed.get() == false) {
final List<BatchedTask> toExecute = new ArrayList<>();
final Map<String, List<BatchedTask>> processTasksBySource = new HashMap<>();
// While removing task, need to remove task first from taskMap and then remove identity from identityMap.
// Changing this order might lead to duplicate task during submission.
LinkedHashSet<BatchedTask> pending = tasksPerBatchingKey.remove(updateTask.batchingKey);
taskIdentityPerBatchingKey.remove(updateTask.batchingKey);
if (pending != null) {
for (BatchedTask task : pending) {
if (task.processed.getAndSet(true) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ public void testDuplicateSubmission() throws InterruptedException {
submitTask("blocking", blockingTask);

TestExecutor<SimpleTask> executor = tasks -> {};
SimpleTask task = new SimpleTask(1);
SimpleTask task1 = new SimpleTask(1);
TestListener listener = new TestListener() {
@Override
public void processed(String source) {
Expand All @@ -410,21 +410,110 @@ public void onFailure(String source, Exception e) {
}
};

submitTask("first time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);

// submitting same task1 again, it should throw exception.
final IllegalStateException e = expectThrows(
IllegalStateException.class,
() -> submitTask("second time", task, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener)
() -> submitTask("second time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener)
);
assertThat(e, hasToString(containsString("task [1] with source [second time] is already queued")));

submitTask("third time a charm", new SimpleTask(1), ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
// inserting new task with same data, this should pass as it is new object and reference is different.
SimpleTask task2 = new SimpleTask(1);
// equals method returns true for both task
assertTrue(task1.equals(task2));
// references of both tasks are different.
assertFalse(task1 == task2);
// submitting this task should be allowed, as it is new object.
submitTask("third time a charm", task2, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);

// submitting same task2 again, it should throw exception, since it was submitted last time
final IllegalStateException e2 = expectThrows(
IllegalStateException.class,
() -> submitTask("second time", task2, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener)
);
assertThat(e2, hasToString(containsString("task [1] with source [second time] is already queued")));

assertThat(latch.getCount(), equalTo(2L));
}
latch.await();
}

public void testDuplicateSubmissionAfterTimeout() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(2);
final CountDownLatch timeOutLatch = new CountDownLatch(1);
try (BlockingTask blockingTask = new BlockingTask(Priority.IMMEDIATE)) {
submitTask("blocking", blockingTask);

TestExecutor<SimpleTask> executor = tasks -> {};
SimpleTask task1 = new SimpleTask(1);
TestListener listener = new TestListener() {
@Override
public void processed(String source) {
latch.countDown();
}

@Override
public void onFailure(String source, Exception e) {
if (e instanceof ProcessClusterEventTimeoutException) {
timeOutLatch.countDown();
} else {
throw new AssertionError(e);
}
}
};

submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
ArrayList<TaskBatcher.BatchedTask> tasks = new ArrayList();
tasks.add(
taskBatcher.new UpdateTask(
ClusterStateTaskConfig.build(Priority.NORMAL).priority(), "first time", task1, listener, executor
)
);

// task1 got timed out, it will be removed from map.
taskBatcher.onTimeoutInternal(tasks, TimeValue.ZERO);
timeOutLatch.await(); // wait for task to get timeout
// submitting same task1 again, it should get submitted, since last task was timeout.
submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);
assertThat(latch.getCount(), equalTo(2L));
}
latch.await();
}

public void testDuplicateSubmissionAfterExecution() throws InterruptedException {
final CountDownLatch firstTaskLatch = new CountDownLatch(1);
final CountDownLatch latch = new CountDownLatch(2);

TestExecutor<SimpleTask> executor = tasks -> {};
SimpleTask task1 = new SimpleTask(1);
TestListener listener = new TestListener() {
@Override
public void processed(String source) {
firstTaskLatch.countDown();
latch.countDown();
}

@Override
public void onFailure(String source, Exception e) {
if (e instanceof ProcessClusterEventTimeoutException) {
latch.countDown();
} else {
throw new AssertionError(e);
}
}
};
submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);

firstTaskLatch.await(); // wait till task is not executed

// submitting same task1 again, it should get submitted, since last task was executed.
submitTask("first time", task1, ClusterStateTaskConfig.build(Priority.NORMAL), executor, listener);

latch.await(); // wait till all tasks are not completed.
}

protected static TaskBatcherListener getMockListener() {
return new TaskBatcherListener() {
@Override
Expand Down Expand Up @@ -458,12 +547,16 @@ private SimpleTask(int id) {

@Override
public int hashCode() {
return super.hashCode();
return this.id;
}

@Override
public boolean equals(Object obj) {
return super.equals(obj);
return ((SimpleTask) obj).getId() == this.id;
}

public int getId() {
return id;
}

@Override
Expand Down

0 comments on commit daee2c3

Please sign in to comment.