diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java new file mode 100644 index 000000000..1564fdfd2 --- /dev/null +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchNlpTask.java @@ -0,0 +1,96 @@ +package org.icij.datashare.tasks; + +import static java.util.Optional.ofNullable; +import static org.icij.datashare.tasks.GroupHelper.JAVA_GROUP; + +import com.google.inject.Inject; +import com.google.inject.assistedinject.Assisted; +import java.util.List; +import java.util.function.Function; +import org.icij.datashare.asynctasks.CancellableTask; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.asynctasks.TaskGroup; +import org.icij.datashare.extension.PipelineRegistry; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.NamedEntity; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; +import org.icij.datashare.user.UserTask; +import org.icij.task.DefaultTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@TaskGroup(JAVA_GROUP) +public class BatchNlpTask extends DefaultTask implements UserTask, CancellableTask { + // TODO: fix the raw used of parametrized type... + private static final List EXCLUDED_SOURCES = List.of("contentTranslated"); + private final Logger logger = LoggerFactory.getLogger(getClass()); + private final User user; + private volatile Thread taskThread; + private final Indexer indexer; + private final List docs; + private final Pipeline pipeline; + private final int maxLength; + + @Inject + public BatchNlpTask(Indexer indexer, PipelineRegistry registry, @Assisted Task taskView, + @Assisted final Function updateCallback) { + this(indexer, registry.get(Pipeline.Type.parse((String) taskView.args.get("pipeline"))), taskView, updateCallback); + } + + + BatchNlpTask(Indexer indexer, Pipeline pipeline, @Assisted Task taskView, + @Assisted final Function ignored) { + this.user = taskView.getUser(); + this.indexer = indexer; + this.pipeline = pipeline; + this.docs = (List) taskView.args.get("docs"); + this.maxLength = (int) taskView.args.get("maxLength"); + } + + @Override + public Long call() throws Exception { + taskThread = Thread.currentThread(); + if (this.docs.isEmpty()) { + return 0L; + } + Language language = this.docs.get(0).language(); + pipeline.initialize(language); + logger.info("performing NER on {} docs in {}...", this.docs.size(), language); + // TODO: for now None of the Java NER seems to support batch processing, we just iterate docs one by one + // TODO: we could improve perfs by fetching docs and processing them concurrently... + for (BatchEnqueueFromIndexTask.BatchDocument doc : this.docs) { + String project = doc.project(); + Document indexDoc = indexer.get(doc.id(), doc.rootDocument(), EXCLUDED_SOURCES); + if (indexDoc.getContentTextLength() < this.maxLength) { + List namedEntities = pipeline.process(indexDoc); + indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc); + } else { + int nbChunks = indexDoc.getContentTextLength() / this.maxLength + 1; + for (int chunkIndex = 0; chunkIndex < nbChunks; chunkIndex++) { + List namedEntities = + pipeline.process(indexDoc, maxLength, chunkIndex * maxLength); + if (chunkIndex < nbChunks - 1) { + indexer.bulkAdd(project, namedEntities); + } else { + indexer.bulkAdd(project, pipeline.getType(), namedEntities, indexDoc); + } + } + } + } + pipeline.terminate(language); + return (long) this.docs.size(); + } + + @Override + public void cancel(boolean requeue) { + ofNullable(taskThread).ifPresent(Thread::interrupt); + } + + @Override + public User getUser() { + return user; + } +} diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTaskTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTaskTest.java new file mode 100644 index 000000000..558a7d166 --- /dev/null +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/BatchNlpTaskTest.java @@ -0,0 +1,76 @@ +package org.icij.datashare.tasks; + +import static org.icij.datashare.test.ElasticsearchRule.TEST_INDEX; +import static org.icij.datashare.text.DocumentBuilder.createDoc; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; + +import java.util.List; +import java.util.Map; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.nlp.AbstractPipeline; +import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +public class BatchNlpTaskTest { + @Mock + private Indexer indexer; + @Mock + private AbstractPipeline pipeline; + private AutoCloseable mocks; + + @Before + public void setUp() { + this.mocks = openMocks(this); + } + + @Before + public void tearDow() throws Exception { + this.mocks.close(); + } + + @Test(timeout = 2000) + public void test_batch_nlp() throws Exception { + // Given + int maxLength = 20; + String rootId = "rootId"; + Language language = Language.ENGLISH; + Document doc0 = createDoc("doc0").with(language).withRootId(rootId) + .with("hello world").build(); + Document doc1 = createDoc("doc1").with(language).withRootId(rootId) + .with("this is too long to be processed all at once").build(); + when(pipeline.getType()).thenReturn(Pipeline.Type.CORENLP); + when(pipeline.initialize(any())).thenReturn(true); + + when(indexer.get(anyString(), anyString(), any(List.class))).thenReturn(doc0, doc1); + List batchDocs = List.of( + new BatchEnqueueFromIndexTask.BatchDocument(doc0.getId(), doc0.getRootDocument(), TEST_INDEX, language), + new BatchEnqueueFromIndexTask.BatchDocument(doc1.getId(), doc1.getRootDocument(), TEST_INDEX, language) + ); + Map properties = Map.of( + "docs", batchDocs, + "pipeline", "OPENNLP", + "maxLength", maxLength, + "group", "JAVA" + ); + BatchNlpTask nlpTask = new BatchNlpTask( + indexer, pipeline, new Task<>(BatchNlpTask.class.getName(), new User("test"), properties), null + ); + // When + nlpTask.call(); + // Then + verify(pipeline).process(eq(doc0)); + verify(pipeline).process(eq(doc1), eq(maxLength), eq(0)); + verify(pipeline).process(eq(doc1), eq(maxLength), eq(maxLength)); + } +}