Skip to content

Commit

Permalink
feature: implement BatchNlpTask
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Nov 5, 2024
1 parent 21d7a09 commit 5520b8c
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package org.icij.datashare.tasks;

import static java.util.Optional.ofNullable;
import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP;

import com.google.inject.Inject;
import com.google.inject.assistedinject.Assisted;
import java.util.List;
import java.util.function.Function;
import org.icij.datashare.asynctasks.CancellableTask;
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.asynctasks.TaskGroup;
import org.icij.datashare.extension.PipelineRegistry;
import org.icij.datashare.text.Document;
import org.icij.datashare.text.Language;
import org.icij.datashare.text.NamedEntity;
import org.icij.datashare.text.indexing.Indexer;
import org.icij.datashare.text.nlp.Pipeline;
import org.icij.datashare.user.User;
import org.icij.datashare.user.UserTask;
import org.icij.task.DefaultTask;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@TaskGroup(JAVA_GROUP)
public class BatchNlpTask extends DefaultTask<Long> implements UserTask, CancellableTask {
// TODO: fix the raw used of parametrized type...
private static final List<String> EXCLUDED_SOURCES = List.of("contentTranslated");
private final Logger logger = LoggerFactory.getLogger(getClass());
private final User user;
private volatile Thread taskThread;
private final Indexer indexer;
private final List<BatchEnqueueFromIndexTask.BatchDocument> docs;
private final Pipeline pipeline;
private final int maxLength;

@Inject
public BatchNlpTask(Indexer indexer, PipelineRegistry registry, @Assisted Task<Long> taskView,
@Assisted final Function<Double, Void> updateCallback) {
this(indexer, registry.get(Pipeline.Type.parse((String) taskView.args.get("pipeline"))), taskView, updateCallback);
}


BatchNlpTask(Indexer indexer, Pipeline pipeline, @Assisted Task<Long> taskView,
@Assisted final Function<Double, Void> ignored) {
this.user = taskView.getUser();
this.indexer = indexer;
this.pipeline = pipeline;
this.docs = (List<BatchEnqueueFromIndexTask.BatchDocument>) taskView.args.get("docs");
this.maxLength = (int) taskView.args.get("maxLength");
}

@Override
public Long call() throws Exception {
taskThread = Thread.currentThread();
if (this.docs.isEmpty()) {
return 0L;
}
Language language = this.docs.get(0).language();
pipeline.initialize(language);
logger.info("performing NER on {} docs in {}...", this.docs.size(), language);
// TODO: for now None of the Java NER seems to support batch processing, we just iterate docs one by one
// TODO: we could improve perfs by fetching docs and processing them concurrently...
for (BatchEnqueueFromIndexTask.BatchDocument doc : this.docs) {
String project = doc.project();
Document indexDoc = indexer.get(doc.id(), doc.rootDocument(), EXCLUDED_SOURCES);
if (indexDoc.getContentTextLength() < this.maxLength) {
List<NamedEntity> namedEntities = pipeline.process(indexDoc);
indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc);
} else {
int nbChunks = indexDoc.getContentTextLength() / this.maxLength + 1;
for (int chunkIndex = 0; chunkIndex < nbChunks; chunkIndex++) {
List<NamedEntity> namedEntities =
pipeline.process(indexDoc, maxLength, chunkIndex * maxLength);
if (chunkIndex < nbChunks - 1) {
indexer.bulkAdd(project, namedEntities);
} else {
indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc);
}
}
}
}
pipeline.terminate(language);
return (long) this.docs.size();
}

@Override
public void cancel(boolean requeue) {
ofNullable(taskThread).ifPresent(Thread::interrupt);
}

@Override
public User getUser() {
return user;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package org.icij.datashare.tasks;

import static org.icij.datashare.test.ElasticsearchRule.TEST_INDEX;
import static org.icij.datashare.text.DocumentBuilder.createDoc;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.MockitoAnnotations.openMocks;

import java.util.List;
import java.util.Map;
import org.icij.datashare.asynctasks.Task;
import org.icij.datashare.text.Document;
import org.icij.datashare.text.Language;
import org.icij.datashare.text.indexing.Indexer;
import org.icij.datashare.text.nlp.AbstractPipeline;
import org.icij.datashare.text.nlp.Pipeline;
import org.icij.datashare.user.User;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;

public class BatchNlpTaskTest {
@Mock
private Indexer indexer;
@Mock
private AbstractPipeline pipeline;
private AutoCloseable mocks;

@Before
public void setUp() {
this.mocks = openMocks(this);
}

@Before
public void tearDow() throws Exception {
this.mocks.close();
}

@Test(timeout = 2000)
public void test_batch_nlp() throws Exception {
// Given
int maxLength = 20;
String rootId = "rootId";
Language language = Language.ENGLISH;
Document doc0 = createDoc("doc0").with(language).withRootId(rootId)
.with("hello world").build();
Document doc1 = createDoc("doc1").with(language).withRootId(rootId)
.with("this is too long to be processed all at once").build();
when(pipeline.getType()).thenReturn(Pipeline.Type.CORENLP);
when(pipeline.initialize(any())).thenReturn(true);

when(indexer.get(anyString(), anyString(), any(List.class))).thenReturn(doc0, doc1);
List<BatchEnqueueFromIndexTask.BatchDocument> batchDocs = List.of(
new BatchEnqueueFromIndexTask.BatchDocument(doc0.getId(), doc0.getRootDocument(), TEST_INDEX, language),
new BatchEnqueueFromIndexTask.BatchDocument(doc1.getId(), doc1.getRootDocument(), TEST_INDEX, language)
);
Map<String, Object> properties = Map.of(
"docs", batchDocs,
"pipeline", "OPENNLP",
"maxLength", maxLength,
"group", "JAVA"
);
BatchNlpTask nlpTask = new BatchNlpTask(
indexer, pipeline, new Task<>(BatchNlpTask.class.getName(), new User("test"), properties), null
);
// When
nlpTask.call();
// Then
verify(pipeline).process(eq(doc0));
verify(pipeline).process(eq(doc1), eq(maxLength), eq(0));
verify(pipeline).process(eq(doc1), eq(maxLength), eq(maxLength));
}
}

0 comments on commit 5520b8c

Please sign in to comment.