From 3b0a6420e49863d9fe5908cf6e99582eb2d2882e Mon Sep 17 00:00:00 2001 From: Jimmy Lin Date: Wed, 13 May 2020 14:17:37 -0400 Subject: [PATCH] Untangle SimpleSearcher form SimpleTweetSearcher + verification script (#1189) --- pom.xml | 8 + .../io/anserini/search/SimpleSearcher.java | 215 ++++++++++------- .../anserini/search/SimpleTweetSearcher.java | 226 ++++++++++++++++++ .../search/topicreader/TopicReader.java | 22 +- .../io/anserini/util/DumpAnalyzedQueries.java | 2 +- src/main/python/verify_simplesearcher.py | 76 ++++++ .../search/topicreader/TopicReaderTest.java | 23 +- 7 files changed, 485 insertions(+), 87 deletions(-) create mode 100644 src/main/java/io/anserini/search/SimpleTweetSearcher.java create mode 100644 src/main/python/verify_simplesearcher.py diff --git a/pom.xml b/pom.xml index 39c5db735a..d8296ab62b 100644 --- a/pom.xml +++ b/pom.xml @@ -128,6 +128,14 @@ io.anserini.search.SearchMsmarco SearchMsmarco + + io.anserini.search.SimpleSearcher + SimpleSearcher + + + io.anserini.search.SimpleTweetSearcher + SimpleTweetSearcher + io.anserini.util.DumpAnalyzedQueries DumpAnalyzedQueries diff --git a/src/main/java/io/anserini/search/SimpleSearcher.java b/src/main/java/io/anserini/search/SimpleSearcher.java index fef03c92ea..96057df445 100644 --- a/src/main/java/io/anserini/search/SimpleSearcher.java +++ b/src/main/java/io/anserini/search/SimpleSearcher.java @@ -17,19 +17,18 @@ package io.anserini.search; import io.anserini.analysis.AnalyzerUtils; -import io.anserini.analysis.TweetAnalyzer; import io.anserini.index.IndexArgs; import io.anserini.index.IndexCollection; import io.anserini.index.IndexReaderUtils; -import io.anserini.index.generator.TweetGenerator; import io.anserini.rerank.RerankerCascade; import io.anserini.rerank.RerankerContext; import io.anserini.rerank.ScoredDocuments; import io.anserini.rerank.lib.Rm3Reranker; import io.anserini.rerank.lib.ScoreTiesAdjusterReranker; import io.anserini.search.query.BagOfWordsQueryGenerator; -import io.anserini.search.query.PhraseQueryGenerator; import io.anserini.search.query.QueryGenerator; +import io.anserini.search.topicreader.TopicReader; +import org.apache.commons.lang3.time.DurationFormatUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.analysis.Analyzer; @@ -37,29 +36,44 @@ import org.apache.lucene.analysis.bn.BengaliAnalyzer; import org.apache.lucene.analysis.cjk.CJKAnalyzer; import org.apache.lucene.analysis.de.GermanAnalyzer; -import org.apache.lucene.analysis.en.EnglishAnalyzer; import org.apache.lucene.analysis.es.SpanishAnalyzer; import org.apache.lucene.analysis.fr.FrenchAnalyzer; import org.apache.lucene.analysis.hi.HindiAnalyzer; import org.apache.lucene.document.Document; -import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; -import org.apache.lucene.search.*; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.LMDirichletSimilarity; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.FSDirectory; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.OptionHandlerFilter; +import org.kohsuke.args4j.ParserProperties; import java.io.Closeable; import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.SortedMap; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; @@ -69,24 +83,40 @@ /** * Class that exposes basic search functionality, designed specifically to provide the bridge between Java and Python - * via pyjnius. + * via Pyjnius. */ public class SimpleSearcher implements Closeable { public static final Sort BREAK_SCORE_TIES_BY_DOCID = new Sort(SortField.FIELD_SCORE, new SortField(IndexArgs.ID, SortField.Type.STRING_VAL)); - public static final Sort BREAK_SCORE_TIES_BY_TWEETID = - new Sort(SortField.FIELD_SCORE, - new SortField(TweetGenerator.TweetField.ID_LONG.name, SortField.Type.LONG, true)); private static final Logger LOG = LogManager.getLogger(SimpleSearcher.class); - private final IndexReader reader; - private Similarity similarity; - private Analyzer analyzer; - private RerankerCascade cascade; - private boolean searchtweets; - private boolean isRerank; + public static final class Args { + @Option(name = "-index", metaVar = "[path]", required = true, usage = "Path to Lucene index.") + public String index; - private IndexSearcher searcher = null; + @Option(name = "-topics", metaVar = "[file]", required = true, usage = "Topics file.") + public String topics; + + @Option(name = "-output", metaVar = "[file]", required = true, usage = "Output run file.") + public String output; + + @Option(name = "-rm3", usage = "Flag to use RM3.") + public Boolean useRM3 = false; + + @Option(name = "-hits", metaVar = "[number]", usage = "Max number of hits to return.") + public int hits = 1000; + + @Option(name = "-threads", metaVar = "[number]", usage = "Number of threads to use.") + public int threads = 1; + } + + protected IndexReader reader; + protected Similarity similarity; + protected Analyzer analyzer; + protected RerankerCascade cascade; + protected boolean isRerank; + + protected IndexSearcher searcher = null; /** * This class is meant to serve as the bridge between Anserini and Pyserini. @@ -98,7 +128,7 @@ public class Result { public float score; public String contents; public String raw; - public Document lucene_document; + public Document lucene_document; // Since this is for Python access, we're using Python naming conventions. public Result(String docid, int lucene_docid, float score, String contents, String raw, Document lucene_document) { this.docid = docid; @@ -110,6 +140,9 @@ public Result(String docid, int lucene_docid, float score, String contents, Stri } } + protected SimpleSearcher() { + } + public SimpleSearcher(String indexDir) throws IOException { this(indexDir, IndexCollection.DEFAULT_ANALYZER); } @@ -121,20 +154,16 @@ public SimpleSearcher(String indexDir, Analyzer analyzer) throws IOException { throw new IllegalArgumentException(indexDir + " does not exist or is not a directory."); } + SearchArgs defaults = new SearchArgs(); + this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); - this.similarity = new BM25Similarity(0.9f, 0.4f); + this.similarity = new BM25Similarity(Float.parseFloat(defaults.bm25_k1[0]), Float.parseFloat(defaults.bm25_b[0])); this.analyzer = analyzer; - this.searchtweets = false; this.isRerank = false; cascade = new RerankerCascade(); cascade.add(new ScoreTiesAdjusterReranker()); } - public void setSearchTweets(boolean flag) { - this.searchtweets = flag; - this.analyzer = flag? new TweetAnalyzer(true) : new EnglishAnalyzer(); - } - public void setAnalyzer(Analyzer analyzer) { this.analyzer = analyzer; } @@ -168,7 +197,9 @@ public void unsetRM3Reranker() { } public void setRM3Reranker() { - setRM3Reranker(10, 10, 0.5f, false); + SearchArgs defaults = new SearchArgs(); + + setRM3Reranker(Integer.parseInt(defaults.rm3_fbTerms[0]), 10, 0.5f, false); } public void setRM3Reranker(int fbTerms, int fbDocs, float originalQueryWeight) { @@ -220,20 +251,11 @@ public void close() throws IOException { } public Map batchSearch(List queries, List qids, int k, int threads) { - return batchSearchFields(queries, qids, k, -1, threads, new HashMap<>()); - } - - public Map batchSearch(List queries, List qids, int k, long t, int threads) { - return batchSearchFields(queries, qids, k, t, threads, new HashMap<>()); + return batchSearchFields(queries, qids, k, threads, new HashMap<>()); } public Map batchSearchFields(List queries, List qids, int k, int threads, Map fields) { - return batchSearchFields(queries, qids, k, -1, threads, fields); - } - - public Map batchSearchFields(List queries, List qids, int k, long t, int threads, - Map fields) { // Create the IndexSearcher here, if needed. We do it here because if we leave the creation to the search // method, we might end up with a race condition as multiple threads try to concurrently create the IndexSearcher. if (searcher == null) { @@ -253,9 +275,9 @@ public Map batchSearchFields(List queries, List { try { if (fields.size() > 0) { - results.put(qid, searchFields(query, fields, k, t)); + results.put(qid, searchFields(query, fields, k)); } else { - results.put(qid, search(query, k, t)); + results.put(qid, search(query, k)); } } catch (IOException e) { throw new CompletionException(e); @@ -264,7 +286,7 @@ public Map batchSearchFields(List queries, List queryTokens = AnalyzerUtils.analyze(analyzer, q); - return search(query, queryTokens, q, k, t); + return search(query, queryTokens, q, k); } public Result[] search(Query query, int k) throws IOException { - return search(query, null, null, k, -1); + return search(query, null, null, k); } public Result[] search(QueryGenerator generator, String q, int k) throws IOException { Query query = generator.buildQuery(IndexArgs.CONTENTS, analyzer, q); - return search(query, null, null, k, -1); + return search(query, null, null, k); } - protected Result[] search(Query query, List queryTokens, String queryString, int k, - long t) throws IOException { + protected Result[] search(Query query, List queryTokens, String queryString, int k) throws IOException { // Create an IndexSearch only once. Note that the object is thread safe. if (searcher == null) { searcher = new IndexSearcher(reader); @@ -328,36 +345,12 @@ protected Result[] search(Query query, List queryTokens, String queryStr SearchArgs searchArgs = new SearchArgs(); searchArgs.arbitraryScoreTieBreak = false; searchArgs.hits = k; - searchArgs.searchtweets = searchtweets; TopDocs rs; RerankerContext context; - if (searchtweets) { - if (t > 0) { - // Do not consider the tweets with tweet ids that are beyond the queryTweetTime - // tag contains the timestamp of the query in terms of the - // chronologically nearest tweet id within the corpus - Query filter = LongPoint.newRangeQuery(TweetGenerator.TweetField.ID_LONG.name, 0L, t); - BooleanQuery.Builder builder = new BooleanQuery.Builder(); - builder.add(filter, BooleanClause.Occur.FILTER); - builder.add(query, BooleanClause.Occur.MUST); - Query compositeQuery = builder.build(); - rs = searcher.search(compositeQuery, isRerank ? searchArgs.rerankcutoff : - k, BREAK_SCORE_TIES_BY_TWEETID, true); - context = new RerankerContext<>(searcher, null, compositeQuery, null, - queryString, queryTokens, filter, searchArgs); - } else { - rs = searcher.search(query, - isRerank ? searchArgs.rerankcutoff : k, BREAK_SCORE_TIES_BY_TWEETID, true); - context = new RerankerContext<>(searcher, null, query, null, - queryString, queryTokens, null, searchArgs); - } - } else { - rs = searcher.search(query, - isRerank ? searchArgs.rerankcutoff : k, BREAK_SCORE_TIES_BY_DOCID, true); - context = new RerankerContext<>(searcher, null, query, null, + rs = searcher.search(query, isRerank ? searchArgs.rerankcutoff : k, BREAK_SCORE_TIES_BY_DOCID, true); + context = new RerankerContext<>(searcher, null, query, null, queryString, queryTokens, null, searchArgs); - } ScoredDocuments hits = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context); @@ -379,13 +372,9 @@ protected Result[] search(Query query, List queryTokens, String queryStr return results; } - public Result[] searchFields(String q, Map fields, int k) throws IOException { - return searchFields(q, fields, k, -1); - } - // searching both the defaults contents fields and another field with weight boost // this is used for MS MARCO experiments with document expansion. - public Result[] searchFields(String q, Map fields, int k, long t) throws IOException { + public Result[] searchFields(String q, Map fields, int k) throws IOException { IndexSearcher searcher = new IndexSearcher(reader); searcher.setSimilarity(similarity); @@ -401,7 +390,7 @@ public Result[] searchFields(String q, Map fields, int k, long t) BooleanQuery query = queryBuilder.build(); List queryTokens = AnalyzerUtils.analyze(analyzer, q); - return search(query, queryTokens, q, k, -1); + return search(query, queryTokens, q, k); } /** @@ -505,4 +494,70 @@ public String documentRaw(String docid) { return IndexReaderUtils.documentRaw(reader, docid); } + // Note that this class is primarily meant to be used by automated regression scripts, not humans! + // tl;dr - Do not use this class for running experiments. Use SearchCollection instead! + // + // SimpleSearcher is the main class that exposes search functionality for Pyserini (in Python). + // As such, it has a different code path than SearchCollection, the preferred entry point for running experiments + // from Java. The main method here exposes only barebone options, primarily designed to verify that results from + // SimpleSearcher are *exactly* the same as SearchCollection (e.g., via automated regression scripts). + public static void main(String[] args) throws Exception { + Args searchArgs = new Args(); + CmdLineParser parser = new CmdLineParser(searchArgs, ParserProperties.defaults().withUsageWidth(100)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + System.err.println(e.getMessage()); + parser.printUsage(System.err); + System.err.println("Example: SimpleSearcher" + parser.printExample(OptionHandlerFilter.REQUIRED)); + return; + } + + final long start = System.nanoTime(); + SimpleSearcher searcher = new SimpleSearcher(searchArgs.index); + SortedMap> topics = TopicReader.getTopicsByFile(searchArgs.topics); + + PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(searchArgs.output), StandardCharsets.US_ASCII)); + + if (searchArgs.useRM3) { + searcher.setRM3Reranker(); + } + + if (searchArgs.threads == 1) { + for (Object id : topics.keySet()) { + Result[] results = searcher.search(topics.get(id).get("title"), searchArgs.hits); + + for (int i = 0; i < results.length; i++) { + out.println(String.format(Locale.US, "%s Q0 %s %d %f Anserini", + id, results[i].docid, (i + 1), results[i].score)); + } + } + } else { + List qids = new ArrayList<>(); + List queries = new ArrayList<>(); + + for (Object id : topics.keySet()) { + qids.add(id.toString()); + queries.add(topics.get(id).get("title")); + } + + Map allResults = searcher.batchSearch(queries, qids, searchArgs.hits, searchArgs.threads); + + // We iterate through, in natural object order. + for (Object id : topics.keySet()) { + Result[] results = allResults.get(id.toString()); + + for (int i = 0; i < results.length; i++) { + out.println(String.format(Locale.US, "%s Q0 %s %d %f Anserini", + id, results[i].docid, (i + 1), results[i].score)); + } + } + } + + out.close(); + + final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); + LOG.info("Total run time: " + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")); + } } diff --git a/src/main/java/io/anserini/search/SimpleTweetSearcher.java b/src/main/java/io/anserini/search/SimpleTweetSearcher.java new file mode 100644 index 0000000000..ad2c8a2244 --- /dev/null +++ b/src/main/java/io/anserini/search/SimpleTweetSearcher.java @@ -0,0 +1,226 @@ +/* + * Anserini: A Lucene toolkit for replicable information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.search; + +import io.anserini.analysis.AnalyzerUtils; +import io.anserini.analysis.TweetAnalyzer; +import io.anserini.index.IndexArgs; +import io.anserini.index.IndexCollection; +import io.anserini.index.IndexReaderUtils; +import io.anserini.index.generator.TweetGenerator; +import io.anserini.rerank.RerankerCascade; +import io.anserini.rerank.RerankerContext; +import io.anserini.rerank.ScoredDocuments; +import io.anserini.rerank.lib.Rm3Reranker; +import io.anserini.rerank.lib.ScoreTiesAdjusterReranker; +import io.anserini.search.query.BagOfWordsQueryGenerator; +import io.anserini.search.query.QueryGenerator; +import io.anserini.search.topicreader.TopicReader; +import org.apache.commons.lang3.time.DurationFormatUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.ar.ArabicAnalyzer; +import org.apache.lucene.analysis.bn.BengaliAnalyzer; +import org.apache.lucene.analysis.cjk.CJKAnalyzer; +import org.apache.lucene.analysis.de.GermanAnalyzer; +import org.apache.lucene.analysis.en.EnglishAnalyzer; +import org.apache.lucene.analysis.es.SpanishAnalyzer; +import org.apache.lucene.analysis.fr.FrenchAnalyzer; +import org.apache.lucene.analysis.hi.HindiAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.LongPoint; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexableField; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.similarities.BM25Similarity; +import org.apache.lucene.search.similarities.LMDirichletSimilarity; +import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.store.FSDirectory; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.OptionHandlerFilter; +import org.kohsuke.args4j.ParserProperties; + +import java.io.Closeable; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.SortedMap; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Class that exposes basic search functionality, designed specifically to provide the bridge between Java and Python + * via pyjnius. + */ +public class SimpleTweetSearcher extends SimpleSearcher implements Closeable { + public static final Sort BREAK_SCORE_TIES_BY_TWEETID = + new Sort(SortField.FIELD_SCORE, + new SortField(TweetGenerator.TweetField.ID_LONG.name, SortField.Type.LONG, true)); + private static final Logger LOG = LogManager.getLogger(SimpleTweetSearcher.class); + + public static final class Args { + @Option(name = "-index", metaVar = "[path]", required = true, usage = "Path to Lucene index.") + public String index; + + @Option(name = "-topics", metaVar = "[file]", required = true, usage = "Topics file.") + public String topics; + + @Option(name = "-output", metaVar = "[file]", required = true, usage = "Output run file.") + public String output; + + @Option(name = "-rm3", usage = "Flag to use RM3.") + public Boolean useRM3 = false; + + @Option(name = "-hits", metaVar = "[number]", usage = "max number of hits to return") + public int hits = 1000; + } + + protected SimpleTweetSearcher() { + } + + public SimpleTweetSearcher(String indexDir) throws IOException { + super(indexDir, new TweetAnalyzer()); + } + + @Override + public void close() throws IOException { + reader.close(); + } + + public Result[] searchTweets(String q, int k, long t) throws IOException { + Query query = new BagOfWordsQueryGenerator().buildQuery(IndexArgs.CONTENTS, analyzer, q); + List queryTokens = AnalyzerUtils.analyze(analyzer, q); + + return searchTweets(query, queryTokens, q, k, t); + } + + protected Result[] searchTweets(Query query, List queryTokens, String queryString, int k, long t) + throws IOException { + // Create an IndexSearch only once. Note that the object is thread safe. + if (searcher == null) { + searcher = new IndexSearcher(reader); + searcher.setSimilarity(similarity); + } + + SearchArgs searchArgs = new SearchArgs(); + searchArgs.arbitraryScoreTieBreak = false; + searchArgs.hits = k; + searchArgs.searchtweets = true; + + TopDocs rs; + RerankerContext context; + + // Do not consider the tweets with tweet ids that are beyond the queryTweetTime + // tag contains the timestamp of the query in terms of the + // chronologically nearest tweet id within the corpus + Query filter = LongPoint.newRangeQuery(TweetGenerator.TweetField.ID_LONG.name, 0L, t); + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + builder.add(filter, BooleanClause.Occur.FILTER); + builder.add(query, BooleanClause.Occur.MUST); + Query compositeQuery = builder.build(); + rs = searcher.search(compositeQuery, isRerank ? searchArgs.rerankcutoff : + k, BREAK_SCORE_TIES_BY_TWEETID, true); + context = new RerankerContext<>(searcher, null, compositeQuery, null, + queryString, queryTokens, filter, searchArgs); + + ScoredDocuments hits = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context); + + Result[] results = new Result[hits.ids.length]; + for (int i = 0; i < hits.ids.length; i++) { + Document doc = hits.documents[i]; + String docid = doc.getField(IndexArgs.ID).stringValue(); + + IndexableField field; + field = doc.getField(IndexArgs.CONTENTS); + String contents = field == null ? null : field.stringValue(); + + field = doc.getField(IndexArgs.RAW); + String raw = field == null ? null : field.stringValue(); + + results[i] = new Result(docid, hits.ids[i], hits.scores[i], contents, raw, doc); + } + + return results; + } + + // Note that this class is primarily meant to be used by automated regression scripts, not humans! + // tl;dr - Do not use this class for running experiments. Use SearchCollection instead! + // + // SimpleTweetSearcher is the main class that exposes search functionality for Pyserini (in Python). + // As such, it has a different code path than SearchCollection, the preferred entry point for running experiments + // from Java. The main method here exposes only barebone options, primarily designed to verify that results from + // SimpleSearcher are *exactly* the same as SearchCollection (e.g., via automated regression scripts). + public static void main(String[] args) throws Exception { + Args searchArgs = new Args(); + CmdLineParser parser = new CmdLineParser(searchArgs, ParserProperties.defaults().withUsageWidth(100)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + System.err.println(e.getMessage()); + parser.printUsage(System.err); + System.err.println("Example: SimpleTweetSearcher" + parser.printExample(OptionHandlerFilter.REQUIRED)); + return; + } + + final long start = System.nanoTime(); + SimpleTweetSearcher searcher = new SimpleTweetSearcher(searchArgs.index); + SortedMap> topics = TopicReader.getTopicsByFile(searchArgs.topics); + + PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(searchArgs.output), StandardCharsets.US_ASCII)); + + if (searchArgs.useRM3) { + searcher.setRM3Reranker(); + } + + for (Object id : topics.keySet()) { + long t = Long.parseLong(topics.get(id).get("time")); + Result[] results = searcher.searchTweets(topics.get(id).get("title"), 1000, t); + + for (int i=0; i { * @param file topics file * @return the {@link TopicReader} class corresponding to a known topics file, or null if unknown. */ - public static Class getTopicReaderByFile(String file) { + public static Class getTopicReaderClassByFile(String file) { // If we're given something that looks like a path with directories, pull out only the file name at the end. if (file.contains("/")) { String[] parts = file.split("/"); @@ -146,6 +146,26 @@ public static SortedMap> getTopics(Topics topics) { } } + /** + * Returns evaluation topics, automatically trying to infer its type and format. + * + * @param file topics file + * @param type of topic id + * @return a set of evaluation topics + */ + @SuppressWarnings("unchecked") + public static SortedMap> getTopicsByFile(String file) { + try { + // Get the constructor + Constructor[] ctors = getTopicReaderClassByFile(file).getDeclaredConstructors(); + // The one we want is always the zero-th one; pass in a dummy Path. + TopicReader reader = (TopicReader) ctors[0].newInstance(Paths.get(file)); + return reader.read(); + } catch (Exception e) { + return null; + } + } + /** * Returns a standard set of evaluation topics, with strings as topic ids. This method is * primarily meant for calling from Python via Pyjnius. The conversion to string topic ids diff --git a/src/main/java/io/anserini/util/DumpAnalyzedQueries.java b/src/main/java/io/anserini/util/DumpAnalyzedQueries.java index 916ac7055f..535f0d4c85 100644 --- a/src/main/java/io/anserini/util/DumpAnalyzedQueries.java +++ b/src/main/java/io/anserini/util/DumpAnalyzedQueries.java @@ -66,7 +66,7 @@ public static void main(String[] argv) throws IOException { TopicReader tr; try { // Can we infer the TopicReader? - Class clazz = TopicReader.getTopicReaderByFile(args.topicsFile.toString()); + Class clazz = TopicReader.getTopicReaderClassByFile(args.topicsFile.toString()); if (clazz != null) { System.out.println(String.format("Inferring %s has TopicReader class %s.", args.topicsFile, clazz)); } else { diff --git a/src/main/python/verify_simplesearcher.py b/src/main/python/verify_simplesearcher.py new file mode 100644 index 0000000000..28b6405c62 --- /dev/null +++ b/src/main/python/verify_simplesearcher.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +''' +Anserini: A Lucene toolkit for replicable information retrieval research + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +''' + +import logging +import regression_utils + +logger = logging.getLogger('run_es_regression') +ch = logging.StreamHandler() +ch.setFormatter(logging.Formatter('%(asctime)s %(levelname)s - %(message)s')) +logger.addHandler(ch) +logger.setLevel(logging.INFO) + +sc_command = 'target/appassembler/bin/SearchCollection ' +ss_command = 'target/appassembler/bin/SimpleSearcher ' +st_command = 'target/appassembler/bin/SimpleTweetSearcher' +robust04_index = 'indexes/lucene-index.robust04.pos+docvectors+raw' +robust04_topics = 'src/main/resources/topics-and-qrels/topics.robust04.txt' +mb11_index = 'indexes/lucene-index.mb11.pos+docvectors+raw' +mb11_topics = 'src/main/resources/topics-and-qrels/topics.microblog2011.txt' + +# Raw data for verification: array of arrays +# Each top level array contains commands whose outputs should be *identical* +# Note that spacing is intentional to make command easy to read. +groups = [ [ f'{sc_command} -index {robust04_index} -topicreader Trec -topics {robust04_topics} -bm25 -output', \ + f'{ss_command} -index {robust04_index} -topics {robust04_topics} -output', \ + f'{ss_command} -index {robust04_index} -topics {robust04_topics} -threads 4 -output' ], \ + [ f'{sc_command} -index {robust04_index} -topicreader Trec -topics {robust04_topics} -bm25 -rm3 -output', \ + f'{ss_command} -index {robust04_index} -topics {robust04_topics} -rm3 -output', \ + f'{ss_command} -index {robust04_index} -topics {robust04_topics} -rm3 -threads 4 -output' ], \ + [ f'{sc_command} -index {mb11_index} -topicreader Microblog -topics {mb11_topics} -bm25 -searchtweets -output', \ + f'{st_command} -index {mb11_index} -topics {mb11_topics} -output'], \ + [ f'{sc_command} -index {mb11_index} -topicreader Microblog -topics {mb11_topics} -bm25 -rm3 -searchtweets -output', \ + f'{st_command} -index {mb11_index} -topics {mb11_topics} -rm3 -output'], \ + ] + +if __name__ == '__main__': + group_cnt = 0 + for group in groups: + print(f'# Verifying Group {group_cnt}') + entry_cnt = 0 + group_runs = [] + for entry in group: + run_file = f'runs/run.ss_verify.g{group_cnt}.e{entry_cnt}.txt' + cmd = f'{entry} {run_file}' + print(f'Running: {cmd}') + regression_utils.run_shell_command(cmd, logger, echo=False) + + # Load in the run file. + with open(run_file, 'r') as file: + group_runs.append(file.read().replace('\n', '')) + + entry_cnt += 1 + + # Check that all run files are identical. + for i in range(len(group_runs)): + if group_runs[0] != group_runs[i]: + raise ValueError(f'Group {group_cnt}: Results are not identical!') + + print(f'# Group {group_cnt}: Results identical') + group_cnt += 1 + + print('All tests passed!') diff --git a/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java b/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java index d39a8ee107..80c1959e2f 100644 --- a/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java +++ b/src/test/java/io/anserini/search/topicreader/TopicReaderTest.java @@ -29,18 +29,31 @@ public class TopicReaderTest { @Test public void testTopicReaderClassLookup() { assertEquals(TrecTopicReader.class, - TopicReader.getTopicReaderByFile("src/main/resources/topics-and-qrels/topics.robust04.txt")); + TopicReader.getTopicReaderClassByFile("src/main/resources/topics-and-qrels/topics.robust04.txt")); assertEquals(TrecTopicReader.class, - TopicReader.getTopicReaderByFile("topics.robust04.txt")); + TopicReader.getTopicReaderClassByFile("topics.robust04.txt")); assertEquals(CovidTopicReader.class, - TopicReader.getTopicReaderByFile("src/main/resources/topics-and-qrels/topics.covid-round1.xml")); + TopicReader.getTopicReaderClassByFile("src/main/resources/topics-and-qrels/topics.covid-round1.xml")); assertEquals(CovidTopicReader.class, - TopicReader.getTopicReaderByFile("topics.covid-round1.xml")); + TopicReader.getTopicReaderClassByFile("topics.covid-round1.xml")); // Unknown TopicReader class. assertEquals(null, - TopicReader.getTopicReaderByFile("topics.unknown.txt")); + TopicReader.getTopicReaderClassByFile("topics.unknown.txt")); + } + + @Test + public void testGetTopicsByFile() { + SortedMap> topics = + TopicReader.getTopicsByFile("src/main/resources/topics-and-qrels/topics.robust04.txt"); + + assertNotNull(topics); + assertEquals(250, topics.size()); + assertEquals(301, (int) topics.firstKey()); + assertEquals("International Organized Crime", topics.get(topics.firstKey()).get("title")); + assertEquals(700, (int) topics.lastKey()); + assertEquals("gasoline tax U.S.", topics.get(topics.lastKey()).get("title")); } @Test