From f35a3d0d03cbd6ac3a6917ee03149ef06b3e2d15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Fri, 30 Aug 2024 10:59:09 +0200 Subject: [PATCH 1/5] feature: add batch processing API for `Pipeline` --- .../icij/datashare/text/nlp/AbstractPipeline.java | 6 ------ .../java/org/icij/datashare/text/nlp/Pipeline.java | 13 +++++++++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/AbstractPipeline.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/AbstractPipeline.java index ec6a2e1d1..e9d313903 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/AbstractPipeline.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/AbstractPipeline.java @@ -1,7 +1,6 @@ package org.icij.datashare.text.nlp; import org.icij.datashare.PropertiesProvider; -import org.icij.datashare.text.Document; import org.icij.datashare.text.Language; import org.icij.datashare.text.NamedEntity; import org.slf4j.Logger; @@ -65,11 +64,6 @@ public boolean initialize(Language language) throws InterruptedException { return true; } - /** - * Apply all specified stages/annotators on input - * @param doc is the document source to process */ - public abstract List process(Document doc) throws InterruptedException; - /** * Post-processing operations */ diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java index 80c54b1cf..bbed5c344 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java @@ -1,5 +1,6 @@ package org.icij.datashare.text.nlp; +import java.util.stream.Stream; import org.icij.datashare.reflect.EnumTypeToken; import org.icij.datashare.text.Document; import org.icij.datashare.text.Language; @@ -22,6 +23,9 @@ static Set set(Type ...types) { return new HashSet<>(Arrays.asList(types)); } + record PipelineInput (Document doc, int contentLength, int contentOffset) {} + + enum Type implements EnumTypeToken { TEST((short)-1), CORENLP((short)0), @@ -98,6 +102,15 @@ public String getName() { List process(Document doc) throws InterruptedException; List process(Document doc, int contentLength, int contentOffset) throws InterruptedException; + default List> process(Stream inputs) throws InterruptedException { + return inputs.map(i -> { + try { + return process(i.doc, i.contentLength, i.contentOffset); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }).collect(Collectors.toList()); + } void terminate(Language language) throws InterruptedException; From a7e9a974cebe806681df1392f7896621c6131378 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Tue, 3 Sep 2024 10:21:40 +0200 Subject: [PATCH 2/5] feature: add `Indexer.sort --- .../icij/datashare/text/indexing/Indexer.java | 4 +++- .../elasticsearch/ElasticsearchSearcher.java | 18 ++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java b/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java index 049a3e35a..b0d41ac64 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/indexing/Indexer.java @@ -33,7 +33,7 @@ public interface Indexer extends Closeable { boolean bulkUpdate(String indexName, List entities) throws IOException; void add(String indexName, T obj) throws IOException; void update(String indexName, T obj) throws IOException; - boolean exists(String indexName, String id) throws IOException; + boolean exists(String indexName, String id) throws IOException; T get(String indexName, String id); T get(String indexName, String id, List sourceExcludes); @@ -61,9 +61,11 @@ interface Searcher { Searcher withoutSource(String... fields); Searcher withSource(boolean source); Searcher limit(int maxCount); + Searcher sort(String field, SortOrder order); void clearScroll() throws IOException; long totalHits(); Searcher with(int fuzziness, boolean phraseMatches); + enum SortOrder { ASC, DESC } } interface QueryBuilderSearcher extends Searcher { diff --git a/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java b/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java index ff2cb795d..0b6560d5f 100644 --- a/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java +++ b/datashare-index/src/main/java/org/icij/datashare/text/indexing/elasticsearch/ElasticsearchSearcher.java @@ -9,7 +9,6 @@ import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.search.Hit; import co.elastic.clients.elasticsearch.core.search.ResponseBody; -import co.elastic.clients.json.JsonpMappingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; import jakarta.json.JsonException; @@ -64,7 +63,7 @@ static Stream resultStream(Class cls, Iterable T hitToObject(Hit searchHit, Class cls) { - return (T) JsonObjectMapper.getObject(searchHit.id(), searchHit.index(), JsonUtils.nodeToMap(searchHit.source()), cls); + return JsonObjectMapper.getObject(searchHit.id(), searchHit.index(), JsonUtils.nodeToMap(searchHit.source()), cls); } @Override @@ -177,6 +176,14 @@ public Indexer.Searcher limit(int maxCount) { return this; } + @Override + public Searcher sort(String field, SortOrder order) { + sourceBuilder.sort(builder -> + builder.field(fieldBuilder -> fieldBuilder.field(field).order(esSortOrder(order))) + ); + return this; + } + @Override public void clearScroll() throws IOException { @@ -194,4 +201,11 @@ public long totalHits() { public String toString() { return "query : " + jsonBoolQuery; } + + private co.elastic.clients.elasticsearch._types.SortOrder esSortOrder(SortOrder sortOrder) { + return switch (sortOrder) { + case ASC -> co.elastic.clients.elasticsearch._types.SortOrder.Asc; + case DESC -> co.elastic.clients.elasticsearch._types.SortOrder.Desc; + }; + } } \ No newline at end of file From c6e35d12f8aa12f538c00b050b0da862c86bb6c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Tue, 3 Sep 2024 10:22:38 +0200 Subject: [PATCH 3/5] chore: sort documents by language when queuing NLP tasks --- .../java/org/icij/datashare/tasks/EnqueueFromIndexTask.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java index dfaca2387..feb2079f2 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java @@ -54,7 +54,8 @@ public EnqueueFromIndexTask(final DocumentCollectionFactory factory, fin public Long call() throws Exception { super.call(); Indexer.Searcher searcher = indexer.search(singletonList(projectName), Document.class) - .without(nlpPipeline).withSource("rootDocument").limit(scrollSize); + .without(nlpPipeline).withSource("rootDocument").limit(scrollSize) + .sort("language", Indexer.Searcher.SortOrder.ASC); logger.info("resuming NLP name finding for index {} and {} with {} scroll and size of {} : {} documents found", projectName, nlpPipeline, scrollDuration, scrollSize, searcher.totalHits()); List docsToProcess = searcher.scroll(scrollDuration).collect(toList()); From 3fbf02cbc85f0f9db7580148e29a9823faec310b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Thu, 5 Sep 2024 10:42:46 +0200 Subject: [PATCH 4/5] feature: batch process --- .../org/icij/datashare/function/Pair.java | 10 +- .../org/icij/datashare/text/nlp/NlpTag.java | 2 +- .../org/icij/datashare/text/nlp/Pipeline.java | 41 +++-- .../extension/PipelineRegistryTest.java | 115 +++++++------- .../datashare/text/nlp/test/TestPipeline.java | 11 +- .../org/icij/datashare/nlp/EmailPipeline.java | 27 ++-- .../icij/datashare/tasks/ExtractNlpTask.java | 145 +++++++++++++++--- .../org/icij/datashare/web/NerResource.java | 4 +- .../icij/datashare/nlp/CoreNlpTestManual.java | 2 +- .../icij/datashare/nlp/EmailPipelineTest.java | 82 +++++----- .../tasks/ExtractNlpTaskIntTest.java | 47 +++++- .../datashare/tasks/ExtractNlpTaskTest.java | 140 ++++++++++++----- .../icij/datashare/web/NerResourceTest.java | 18 ++- .../datashare/cli/DatashareCliOptions.java | 7 +- .../text/nlp/corenlp/CorenlpPipeline.java | 47 ++---- 15 files changed, 452 insertions(+), 246 deletions(-) diff --git a/datashare-api/src/main/java/org/icij/datashare/function/Pair.java b/datashare-api/src/main/java/org/icij/datashare/function/Pair.java index 562facf1b..44926f42a 100644 --- a/datashare-api/src/main/java/org/icij/datashare/function/Pair.java +++ b/datashare-api/src/main/java/org/icij/datashare/function/Pair.java @@ -17,6 +17,10 @@ public Pair(T1 fst, T2 snd) { second = snd; } + public static Pair of(T1 fst, T2 snd) { + return new Pair<>(fst, snd); + } + public T1 _1() { return first; } public T2 _2() { return second; } @@ -27,12 +31,10 @@ public int hashCode() { @Override public boolean equals(Object o) { - if ( ! (o instanceof Pair) ) { + if ( ! (o instanceof Pair objPair) ) { return false; } - Pair objPair = (Pair) o; - return first .equals(objPair._1()) && - second.equals(objPair._2()); + return first .equals(objPair._1()) && second.equals(objPair._2()); } } \ No newline at end of file diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java index 50a9f40cd..c4f835fa8 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java @@ -14,7 +14,7 @@ public class NlpTag { private final NamedEntity.Category category; - NlpTag(int begin, int end, NamedEntity.Category category) { + public NlpTag(int begin, int end, NamedEntity.Category category) { this.begin = begin; this.end = end; this.category = category; diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java index bbed5c344..25d215575 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java @@ -1,5 +1,11 @@ package org.icij.datashare.text.nlp; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; import java.util.stream.Stream; import org.icij.datashare.reflect.EnumTypeToken; import org.icij.datashare.text.Document; @@ -7,7 +13,6 @@ import org.icij.datashare.text.NamedEntity; import java.nio.charset.Charset; -import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -23,9 +28,6 @@ static Set set(Type ...types) { return new HashSet<>(Arrays.asList(types)); } - record PipelineInput (Document doc, int contentLength, int contentOffset) {} - - enum Type implements EnumTypeToken { TEST((short)-1), CORENLP((short)0), @@ -71,6 +73,10 @@ public static Set parseAll(final String comaSeparatedTypes) { return comaSeparatedTypes == null || comaSeparatedTypes.isEmpty() ? new HashSet<>(): stream(comaSeparatedTypes.split(",")).map(Type::valueOf).collect(Collectors.toSet()); } + + public boolean extractFromDoc() { + return this == Type.EMAIL; + } } enum Property { @@ -100,18 +106,25 @@ public String getName() { boolean initialize(Language language) throws InterruptedException; - List process(Document doc) throws InterruptedException; - List process(Document doc, int contentLength, int contentOffset) throws InterruptedException; - default List> process(Stream inputs) throws InterruptedException { - return inputs.map(i -> { - try { - return process(i.doc, i.contentLength, i.contentOffset); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); + default List processDoc(Document doc) throws InterruptedException { + return processDoc(doc, doc.getContentTextLength(), 0); + } + + default List processDoc(Document doc, int contentLength, int contentOffset) throws InterruptedException { + Annotations annotations = new Annotations(doc.getId(), this.getType(), doc.getLanguage()); + String docContent = doc.getContent(); + this.processText(Stream.of(docContent.substring(contentOffset, contentOffset + contentLength)), doc.getLanguage()) + .get(0) + .forEach(tag -> { + int begin = tag.getBegin() + contentOffset; + int end = tag.getEnd() + contentOffset; + annotations.add(begin, end, tag.getCategory()); + }); + return NamedEntity.allFrom(docContent, annotations); } + List> processText(Stream batch, Language language) throws InterruptedException; + void terminate(Language language) throws InterruptedException; boolean supports(Language language); diff --git a/datashare-api/src/test/java/org/icij/datashare/extension/PipelineRegistryTest.java b/datashare-api/src/test/java/org/icij/datashare/extension/PipelineRegistryTest.java index 24d815830..350bfc6f2 100644 --- a/datashare-api/src/test/java/org/icij/datashare/extension/PipelineRegistryTest.java +++ b/datashare-api/src/test/java/org/icij/datashare/extension/PipelineRegistryTest.java @@ -73,61 +73,62 @@ public void setUp() { loader = new ExtensionLoader(folder.getRoot().toPath()); } - String EXTENSION_PIPELINE_SOURCE = "package org.icij.datashare.text.nlp.test;\n" + - "\n" + - "import org.icij.datashare.PropertiesProvider;\n" + - "import org.icij.datashare.text.Document;\n" + - "import org.icij.datashare.text.Language;\n" + - "import org.icij.datashare.text.NamedEntity;\n" + - "import org.icij.datashare.text.nlp.Annotations;\n" + - "import org.icij.datashare.text.nlp.Pipeline;\n" + - "\n" + - "import java.nio.charset.Charset;\n" + - "import java.util.List;\n" + - "import java.util.Optional;\n" + - "\n" + - "public class ExtensionPipeline implements Pipeline {\n" + - " @Override\n" + - " public Type getType() {\n" + - " return Type.TEST;\n" + - " }\n" + - "\n" + - " public ExtensionPipeline(PropertiesProvider provider) {}\n" + - " @Override\n" + - " public boolean initialize(Language language) throws InterruptedException {\n" + - " return false;\n" + - " }\n" + - " @Override\n" + - " public List process(Document doc) throws InterruptedException {\n" + - " return null;\n" + - " }\n" + - " @Override\n" + - " public List process(Document doc, int contentLength, int offset) throws InterruptedException {\n" + - " return null;\n" + - " }\n" + - " @Override\n" + - " public void terminate(Language language) throws InterruptedException {\n" + - "\n" + - " }\n" + - "\n" + - " @Override\n" + - " public boolean supports(Language language) {\n" + - " return false;\n" + - " }\n" + - "\n" + - " @Override\n" + - " public List getTargetEntities() {\n" + - " return null;\n" + - " }\n" + - "\n" + - " @Override\n" + - " public boolean isCaching() {\n" + - " return false;\n" + - " }\n" + - "\n" + - " @Override\n" + - " public Charset getEncoding() {\n" + - " return null;\n" + - " }\n" + - "}\n"; + String EXTENSION_PIPELINE_SOURCE = """ + package org.icij.datashare.text.nlp.test; + + import org.icij.datashare.PropertiesProvider; + import org.icij.datashare.text.Document; + import org.icij.datashare.text.Language; + import org.icij.datashare.text.NamedEntity; + import org.icij.datashare.text.nlp.Annotations; + import org.icij.datashare.text.nlp.Pipeline; + import org.icij.datashare.text.nlp.NlpTag; + + import java.nio.charset.Charset; + import java.util.List; + import java.util.stream.Stream; + import java.util.Optional; + + public class ExtensionPipeline implements Pipeline { + @Override + public Type getType() { + return Type.TEST; + } + + public ExtensionPipeline(PropertiesProvider provider) {} + @Override + public boolean initialize(Language language) throws InterruptedException { + return false; + } + @Override + public List> processText(Stream batch, Language language) { + return null; + } + + @Override + public void terminate(Language language) throws InterruptedException { + + } + + @Override + public boolean supports(Language language) { + return false; + } + + @Override + public List getTargetEntities() { + return null; + } + + @Override + public boolean isCaching() { + return false; + } + + @Override + public Charset getEncoding() { + return null; + } + } + """; } diff --git a/datashare-api/src/test/java/org/icij/datashare/text/nlp/test/TestPipeline.java b/datashare-api/src/test/java/org/icij/datashare/text/nlp/test/TestPipeline.java index 55a82529e..a4000191f 100644 --- a/datashare-api/src/test/java/org/icij/datashare/text/nlp/test/TestPipeline.java +++ b/datashare-api/src/test/java/org/icij/datashare/text/nlp/test/TestPipeline.java @@ -1,9 +1,10 @@ package org.icij.datashare.text.nlp.test; +import java.util.stream.Stream; import org.icij.datashare.PropertiesProvider; -import org.icij.datashare.text.Document; import org.icij.datashare.text.Language; import org.icij.datashare.text.NamedEntity; +import org.icij.datashare.text.nlp.NlpTag; import org.icij.datashare.text.nlp.Pipeline; import java.nio.charset.Charset; @@ -22,15 +23,9 @@ public boolean initialize(Language language) { } @Override - public List process(Document doc) { + public List> processText(Stream batch, Language language) { return null; } - - @Override - public List process(Document doc, int contentLength, int contentOffset) { - return null; - } - @Override public void terminate(Language language) { } diff --git a/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java b/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java index 9f3290409..5dc1e804e 100644 --- a/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java +++ b/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java @@ -8,6 +8,7 @@ import java.util.Set; import java.util.Collection; import java.util.Collections; +import java.util.stream.Stream; import org.icij.datashare.PropertiesProvider; import org.icij.datashare.text.Document; import org.icij.datashare.text.Language; @@ -19,6 +20,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; +import org.icij.datashare.text.nlp.NlpTag; import static java.util.Arrays.asList; import static java.util.Collections.unmodifiableSet; @@ -87,19 +89,26 @@ public EmailPipeline(final PropertiesProvider propertiesProvider) { } @Override - public List process(Document doc) { - return process(doc, doc.getContentTextLength(), 0); + public List> processText(Stream batch, Language ignored) { + return batch.map(text -> { + Matcher matcher = pattern.matcher(text); + return matcher.results() + .map(r -> new NlpTag(matcher.start(), matcher.end(), NamedEntity.Category.EMAIL)) + .toList(); + }).toList(); } @Override - public List process(Document doc, int contentLength, int contentOffset) { - Matcher matcher = pattern.matcher(doc.getContent().substring(contentOffset, Math.min(contentLength + contentOffset, doc.getContentTextLength()))); + public List processDoc(Document doc, int contentLength, int contentOffset) { + String docContent = doc.getContent(); NamedEntitiesBuilder namedEntitiesBuilder = new NamedEntitiesBuilder(EMAIL, doc.getId(), doc.getLanguage()).withRoot(doc.getRootDocument()); - while (matcher.find()) { - String email = matcher.group(0); - int start = matcher.start(); - namedEntitiesBuilder.add(NamedEntity.Category.EMAIL, email, start + contentOffset); - } + String chunkContent = docContent.substring(contentOffset, Math.min(contentLength + contentOffset, doc.getContentTextLength())); + this.processText(Stream.of(chunkContent), doc.getLanguage()) + .get(0) + .forEach(t -> { + String mention = chunkContent.substring(t.getBegin(), t.getEnd()); + namedEntitiesBuilder.add(NamedEntity.Category.EMAIL, mention, t.getBegin() + contentOffset); + }); List entities = namedEntitiesBuilder.build(); if ("message/rfc822".equals(doc.getContentType())) { entities.addAll(processMetadata(doc)); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java index e8eef61b1..baaab840c 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java @@ -1,7 +1,21 @@ package org.icij.datashare.tasks; +import static java.lang.String.valueOf; +import static java.util.Optional.ofNullable; +import static org.icij.datashare.cli.DatashareCliOptions.BATCH_SIZE_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_DEFAULT_PROJECT; +import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.MAX_CONTENT_LENGTH_OPT; +import static org.icij.datashare.cli.DatashareCliOptions.NLP_PIPELINE_OPT; +import static org.icij.extract.document.Identifier.shorten; + import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import org.icij.datashare.HumanReadableSize; import org.icij.datashare.PropertiesProvider; @@ -11,32 +25,28 @@ import org.icij.datashare.extract.DocumentCollectionFactory; import org.icij.datashare.monitoring.Monitorable; import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; import org.icij.datashare.text.NamedEntity; import org.icij.datashare.text.Project; import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.nlp.Annotations; +import org.icij.datashare.text.nlp.NlpTag; import org.icij.datashare.text.nlp.Pipeline; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import static java.lang.String.valueOf; -import static java.util.Optional.ofNullable; -import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_DEFAULT_PROJECT; -import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT; -import static org.icij.datashare.cli.DatashareCliOptions.MAX_CONTENT_LENGTH_OPT; -import static org.icij.datashare.cli.DatashareCliOptions.NLP_PIPELINE_OPT; -import static org.icij.extract.document.Identifier.shorten; - public class ExtractNlpTask extends PipelineTask implements Monitorable { - private static final int DEFAULT_MAX_CONTENT_LENGTH = 1024 * 1024; + private static final int DEFAULT_MAX_LENGTH = 4096; + private static final int DEFAULT_BATCH_SIZE = 256; private final Logger logger = LoggerFactory.getLogger(getClass()); private final Indexer indexer; private final Pipeline nlpPipeline; private final Project project; - private final int maxContentLengthChars; + private final int maxContentLength; + private final int batchSize; + + record BatchItem(Document doc, String text, int offset) { + } @Inject public ExtractNlpTask(Indexer indexer, PipelineRegistry registry, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function updateCallback) { @@ -44,11 +54,12 @@ public ExtractNlpTask(Indexer indexer, PipelineRegistry registry, final Document } - ExtractNlpTask(Indexer indexer, Pipeline pipeline, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function updateCallback) { + ExtractNlpTask(Indexer indexer, Pipeline pipeline, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function ignored) { super(Stage.NLP, taskView.getUser(), factory, new PropertiesProvider(taskView.args), String.class); this.nlpPipeline = pipeline; project = Project.project(ofNullable((String)taskView.args.get(DEFAULT_PROJECT_OPT)).orElse(DEFAULT_DEFAULT_PROJECT)); - maxContentLengthChars = (int) HumanReadableSize.parse(ofNullable((String)taskView.args.get(MAX_CONTENT_LENGTH_OPT)).orElse(valueOf(DEFAULT_MAX_CONTENT_LENGTH))); + maxContentLength = (int) HumanReadableSize.parse(ofNullable((String)taskView.args.get(MAX_CONTENT_LENGTH_OPT)).orElse(valueOf(DEFAULT_MAX_LENGTH))); + batchSize = (int) HumanReadableSize.parse(ofNullable((String)taskView.args.get(BATCH_SIZE_OPT)).orElse(valueOf(DEFAULT_BATCH_SIZE))); this.indexer = indexer; } @@ -56,38 +67,126 @@ public ExtractNlpTask(Indexer indexer, PipelineRegistry registry, final Document public Long call() throws Exception { super.call(); logger.info("extracting Named Entities with pipeline {} for {} from queue {}", nlpPipeline.getType(), project, inputQueue.getName()); + long nbMessages; + if (this.nlpPipeline.getType().extractFromDoc()) { + nbMessages = extractFromDocs(); + } else { + nbMessages = extractFromTexts(); + } + logger.info("exiting ExtractNlpTask loop after {} messages.", nbMessages); + return nbMessages; + } + + long extractFromTexts() throws InterruptedException { + // NLP models are loaded/initialized by language, to avoid loading overhead, docs are + // received grouped by language and sent batched to the pipeline to avoid model reload. + long nDocs = 0; + String docId; + Language currentLanguage = null; + boolean languageInitialized = false; + ArrayList batch = new ArrayList<>(batchSize); + while (!(STRING_POISON.equals(docId = inputQueue.poll(60, TimeUnit.SECONDS)))) { + Document doc = indexer.get(project.getName(), docId); + nDocs++; + if (doc != null) { + String docContent = doc.getContent(); + if (!doc.getLanguage().equals(currentLanguage)) { + if (!batch.isEmpty()) { + consumeBatch(batch, currentLanguage); + } + if (currentLanguage != null) { + nlpPipeline.terminate(currentLanguage); + } + currentLanguage = doc.getLanguage(); + languageInitialized = nlpPipeline.initialize(currentLanguage); + } + if (!languageInitialized) { + continue; + } + int docLength = docContent.length(); + for (int begin = 0; begin < docLength; begin += maxContentLength) { + int end = Math.min(begin + maxContentLength, docLength); + String text = docContent.substring(begin, end); + batch.add(new BatchItem(doc, text, begin)); + if (batch.size() >= batchSize) { + consumeBatch(batch, currentLanguage); + } + } + } else { + logger.warn("no document found in index with id " + docId); + } + } + if (!batch.isEmpty()) { + consumeBatch(batch, currentLanguage); + } + if (currentLanguage != null) { + nlpPipeline.terminate(currentLanguage); + } + return nDocs; + } + + private void consumeBatch(List batch, Language language) throws InterruptedException { + List> entities = nlpPipeline.processText(batch.stream().map(i -> i.text), language); + Iterator batchIt = batch.iterator(); + entities.forEach(chunkTags -> { + BatchItem item = batchIt.next(); + Document doc = item.doc; + Annotations annotations = + new Annotations(doc.getId(), nlpPipeline.getType(), doc.getLanguage()); + int offset = item.offset; + chunkTags.forEach(tag -> { + int begin = tag.getBegin() + offset; + int end = tag.getEnd() + offset; + annotations.add(begin, end, tag.getCategory()); + }); + List chunkEntities = NamedEntity.allFrom(doc.getContent(), annotations); + boolean isComplete = offset + item.text.length() == doc.getContentTextLength(); + try { + if (isComplete) { + indexer.bulkAdd(project.getName(), nlpPipeline.getType(), chunkEntities, doc); + } else { + indexer.bulkAdd(project.getName(), chunkEntities); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + batch.clear(); + } + + private long extractFromDocs() throws InterruptedException { String docId; long nbMessages = 0; while (!(STRING_POISON.equals(docId = inputQueue.poll(60, TimeUnit.SECONDS)))) { try { if (docId != null) { - findNamedEntities(project, docId); + findDocNamedEntities(project, docId); nbMessages++; } } catch (Throwable e) { logger.error("error in ExtractNlpTask loop", e); } } - logger.info("exiting ExtractNlpTask loop after {} messages.", nbMessages); return nbMessages; } - void findNamedEntities(final Project project, final String id) throws InterruptedException { + void findDocNamedEntities(final Project project, final String id) throws InterruptedException { try { Document doc = indexer.get(project.getName(), id); if (doc != null) { logger.info("extracting {} entities for document {}", nlpPipeline.getType(), shorten(doc.getId(), 4)); if (nlpPipeline.initialize(doc.getLanguage())) { int nbEntities = 0; - if (doc.getContent().length() < this.maxContentLengthChars) { - List namedEntities = nlpPipeline.process(doc); + if (doc.getContent().length() < this.maxContentLength) { + List namedEntities = nlpPipeline.processDoc(doc); indexer.bulkAdd(project.getName(), nlpPipeline.getType(), namedEntities, doc); nbEntities = namedEntities.size(); } else { - int nbChunks = doc.getContent().length() / this.maxContentLengthChars + 1; + int nbChunks = doc.getContent().length() / this.maxContentLength + 1; logger.info("document is too large, extracting entities for {} document chunks", nbChunks); for (int chunkIndex = 0; chunkIndex < nbChunks; chunkIndex++) { - List namedEntities = nlpPipeline.process(doc, maxContentLengthChars, chunkIndex * maxContentLengthChars); + List namedEntities = + nlpPipeline.processDoc(doc, maxContentLength, chunkIndex * maxContentLength); if (chunkIndex < nbChunks - 1) { indexer.bulkAdd(project.getName(), namedEntities); } else { diff --git a/datashare-app/src/main/java/org/icij/datashare/web/NerResource.java b/datashare-app/src/main/java/org/icij/datashare/web/NerResource.java index 1a8fabef0..887e7073c 100644 --- a/datashare-app/src/main/java/org/icij/datashare/web/NerResource.java +++ b/datashare-app/src/main/java/org/icij/datashare/web/NerResource.java @@ -10,6 +10,7 @@ import net.codestory.http.annotations.Post; import net.codestory.http.annotations.Prefix; import org.icij.datashare.extension.PipelineRegistry; +import org.icij.datashare.text.Document; import org.icij.datashare.text.DocumentBuilder; import org.icij.datashare.text.Language; import org.icij.datashare.text.NamedEntity; @@ -51,7 +52,8 @@ public List getAnnotations(@Parameter(name = "pipeline", descriptio Pipeline p = pipelineRegistry.get(Pipeline.Type.parse(pipeline)); Language language = languageGuesser.guess(text); if (p.initialize(language)) { - return p.process(DocumentBuilder.createDoc("inline").with(text).with(language).build()); + Document doc = DocumentBuilder.createDoc("inline").with(text).with(language).build(); + return p.processDoc(doc); } return emptyList(); } diff --git a/datashare-app/src/test/java/org/icij/datashare/nlp/CoreNlpTestManual.java b/datashare-app/src/test/java/org/icij/datashare/nlp/CoreNlpTestManual.java index b73d503a7..35ab7d82c 100644 --- a/datashare-app/src/test/java/org/icij/datashare/nlp/CoreNlpTestManual.java +++ b/datashare-app/src/test/java/org/icij/datashare/nlp/CoreNlpTestManual.java @@ -22,7 +22,7 @@ public void test_download_and_load_jar() throws Exception { systemClassLoader.add(distDir.toURI().toURL()); CorenlpPipeline corenlpPipeline = new CorenlpPipeline(new PropertiesProvider()); corenlpPipeline.initialize(Language.ENGLISH); - List process = corenlpPipeline.process(DocumentBuilder.createDoc("my_doc_id").with("this is Dwight's document").build()); + List process = corenlpPipeline.processDoc(DocumentBuilder.createDoc("my_doc_id").with("this is Dwight's document").build()); assertThat(process.size()).isGreaterThan(0); } } diff --git a/datashare-app/src/test/java/org/icij/datashare/nlp/EmailPipelineTest.java b/datashare-app/src/test/java/org/icij/datashare/nlp/EmailPipelineTest.java index 15d4b07d3..c1b82724c 100644 --- a/datashare-app/src/test/java/org/icij/datashare/nlp/EmailPipelineTest.java +++ b/datashare-app/src/test/java/org/icij/datashare/nlp/EmailPipelineTest.java @@ -1,5 +1,6 @@ package org.icij.datashare.nlp; +import java.util.Objects; import org.icij.datashare.PropertiesProvider; import org.icij.datashare.text.Document; import org.icij.datashare.text.Language; @@ -24,9 +25,9 @@ public class EmailPipelineTest { private final EmailPipeline pipeline = new EmailPipeline(new PropertiesProvider()); @Test - public void test_no_email() { - List annotations = pipeline.process(createDocument("this is a content without email but with an arobase (@).", "docId", Language.ENGLISH)); - assertThat(annotations).isEmpty(); + public void test_no_email() throws InterruptedException { + List namedEntities = pipeline.processDoc(createDocument("this is a content without email but with an arobase (@).", "docId", Language.ENGLISH)); + assertThat(namedEntities).isEmpty(); } private Document createDocument(String content, String docId, Language language) { @@ -34,63 +35,66 @@ private Document createDocument(String content, String docId, Language language) } @Test - public void test_one_email() { + public void test_one_email() throws InterruptedException { String content = "this is a content with email@domain.com"; - List annotations = pipeline.process(createDocument(content, "docId", Language.ENGLISH)); + List nameEntities = pipeline.processDoc(createDocument(content, "docId", Language.ENGLISH)); - assertThat(annotations).hasSize(1); - assertThat(annotations.get(0).getOffsets()).containsExactly(23L); - assertThat(annotations.get(0).getCategory()).isEqualTo(NamedEntity.Category.EMAIL); - assertThat(annotations.get(0).getMention()).isEqualTo("email@domain.com"); + assertThat(nameEntities).hasSize(1); + assertThat(nameEntities.get(0).getOffsets()).containsExactly(23L); + assertThat(nameEntities.get(0).getCategory()).isEqualTo(NamedEntity.Category.EMAIL); + assertThat(nameEntities.get(0).getMention()).isEqualTo("email@domain.com"); } @Test - public void test_one_email_twice() { - String content = "this is a content with email@domain.com\n" + - "that is twice in the document\n" + - "email@domain.com"; - List annotations = pipeline.process(createDocument(content, "docId", Language.ENGLISH)); - - assertThat(annotations).hasSize(1); - NamedEntity nlpTag = annotations.get(0); + public void test_one_email_twice() throws InterruptedException { + String content = """ +this is a content with email@domain.com +that is twice in the document +email@domain.com"""; + List namedEntities = pipeline.processDoc(createDocument(content, "docId", Language.ENGLISH)); + + assertThat(namedEntities).hasSize(1); + NamedEntity nlpTag = namedEntities.get(0); assertThat(nlpTag.getOffsets()).containsExactly(23L, 70L); assertThat(nlpTag.getMention()).isEqualTo("email@domain.com"); } @Test - public void test_three_emails() { - List annotations = pipeline.process(createDocument("this is a content with email@domain.com\n" + - "and another one : foo@bar.com\n" + - "and baz@qux.fr", "docId", Language.ENGLISH)); + public void test_three_emails() throws InterruptedException { + List namedEntities = pipeline.processDoc(createDocument(""" +this is a content with email@domain.com +and another one : foo@bar.com +and baz@qux.fr""", "docId", Language.ENGLISH)); - assertThat(annotations).hasSize(3); + assertThat(namedEntities).hasSize(3); } @Test public void test_emails_chunked_content() { - Document document = createDocument("this is a content with email@domain.com\n" + - "and another one : foo@bar.com\n" + - "and baz@qux.fr", "docId", Language.ENGLISH); - List annotations = pipeline.process(document, 20, 72); - - assertThat(annotations).hasSize(1); - assertThat(annotations.get(0).getMention()).isEqualTo("baz@qux.fr"); - assertThat(annotations.get(0).getOffsets()).containsExactly(74L); + Document document = createDocument(""" +this is a content with email@domain.com +and another one : foo@bar.com +and baz@qux.fr""", "docId", Language.ENGLISH); + List namedEntities = pipeline.processDoc(document, 20, 72); + + assertThat(namedEntities).hasSize(1); + assertThat(namedEntities.get(0).getMention()).isEqualTo("baz@qux.fr"); + assertThat(namedEntities.get(0).getOffsets()).containsExactly(74L); } @Test - public void test_acceptance() throws IOException { - Path emailFile = Paths.get(getClass().getResource("/email.eml").getPath()); + public void test_acceptance() throws IOException, InterruptedException { + Path emailFile = Paths.get(Objects.requireNonNull(getClass().getResource("/email.eml")).getPath()); String content = new String(Files.readAllBytes(emailFile)); - List annotations = pipeline.process(createDocument(content, "docId", Language.ENGLISH)); + List namedEntities = pipeline.processDoc(createDocument(content, "docId", Language.ENGLISH)); - assertThat(annotations).hasSize(3); - assertThat(annotations.get(0).getOffsets()).containsExactly(14L, 48L, 168L, 332L, 1283L, 1482L, 1544L, 1582L); + assertThat(namedEntities).hasSize(3); + assertThat(namedEntities.get(0).getOffsets()).containsExactly(14L, 48L, 168L, 332L, 1283L, 1482L, 1544L, 1582L); } @Test - public void test_adds_document_headers_parsing_for_email() { + public void test_adds_document_headers_parsing_for_email() throws InterruptedException { Document doc = createDoc("docid") .with("hello@world.com") .withRootId("root") @@ -101,7 +105,7 @@ public void test_adds_document_headers_parsing_for_email() { put(tikaMsgHeader("Cc"), "email2@domain.com,email3@domain.com"); }}).build(); - List namedEntities = pipeline.process(doc); + List namedEntities = pipeline.processDoc(doc); Map metaFirst = Map.of("emailHeaderField", "tika_metadata_message_to"); NamedEntity fourth = NamedEntity.create( @@ -124,7 +128,7 @@ public void test_adds_document_headers_parsing_for_email() { } @Test - public void test_filter_headers_that_contains_mail_addresses() { + public void test_filter_headers_that_contains_mail_addresses() throws InterruptedException { Document doc = createDoc("docid") .with("mail content") .ofContentType("message/rfc822") @@ -151,7 +155,7 @@ public void test_filter_headers_that_contains_mail_addresses() { put(tikaRawHeader("Resent-bcc"), "resent-bcc@head.er"); }}).build(); - List namedEntities = pipeline.process(doc); + List namedEntities = pipeline.processDoc(doc); assertThat(namedEntities).containsExactly( NamedEntity.create(EMAIL, "replyto@head.er", List.of(-1L), "docid", "root", Type.EMAIL, FRENCH, Map.of("emailHeaderField", "tika_metadata_message_raw_header_reply_to")), diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskIntTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskIntTest.java index 9377d517e..9f1c75834 100644 --- a/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskIntTest.java +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskIntTest.java @@ -5,6 +5,7 @@ import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.TypeLiteral; +import java.util.stream.Stream; import org.icij.datashare.PipelineHelper; import org.icij.datashare.PropertiesProvider; import org.icij.datashare.Stage; @@ -31,6 +32,7 @@ import static java.util.Arrays.asList; import static org.icij.datashare.text.DocumentBuilder.createDoc; import static org.icij.datashare.text.Language.ENGLISH; +import static org.icij.datashare.text.Language.FRENCH; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -50,21 +52,53 @@ public ExtractNlpTaskIntTest(Injector injector) { } @Test(timeout = 2000) - public void test_loop_consume_two_documents() throws Exception { + public void test_loop_consume_two_doc_texts() throws Exception { when(pipeline.getType()).thenReturn(Pipeline.Type.CORENLP); when(pipeline.initialize(any())).thenReturn(true); - Document doc = createDoc("content").build(); - when(indexer.get(anyString(), eq("docId"))).thenReturn(doc); + Document enDoc = createDoc("enId").with("content").with(ENGLISH).build(); + Document frDoc = createDoc("fr").with("contenu").with(FRENCH).build(); + when(indexer.get(anyString(), eq("enId"))).thenReturn(enDoc); + when(indexer.get(anyString(), eq("frId"))).thenReturn(frDoc); String queueName = new PipelineHelper(new PropertiesProvider()).getQueueNameFor(Stage.NLP); DocumentQueue queue = factory.createQueue(queueName, String.class); - queue.add("docId"); + queue.add("enId"); + queue.add("frId"); queue.add(PipelineTask.STRING_POISON); nlpTask.call(); verify(pipeline).initialize(ENGLISH); - verify(pipeline).process(doc); + verify(pipeline).processText(any(Stream.class), eq(ENGLISH)); + verify(pipeline).terminate(ENGLISH); + verify(pipeline).processText(any(Stream.class), eq(ENGLISH)); + verify(pipeline).initialize(FRENCH); + verify(pipeline).terminate(FRENCH); + } + + @Test(timeout = 2000) + public void test_loop_consume_two_docs() throws Exception { + when(pipeline.getType()).thenReturn(Pipeline.Type.EMAIL); + when(pipeline.initialize(any())).thenReturn(true); + Document enDoc = createDoc("enId").with(ENGLISH).build(); + Document frDoc = createDoc("frId").with(FRENCH).build(); + when(indexer.get(anyString(), eq("enId"))).thenReturn(enDoc); + when(indexer.get(anyString(), eq("frId"))).thenReturn(frDoc); + + String queueName = new PipelineHelper(new PropertiesProvider()).getQueueNameFor(Stage.NLP); + DocumentQueue queue = factory.createQueue(queueName, String.class); + queue.add("enId"); + queue.add("frId"); + queue.add(PipelineTask.STRING_POISON); + + nlpTask.call(); + + verify(pipeline).initialize(ENGLISH); + verify(pipeline).processDoc(eq(enDoc)); + verify(pipeline).terminate(ENGLISH); + verify(pipeline).initialize(FRENCH); + verify(pipeline).processDoc(eq(frDoc)); + verify(pipeline).terminate(FRENCH); } @Parameterized.Parameters @@ -93,7 +127,8 @@ protected void configure() { public void setUp() { initMocks(this); nlpTask = new ExtractNlpTask(indexer, pipeline, factory, new Task<>(ExtractNlpTask.class.getName(), User.local(), new HashMap<>(){{ - put("maxContentLength", "32"); + put("maxContentLength", "8"); + put("batchSize", "1"); }}), null); } } diff --git a/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskTest.java b/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskTest.java index e34ce6c0d..35160a9bb 100644 --- a/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskTest.java +++ b/datashare-app/src/test/java/org/icij/datashare/tasks/ExtractNlpTaskTest.java @@ -1,86 +1,146 @@ package org.icij.datashare.tasks; -import org.icij.datashare.asynctasks.Task; -import org.icij.datashare.extract.MemoryDocumentCollectionFactory; -import org.icij.datashare.text.Document; -import org.icij.datashare.text.Language; -import org.icij.datashare.text.indexing.Indexer; -import org.icij.datashare.text.nlp.AbstractPipeline; -import org.icij.datashare.user.User; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; - -import java.util.HashMap; - import static java.util.Collections.emptyList; +import static org.fest.assertions.Assertions.assertThat; +import static org.icij.datashare.tasks.PipelineTask.STRING_POISON; import static org.icij.datashare.text.DocumentBuilder.createDoc; import static org.icij.datashare.text.Language.ENGLISH; +import static org.icij.datashare.text.Language.FRENCH; import static org.icij.datashare.text.Project.project; - import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; - +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.mockito.MockitoAnnotations.initMocks; +import java.util.ArrayList; +import java.util.Map; +import java.util.stream.Stream; +import org.icij.datashare.asynctasks.Task; +import org.icij.datashare.extract.MemoryDocumentCollectionFactory; +import org.icij.datashare.text.Document; +import org.icij.datashare.text.Language; +import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.text.nlp.AbstractPipeline; +import org.icij.datashare.user.User; +import org.icij.extract.queue.DocumentQueue; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; + public class ExtractNlpTaskTest { - @Mock private Indexer indexer; - @Mock private AbstractPipeline pipeline; + @Mock + private Indexer indexer; + @Mock + private AbstractPipeline pipeline; private final MemoryDocumentCollectionFactory factory = new MemoryDocumentCollectionFactory<>(); private ExtractNlpTask nlpTask; + private final String INPUT_QUEUE_NAME = "extract:queue:nlp"; @Before public void setUp() { initMocks(this); - nlpTask = new ExtractNlpTask(indexer, pipeline, factory, new Task<>(ExtractNlpTask.class.getName(), User.local(), new HashMap<>(){{ - put("maxContentLength", "32"); - }}), null); + nlpTask = + new ExtractNlpTask(indexer, pipeline, factory, new Task<>(ExtractNlpTask.class.getName(), User.local(), + Map.of("maxContentLength", "8", "batchSize", "2")), + null + ); + factory.queues.values().forEach(q -> q.drainTo(new ArrayList<>())); } @Test - public void test_on_message_does_nothing__when_doc_not_found_in_index() throws Exception { - nlpTask.findNamedEntities(project("projectName"),"unknownId"); + public void test_find_doc_named_entities_does_nothing_when_doc_not_found_in_index() throws Exception { + nlpTask.findDocNamedEntities(project("projectName"), "unknownId"); verify(pipeline, never()).initialize(any(Language.class)); - verify(pipeline, never()).process(any()); + verify(pipeline, never()).processDoc(any()); + verify(pipeline, never()).processDoc(any(), anyInt(), anyInt()); } @Test - public void test_on_message_do_not_processNLP__when_init_fails() throws Exception { + public void test_extract_from_texts_does_nothing_when_doc_not_found_in_index() throws Exception { + DocumentQueue inputQueue = factory.getQueues(INPUT_QUEUE_NAME, String.class).get(0); + inputQueue.add("unknownId"); + inputQueue.add(STRING_POISON); + nlpTask.extractFromTexts(); + + verify(pipeline, never()).initialize(any(Language.class)); + verify(pipeline, never()).processText(any(), any()); + } + + @Test + public void test_find_doc_named_entities_does_nothing_when_init_fails() throws Exception { when(pipeline.initialize(any())).thenReturn(false); when(indexer.get(anyString(), anyString(), anyString())).thenReturn(createDoc("content").build()); - nlpTask.findNamedEntities(project("projectName"),"id"); - verify(pipeline, never()).process(any()); + nlpTask.findDocNamedEntities(project("projectName"), "id"); + verify(pipeline, never()).processDoc(any()); + verify(pipeline, never()).processDoc(any(), anyInt(), anyInt()); } @Test - public void test_on_message_processNLP__when_doc_found_in_index() throws Exception { - when(pipeline.initialize(any())).thenReturn(true); - Document doc = createDoc("content").build(); - when(pipeline.process(doc)).thenReturn(emptyList()); - when(indexer.get("projectName", doc.getId())).thenReturn(doc); + public void test_extract_from_texts_does_nothing_when_init_fails() throws Exception { + when(pipeline.initialize(any())).thenReturn(false); + DocumentQueue inputQueue = factory.getQueues(INPUT_QUEUE_NAME, String.class).get(0); + inputQueue.add("docId"); + inputQueue.add(STRING_POISON); - nlpTask.findNamedEntities(project("projectName"), doc.getId()); + nlpTask.extractFromTexts(); - verify(pipeline).initialize(ENGLISH); - verify(pipeline).process(doc); + verify(pipeline, never()).processText(any(), any()); } @Test - public void test_on_message_process__chunked_doc_when_doc_is_large() throws Exception { + public void test_find_doc_named_entities_chunks_doc_when_too_large() throws InterruptedException { when(pipeline.initialize(any())).thenReturn(true); Document doc = createDoc("huge_doc").with("0123456789abcdef0123456789abcdef+").build(); - when(pipeline.process(doc)).thenReturn(emptyList()); - when(indexer.get("projectName", doc.getId())).thenReturn(doc); + when(pipeline.processDoc(doc)).thenReturn(emptyList()); + when(indexer.get(anyString(), anyString())).thenReturn(doc); - nlpTask.findNamedEntities(project("projectName"), doc.getId()); + nlpTask.findDocNamedEntities(project("projectName"), doc.getId()); verify(pipeline).initialize(ENGLISH); - verify(pipeline).process(doc, 32, 0); - verify(pipeline).process(doc, 32, 32); + verify(pipeline).processDoc(doc, 8, 0); + verify(pipeline).processDoc(doc, 8, 8); + verify(pipeline).processDoc(doc, 8, 16); + verify(pipeline).processDoc(doc, 8, 24); + } + + @Test + public void test_should_process_docs_by_batch_grouped_by_language() throws InterruptedException { + // Given + when(pipeline.initialize(any())).thenReturn(true); + Document enDoc0 = createDoc("enId0").with("content").with(ENGLISH).build(); + Document enDoc1 = createDoc("enId1").with("long content").with(ENGLISH).build(); + Document frDoc0 = createDoc("frId0").with("contenu long").with(FRENCH).build(); + Document frDoc1 = createDoc("frId1").with("contenu").with(FRENCH).build(); + when(indexer.get(anyString(), same(enDoc0.getId()))).thenReturn(enDoc0); + when(indexer.get(anyString(), same(enDoc1.getId()))).thenReturn(enDoc1); + when(indexer.get(anyString(), same(frDoc0.getId()))).thenReturn(frDoc0); + when(indexer.get(anyString(), same(frDoc1.getId()))).thenReturn(frDoc1); + DocumentQueue inputQueue = factory.getQueues(INPUT_QUEUE_NAME, String.class).get(0); + inputQueue.add(enDoc0.getId()); + inputQueue.add(enDoc1.getId()); + inputQueue.add(frDoc0.getId()); + inputQueue.add(frDoc1.getId()); + inputQueue.add(STRING_POISON); + + // When + nlpTask.extractFromTexts(); + + // Then + verify(pipeline).initialize(ENGLISH); + verify(pipeline).terminate(ENGLISH); + verify(pipeline).initialize(FRENCH); + verify(pipeline).terminate(FRENCH); + ArgumentCaptor> streamCaptor = ArgumentCaptor.forClass(Stream.class); + verify(pipeline, times(2)).processText(streamCaptor.capture(), same(ENGLISH)); + verify(pipeline, times(2)).processText(streamCaptor.capture(), same(FRENCH)); + assertThat(streamCaptor.getAllValues().size()).isEqualTo(4); } } diff --git a/datashare-app/src/test/java/org/icij/datashare/web/NerResourceTest.java b/datashare-app/src/test/java/org/icij/datashare/web/NerResourceTest.java index 70ba0736c..9349a3fc2 100644 --- a/datashare-app/src/test/java/org/icij/datashare/web/NerResourceTest.java +++ b/datashare-app/src/test/java/org/icij/datashare/web/NerResourceTest.java @@ -17,7 +17,6 @@ import java.util.List; import java.util.Map; -import static java.util.Arrays.asList; import static java.util.Collections.emptyList; import static org.fest.assertions.Assertions.assertThat; import static org.fest.assertions.MapAssert.entry; @@ -25,6 +24,7 @@ import static org.icij.datashare.text.Language.ENGLISH; import static org.icij.datashare.text.nlp.Pipeline.Type.CORENLP; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.verify; @@ -46,36 +46,38 @@ public void setUp() throws Exception { @Test public void test_post_empty_text() throws Exception { Document doc = DocumentBuilder.createDoc("inline").with("").with(ENGLISH).build(); - doReturn(emptyList()).when(pipeline).process(eq(doc)); + doReturn(emptyList()).when(pipeline).processDoc(eq(doc)); post("/api/ner/findNames/CORENLP", doc.getContent()).should().respond(200).contain("[]"); verify(pipeline).initialize(ENGLISH); - verify(pipeline).process(doc); + verify(pipeline).processDoc(eq(doc)); } @Test - public void test_get_pipeline_list() throws Exception { + public void test_get_pipeline_list() { doReturn(asSet(Pipeline.Type.EMAIL, Pipeline.Type.IXAPIPE)).when(registry).getPipelineTypes(); get("/api/ner/pipelines").should().respond(200).contain("EMAIL").contain("IXAPIPE"); } @Test - public void test_post_text_returns_NamedEntity_list() throws Exception { + public void test_post_text_returns_named_entity_list() throws Exception { Document doc = DocumentBuilder.createDoc("inline").with("This the 'foù' file content.").with(ENGLISH).build(); final Annotations annotations = new Annotations("inline", CORENLP, ENGLISH); annotations.add( 10, 13, NamedEntity.Category.PERSON); - doReturn(asList(NamedEntity.create(NamedEntity.Category.PERSON, "foù", asList(10L), doc.getId(), "root", CORENLP, ENGLISH))).when(pipeline).process(eq(doc)); + doReturn(List.of(NamedEntity.create(NamedEntity.Category.PERSON, "foù", List.of(10L), doc.getId(), "root", CORENLP, ENGLISH))) + .when(pipeline) + .processDoc(eq(doc)); Response response = post("/api/ner/findNames/CORENLP", doc.getContent()).response(); - List actualNerList = TypeConvert.fromJson(response.content(), List.class); + List actualNerList = TypeConvert.fromJson(response.content(), List.class); assertThat(actualNerList).hasSize(1); assertThat(actualNerList.get(0)).isInstanceOf(HashMap.class); assertThat((Map) actualNerList.get(0)).includes( entry("mention", "foù"), entry("extractor", "CORENLP"), entry("mentionNorm", "fou"), - entry("offsets", asList(10)) + entry("offsets", List.of(10)) ); } } diff --git a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java index 77b366895..2d329fc07 100644 --- a/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java +++ b/datashare-cli/src/main/java/org/icij/datashare/cli/DatashareCliOptions.java @@ -37,6 +37,7 @@ public final class DatashareCliOptions { public static final String BATCH_SEARCH_MAX_TIME_OPT = "batchSearchMaxTimeSeconds"; public static final String BATCH_SEARCH_SCROLL_DURATION_OPT = "batchSearchScroll"; public static final String BATCH_SEARCH_SCROLL_SIZE_OPT = "batchSearchScrollSize"; + public static final String BATCH_SIZE_OPT = "batchSize"; public static final String BATCH_THROTTLE_OPT = "batchThrottleMilliseconds"; public static final String BROWSER_OPEN_LINK_OPT = "browserOpenLink"; public static final String BUS_TYPE_OPT = "busType"; @@ -212,7 +213,7 @@ static void followSymlinks(OptionParser parser) { singletonList(FOLLOW_SYMLINKS_OPT), "Follow symlinks while scanning documents") .withRequiredArg() .ofType(Boolean.class) - .defaultsTo(DEFAULT_FOLLOW_SYMLINKS);; + .defaultsTo(DEFAULT_FOLLOW_SYMLINKS); } static void cors(OptionParser parser) { @@ -344,7 +345,7 @@ static void dataDir(OptionParser parser) { static void artifactDir(OptionParser parser) { parser.acceptsAll( - asList(ARTIFACT_DIR_OPT), + List.of(ARTIFACT_DIR_OPT), "Artifact directory for embedded caching. If not provided datashare will use memory." ) .withRequiredArg(); } @@ -790,7 +791,7 @@ public static void oauthClaimIdAttribute(OptionParser parser) { } public static ValueConverter toAbsolute() { - return new ValueConverter() { + return new ValueConverter<>() { @Override public String convert(String value) { Path path = Paths.get(value); diff --git a/datashare-nlp-corenlp/src/main/java/org/icij/datashare/text/nlp/corenlp/CorenlpPipeline.java b/datashare-nlp-corenlp/src/main/java/org/icij/datashare/text/nlp/corenlp/CorenlpPipeline.java index 7dc22a0b7..9b7bf038b 100644 --- a/datashare-nlp-corenlp/src/main/java/org/icij/datashare/text/nlp/corenlp/CorenlpPipeline.java +++ b/datashare-nlp-corenlp/src/main/java/org/icij/datashare/text/nlp/corenlp/CorenlpPipeline.java @@ -5,16 +5,15 @@ import com.google.inject.Inject; import edu.stanford.nlp.pipeline.CoreDocument; import edu.stanford.nlp.pipeline.StanfordCoreNLP; +import edu.stanford.nlp.util.Pair; import java.util.List; import java.util.Set; +import java.util.stream.Stream; import org.icij.datashare.PropertiesProvider; -import org.icij.datashare.function.ThrowingFunctions; -import org.icij.datashare.text.Document; -import org.icij.datashare.text.Hasher; import org.icij.datashare.text.Language; -import org.icij.datashare.text.NamedEntitiesBuilder; import org.icij.datashare.text.NamedEntity; import org.icij.datashare.text.nlp.AbstractPipeline; +import org.icij.datashare.text.nlp.NlpTag; import org.icij.datashare.text.nlp.Pipeline; import org.icij.datashare.text.nlp.corenlp.models.CoreNlpPipelineModels; @@ -49,14 +48,10 @@ public boolean initialize(Language language) throws InterruptedException { return initializePipelineAnnotator(language); } - @Override - public List process(Document doc) throws InterruptedException { - return process(doc, doc.getContentTextLength(), 0); - } @Override - public List process(Document doc, int contentLength, int contentOffset) throws InterruptedException { - return processPipeline(doc, contentLength, contentOffset); + public List> processText(Stream batch, Language language) throws InterruptedException { + return processPipeline(batch, language); } /** @@ -82,27 +77,15 @@ private boolean initializePipelineAnnotator(Language language) throws Interrupte return true; } - /** - * Named Entity Classifier (Conditional Random Fields) only - * - * @param doc the document - */ - private List processPipeline(Document doc, int contentLength, int contentOffset) - throws InterruptedException { - NamedEntitiesBuilder namedEntitiesBuilder = - new NamedEntitiesBuilder(getType(), doc.getId(), doc.getLanguage()).withRoot(doc.getRootDocument()); - LOGGER.info("name-finding for {} in document {} (offset {})", doc.getLanguage(), Hasher.shorten(doc.getId(), 4), - contentOffset); - final StanfordCoreNLP annotator; - annotator = CoreNlpPipelineModels.getInstance().get(doc.getLanguage()); - String text = doc.getContent() - .substring(contentOffset, Math.min(contentOffset + contentLength, doc.getContentTextLength())); - CoreDocument codeDoc = annotator.processToCoreDocument(text); - codeDoc.entityMentions().forEach(e -> { - NamedEntity.Category category = NamedEntity.Category.parse(e.entityType()); - String mention = ThrowingFunctions.removeNewLines.apply(e.text()); - namedEntitiesBuilder.add(category, mention, e.charOffsets().first + contentOffset); - }); - return namedEntitiesBuilder.build(); + private List> processPipeline(Stream batch, Language language) throws InterruptedException { + final StanfordCoreNLP annotator = CoreNlpPipelineModels.getInstance().get(language); + return batch.map(text -> { + CoreDocument codeDoc = annotator.processToCoreDocument(text); + return codeDoc.entityMentions().stream().map(e -> { + NamedEntity.Category category = NamedEntity.Category.parse(e.entityType()); + Pair offsets = e.charOffsets(); + return new NlpTag(offsets.first, offsets.first, category); + }).toList(); + }).toList(); } } From e3a920ff120b8bb95372bc28cd9089975b47b5f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cle=CC=81ment=20Doumouro?= Date: Thu, 5 Sep 2024 16:06:56 +0200 Subject: [PATCH 5/5] chore: make the `NlpTag` are record + serializable class --- .../org/icij/datashare/text/NamedEntity.java | 4 +-- .../org/icij/datashare/text/nlp/NlpTag.java | 30 +++++++------------ .../org/icij/datashare/text/nlp/Pipeline.java | 6 ++-- .../org/icij/datashare/nlp/EmailPipeline.java | 4 +-- .../icij/datashare/tasks/ExtractNlpTask.java | 6 ++-- 5 files changed, 20 insertions(+), 30 deletions(-) diff --git a/datashare-api/src/main/java/org/icij/datashare/text/NamedEntity.java b/datashare-api/src/main/java/org/icij/datashare/text/NamedEntity.java index 5962a5f7a..a4f9acbb0 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/NamedEntity.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/NamedEntity.java @@ -135,8 +135,8 @@ public static List allFrom(String text, Annotations annotations) { } public static NamedEntity from(String text, NlpTag tag, Annotations annotations) { - String mention = ThrowingFunctions.removeNewLines.apply(text.substring(tag.getBegin(), tag.getEnd())); - return NamedEntity.create(tag.getCategory(), mention, List.of((long) tag.getBegin()), + String mention = ThrowingFunctions.removeNewLines.apply(text.substring(tag.begin(), tag.end())); + return NamedEntity.create(tag.category(), mention, List.of((long) tag.begin()), annotations.documentId, annotations.rootId, annotations.pipelineType, annotations.language ); } diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java index c4f835fa8..132f69938 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/NlpTag.java @@ -2,34 +2,24 @@ import static java.util.Comparator.comparingInt; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Comparator; import org.icij.datashare.text.NamedEntity; -public class NlpTag { +public record NlpTag(int begin, int end, NamedEntity.Category category) { - public static final Comparator comparator = comparingInt(NlpTag::getBegin); + public static final Comparator comparator = comparingInt(NlpTag::begin); - private final int begin; - private final int end; - private final NamedEntity.Category category; - - - public NlpTag(int begin, int end, NamedEntity.Category category) { + @JsonCreator + public NlpTag( + @JsonProperty("begin") int begin, + @JsonProperty("start") int end, + @JsonProperty("category") NamedEntity.Category category + ) { this.begin = begin; this.end = end; this.category = category; } - public int getBegin() { - return begin; - } - - public int getEnd() { - return end; - } - - public NamedEntity.Category getCategory() { - return category; - } - } diff --git a/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java b/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java index 25d215575..e53009765 100644 --- a/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java +++ b/datashare-api/src/main/java/org/icij/datashare/text/nlp/Pipeline.java @@ -116,9 +116,9 @@ default List processDoc(Document doc, int contentLength, int conten this.processText(Stream.of(docContent.substring(contentOffset, contentOffset + contentLength)), doc.getLanguage()) .get(0) .forEach(tag -> { - int begin = tag.getBegin() + contentOffset; - int end = tag.getEnd() + contentOffset; - annotations.add(begin, end, tag.getCategory()); + int begin = tag.begin() + contentOffset; + int end = tag.end() + contentOffset; + annotations.add(begin, end, tag.category()); }); return NamedEntity.allFrom(docContent, annotations); } diff --git a/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java b/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java index 5dc1e804e..bd34d704e 100644 --- a/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java +++ b/datashare-app/src/main/java/org/icij/datashare/nlp/EmailPipeline.java @@ -106,8 +106,8 @@ public List processDoc(Document doc, int contentLength, int content this.processText(Stream.of(chunkContent), doc.getLanguage()) .get(0) .forEach(t -> { - String mention = chunkContent.substring(t.getBegin(), t.getEnd()); - namedEntitiesBuilder.add(NamedEntity.Category.EMAIL, mention, t.getBegin() + contentOffset); + String mention = chunkContent.substring(t.begin(), t.end()); + namedEntitiesBuilder.add(NamedEntity.Category.EMAIL, mention, t.begin() + contentOffset); }); List entities = namedEntitiesBuilder.build(); if ("message/rfc822".equals(doc.getContentType())) { diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java index baaab840c..8287a5026 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java @@ -135,9 +135,9 @@ private void consumeBatch(List batch, Language language) throws Inter new Annotations(doc.getId(), nlpPipeline.getType(), doc.getLanguage()); int offset = item.offset; chunkTags.forEach(tag -> { - int begin = tag.getBegin() + offset; - int end = tag.getEnd() + offset; - annotations.add(begin, end, tag.getCategory()); + int begin = tag.begin() + offset; + int end = tag.end() + offset; + annotations.add(begin, end, tag.category()); }); List chunkEntities = NamedEntity.allFrom(doc.getContent(), annotations); boolean isComplete = offset + item.text.length() == doc.getContentTextLength();