From ae498aa997155945ff24adb45454183e5afe93d9 Mon Sep 17 00:00:00 2001 From: Jimmy Lin Date: Wed, 22 Nov 2023 13:07:06 -0500 Subject: [PATCH] Refactoring of HNSW and InvertedDense searching, closing #2267 (#2268) --- README.md | 55 +-- .../io/anserini/search/SearchCollection.java | 117 +++--- .../search/SearchHnswDenseVectors.java | 337 ++++++------------ .../search/SearchInvertedDenseVectors.java | 278 ++++----------- .../dl19-passage-cos-dpr-distil-fw.yaml | 2 +- .../dl19-passage-cos-dpr-distil-lexlsh.yaml | 2 +- .../dl20-passage-cos-dpr-distil-fw.yaml | 2 +- .../dl20-passage-cos-dpr-distil-lexlsh.yaml | 2 +- .../msmarco-passage-cos-dpr-distil-fw.yaml | 2 +- ...msmarco-passage-cos-dpr-distil-lexlsh.yaml | 2 +- 10 files changed, 285 insertions(+), 514 deletions(-) diff --git a/README.md b/README.md index be5703721c..4e4c435da4 100644 --- a/README.md +++ b/README.md @@ -67,33 +67,34 @@ See individual pages for details! ### MS MARCO V1 Passage Regressions -| | dev | DL19 | DL20 | -|---------------------------------------------|:------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------:| -| **Unsupervised Sparse Lexical** | | | | -| BoW baselines | [+](docs/regressions/regressions-msmarco-passage.md) | [+](docs/regressions/regressions-dl19-passage.md) | [+](docs/regressions/regressions-dl20-passage.md) | -| Quantized BM25 | [✓](docs/regressions/regressions-msmarco-passage-bm25-b8.md) | [✓](docs/regressions/regressions-dl19-passage-bm25-b8.md) | [✓](docs/regressions/regressions-dl20-passage-bm25-b8.md) | -| WP baselines | [+](docs/regressions/regressions-msmarco-passage-wp.md) | [+](docs/regressions/regressions-dl19-passage-wp.md) | [+](docs/regressions/regressions-dl20-passage-wp.md) | -| Huggingface WP baselines | [+](docs/regressions/regressions-msmarco-passage-hgf-wp.md) | [+](docs/regressions/regressions-dl19-passage-hgf-wp.md) | [+](docs/regressions/regressions-dl20-passage-hgf-wp.md) | -| doc2query | [+](docs/regressions/regressions-msmarco-passage-doc2query.md) | | | -| doc2query-T5 | [+](docs/regressions/regressions-msmarco-passage-docTTTTTquery.md) | [+](docs/regressions/regressions-dl19-passage-docTTTTTquery.md) | [+](docs/regressions/regressions-dl20-passage-docTTTTTquery.md) | -| **Learned Sparse Lexical (uniCOIL family)** | | | | -| uniCOIL noexp | [✓](docs/regressions/regressions-msmarco-passage-unicoil-noexp.md) | [✓](docs/regressions/regressions-dl19-passage-unicoil-noexp.md) | [✓](docs/regressions/regressions-dl20-passage-unicoil-noexp.md) | -| uniCOIL with doc2query-T5 | [✓](docs/regressions/regressions-msmarco-passage-unicoil.md) | [✓](docs/regressions/regressions-dl19-passage-unicoil.md) | [✓](docs/regressions/regressions-dl20-passage-unicoil.md) | -| uniCOIL with TILDE | [✓](docs/regressions/regressions-msmarco-passage-unicoil-tilde-expansion.md) | | | -| **Learned Sparse Lexical (other)** | | | | -| DeepImpact | [✓](docs/regressions/regressions-msmarco-passage-deepimpact.md) | | | -| SPLADEv2 | [✓](docs/regressions/regressions-msmarco-passage-distill-splade-max.md) | | | -| SPLADE-distill CoCodenser-medium | [✓](docs/regressions/regressions-msmarco-passage-splade-distil-cocodenser-medium.md) | [✓](docs/regressions/regressions-dl19-passage-splade-distil-cocodenser-medium.md) | [✓](docs/regressions/regressions-dl20-passage-splade-distil-cocodenser-medium.md) | -| SPLADE++ CoCondenser-EnsembleDistil | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-ed.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-ed.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-ed.md) | -| SPLADE++ CoCondenser-EnsembleDistil (ONNX) | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-ed-onnx.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-ed-onnx.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-ed-onnx.md) | -| SPLADE++ CoCondenser-SelfDistil | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-sd.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-sd.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-sd.md) | -| SPLADE++ CoCondenser-SelfDistil (ONNX) | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-sd-onnx.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-sd-onnx.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-sd-onnx.md) | -| **Learned Dense** | | | | -| cosDPR-distil w/ HNSW | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-hnsw.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-hnsw.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-hnsw.md) | -| cosDPR-distil w/ HSNW (ONNX) | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-hnsw-onnx.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-hnsw-onnx.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-hnsw-onnx.md) | -| cosDPR-distil w/ "fake words" | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-fw.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-fw.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-fw.md) | -| cosDPR-distil w/ "LexLSH" | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-lexlsh.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-lexlsh.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-lexlsh.md) | -| OpenAI-ada2 w/ HNSW | [✓](docs/regressions/regressions-msmarco-passage-openai-ada2.md) | [✓](docs/regressions/regressions-dl19-passage-openai-ada2.md) | [✓](docs/regressions/regressions-dl20-passage-openai-ada2.md) | +| | dev | DL19 | DL20 | +|--------------------------------------------|:------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------:| +| **Unsupervised Sparse** | | | | +| BoW baselines | [+](docs/regressions/regressions-msmarco-passage.md) | [+](docs/regressions/regressions-dl19-passage.md) | [+](docs/regressions/regressions-dl20-passage.md) | +| Quantized BM25 | [✓](docs/regressions/regressions-msmarco-passage-bm25-b8.md) | [✓](docs/regressions/regressions-dl19-passage-bm25-b8.md) | [✓](docs/regressions/regressions-dl20-passage-bm25-b8.md) | +| WP baselines | [+](docs/regressions/regressions-msmarco-passage-wp.md) | [+](docs/regressions/regressions-dl19-passage-wp.md) | [+](docs/regressions/regressions-dl20-passage-wp.md) | +| Huggingface WP baselines | [+](docs/regressions/regressions-msmarco-passage-hgf-wp.md) | [+](docs/regressions/regressions-dl19-passage-hgf-wp.md) | [+](docs/regressions/regressions-dl20-passage-hgf-wp.md) | +| doc2query | [+](docs/regressions/regressions-msmarco-passage-doc2query.md) | | | +| doc2query-T5 | [+](docs/regressions/regressions-msmarco-passage-docTTTTTquery.md) | [+](docs/regressions/regressions-dl19-passage-docTTTTTquery.md) | [+](docs/regressions/regressions-dl20-passage-docTTTTTquery.md) | +| **Learned Sparse (uniCOIL family)** | | | | +| uniCOIL noexp | [✓](docs/regressions/regressions-msmarco-passage-unicoil-noexp.md) | [✓](docs/regressions/regressions-dl19-passage-unicoil-noexp.md) | [✓](docs/regressions/regressions-dl20-passage-unicoil-noexp.md) | +| uniCOIL with doc2query-T5 | [✓](docs/regressions/regressions-msmarco-passage-unicoil.md) | [✓](docs/regressions/regressions-dl19-passage-unicoil.md) | [✓](docs/regressions/regressions-dl20-passage-unicoil.md) | +| uniCOIL with TILDE | [✓](docs/regressions/regressions-msmarco-passage-unicoil-tilde-expansion.md) | | | +| **Learned Sparse (other)** | | | | +| DeepImpact | [✓](docs/regressions/regressions-msmarco-passage-deepimpact.md) | | | +| SPLADEv2 | [✓](docs/regressions/regressions-msmarco-passage-distill-splade-max.md) | | | +| SPLADE-distill CoCodenser-medium | [✓](docs/regressions/regressions-msmarco-passage-splade-distil-cocodenser-medium.md) | [✓](docs/regressions/regressions-dl19-passage-splade-distil-cocodenser-medium.md) | [✓](docs/regressions/regressions-dl20-passage-splade-distil-cocodenser-medium.md) | +| SPLADE++ CoCondenser-EnsembleDistil | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-ed.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-ed.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-ed.md) | +| SPLADE++ CoCondenser-EnsembleDistil (ONNX) | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-ed-onnx.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-ed-onnx.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-ed-onnx.md) | +| SPLADE++ CoCondenser-SelfDistil | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-sd.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-sd.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-sd.md) | +| SPLADE++ CoCondenser-SelfDistil (ONNX) | [✓](docs/regressions/regressions-msmarco-passage-splade-pp-sd-onnx.md) | [✓](docs/regressions/regressions-dl19-passage-splade-pp-sd-onnx.md) | [✓](docs/regressions/regressions-dl20-passage-splade-pp-sd-onnx.md) | +| **Learned Dense** (HNSW) | | | | +| cosDPR-distil w/ HNSW | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-hnsw.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-hnsw.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-hnsw.md) | +| cosDPR-distil w/ HSNW (ONNX) | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-hnsw-onnx.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-hnsw-onnx.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-hnsw-onnx.md) | +| OpenAI-ada2 w/ HNSW | [✓](docs/regressions/regressions-msmarco-passage-openai-ada2.md) | [✓](docs/regressions/regressions-dl19-passage-openai-ada2.md) | [✓](docs/regressions/regressions-dl20-passage-openai-ada2.md) | +| **Learned Dense** (Inverted; experimental) | | | | +| cosDPR-distil w/ "fake words" | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-fw.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-fw.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-fw.md) | +| cosDPR-distil w/ "LexLSH" | [✓](docs/regressions/regressions-msmarco-passage-cos-dpr-distil-lexlsh.md) | [✓](docs/regressions/regressions-dl19-passage-cos-dpr-distil-lexlsh.md) | [✓](docs/regressions/regressions-dl20-passage-cos-dpr-distil-lexlsh.md) | ### Available Corpora for Download diff --git a/src/main/java/io/anserini/search/SearchCollection.java b/src/main/java/io/anserini/search/SearchCollection.java index 599e40b69e..349e71e002 100644 --- a/src/main/java/io/anserini/search/SearchCollection.java +++ b/src/main/java/io/anserini/search/SearchCollection.java @@ -713,6 +713,68 @@ private Analyzer getAnalyzer() { private Map qrels; private Set queriesWithRel; + public static String generateRunOutput(ScoredDocuments docs, + K qid, + String format, + String runtag, + boolean removedups, + boolean removeQuery, + boolean selectMaxPassage, + String selectMaxPassage_delimiter, + int selectMaxPassage_hits) { + StringBuilder out = new StringBuilder(); + // For removing duplicate docids. + Set docids = new HashSet<>(); + + int rank = 1; + for (int i = 0; i < docs.documents.length; i++) { + String docid = docs.documents[i].get(Constants.ID); + + if (selectMaxPassage) { + docid = docid.split(selectMaxPassage_delimiter)[0]; + } + + if (docids.contains(docid)) + continue; + + // Remove docids that are identical to the query id if flag is set. + if (removeQuery && docid.equals(qid)) + continue; + + if ("msmarco".equals(format)) { + // MS MARCO output format: + out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, docid, rank)); + } else { + // Standard TREC format: + // + the first column is the topic number. + // + the second column is currently unused and should always be "Q0". + // + the third column is the official document identifier of the retrieved document. + // + the fourth column is the rank the document is retrieved. + // + the fifth column shows the score (integer or floating point) that generated the ranking. + // + the sixth column is called the "run tag" and should be a unique identifier for your + out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", + qid, docid, rank, docs.scores[i], runtag)); + } + + // Note that this option is set to false by default because duplicate documents usually indicate some + // underlying indexing issues, and we don't want to just eat errors silently. + // + // However, when we're performing passage retrieval, i.e., with "selectMaxSegment", we *do* want to remove + // duplicates. + if (removedups || selectMaxPassage) { + docids.add(docid); + } + + rank++; + + if (selectMaxPassage && rank > selectMaxPassage_hits) { + break; + } + } + + return out.toString(); + } + private final class SearcherThread extends Thread { final private IndexReader reader; final private IndexSearcher searcher; @@ -767,9 +829,6 @@ public void run() { // This is the per-query execution, in parallel. executor.execute(() -> { - // This is for holding the results. - StringBuilder out = new StringBuilder(); - String queryString = ""; if (args.topicField.contains("+")) { for (String field : args.topicField.split("\\+")) { @@ -811,56 +870,10 @@ public void run() { throw new CompletionException(e); } - // For removing duplicate docids. - Set docids = new HashSet<>(); - - int rank = 1; - for (int i = 0; i < docs.documents.length; i++) { - String docid = docs.documents[i].get(Constants.ID); - - if (args.selectMaxPassage) { - docid = docid.split(args.selectMaxPassage_delimiter)[0]; - } - - if (docids.contains(docid)) - continue; - - // Remove docids that are identical to the query id if flag is set. - if (args.removeQuery && docid.equals(qid)) - continue; - - if ("msmarco".equals(args.format)) { - // MS MARCO output format: - out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, docid, rank)); - } else { - // Standard TREC format: - // + the first column is the topic number. - // + the second column is currently unused and should always be "Q0". - // + the third column is the official document identifier of the retrieved document. - // + the fourth column is the rank the document is retrieved. - // + the fifth column shows the score (integer or floating point) that generated the ranking. - // + the sixth column is called the "run tag" and should be a unique identifier for your - out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", - qid, docid, rank, docs.scores[i], runTag)); - } - - // Note that this option is set to false by default because duplicate documents usually indicate some - // underlying indexing issues, and we don't want to just eat errors silently. - // - // However, we we're performing passage retrieval, i.e., with "selectMaxSegment", we *do* want to remove - // duplicates. - if (args.removedups || args.selectMaxPassage) { - docids.add(docid); - } - - rank++; - - if (args.selectMaxPassage && rank > args.selectMaxPassage_hits) { - break; - } - } + String runOutput = generateRunOutput(docs, qid, args.format, runTag, args.removedups, args.removeQuery, + args.selectMaxPassage, args.selectMaxPassage_delimiter, args.selectMaxPassage_hits); - results.put(qid, out.toString()); + results.put(qid, runOutput); int n = cnt.incrementAndGet(); if (n % 100 == 0) { LOG.info(String.format("%s: %d queries processed", desc, n)); diff --git a/src/main/java/io/anserini/search/SearchHnswDenseVectors.java b/src/main/java/io/anserini/search/SearchHnswDenseVectors.java index 69fbc586bc..bea6efa156 100644 --- a/src/main/java/io/anserini/search/SearchHnswDenseVectors.java +++ b/src/main/java/io/anserini/search/SearchHnswDenseVectors.java @@ -16,10 +16,9 @@ package io.anserini.search; +import ai.onnxruntime.OrtException; import io.anserini.encoder.dense.DenseEncoder; -import io.anserini.encoder.sparse.SparseEncoder; import io.anserini.index.Constants; -import io.anserini.index.IndexHnswDenseVectors; import io.anserini.rerank.ScoredDocuments; import io.anserini.search.query.VectorQueryGenerator; import io.anserini.search.topicreader.TopicReader; @@ -30,34 +29,25 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; -import org.apache.lucene.search.KnnVectorQuery; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.store.MMapDirectory; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; import org.kohsuke.args4j.OptionHandlerFilter; import org.kohsuke.args4j.ParserProperties; import org.kohsuke.args4j.spi.StringArrayOptionHandler; -import ai.onnxruntime.OrtException; import java.io.Closeable; -import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.Locale; import java.util.Map; -import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.CompletionException; @@ -68,9 +58,9 @@ import java.util.concurrent.atomic.AtomicInteger; /** - * Main entry point for search. + * Main entry point for HNSW search. */ -public final class SearchHnswDenseVectors implements Closeable { +public final class SearchHnswDenseVectors implements Runnable, Closeable { // These are the default tie-breaking rules for documents that end up with the same score with respect to a query. // For most collections, docids are strings, and we break ties by lexicographic sort order. public static final Sort BREAK_SCORE_TIES_BY_DOCID = @@ -91,14 +81,14 @@ public static class Args { @Option(name = "-topicReader", usage = "TopicReader to use.") public String topicReader = "JsonIntVector"; + @Option(name = "-topicField", usage = "Topic field that should be used as the query.") + public String topicField = "vector"; + @Option(name = "-generator", usage = "QueryGenerator to use.") public String queryGenerator = "VectorQueryGenerator"; - @Option(name = "-threads", metaVar = "[int]", usage = "Number of threads to use for running different parameter configurations.") - public int threads = 1; - - @Option(name = "-parallelism", metaVar = "[int]", usage = "Number of threads to use for each individual parameter configuration.") - public int parallelism = 8; + @Option(name = "-threads", metaVar = "[int]", usage = "Number of threads for running queries in parallel.") + public int threads = 4; @Option(name = "-removeQuery", usage = "Remove docids that have the query id when writing final run output.") public Boolean removeQuery = false; @@ -108,21 +98,14 @@ public static class Args { @Option(name = "-removedups", usage = "Remove duplicate docids when writing final run output.") public Boolean removedups = false; - @Option(name = "-hits", metaVar = "[number]", required = false, usage = "max number of hits to return") + @Option(name = "-hits", metaVar = "[number]", usage = "max number of hits to return") public int hits = 1000; - @Option(name = "-efSearch", metaVar = "[number]", required = false, usage = "efSearch parameter for HNSW search") + @Option(name = "-efSearch", metaVar = "[number]", usage = "efSearch parameter for HNSW search") public int efSearch = 100; - @Option(name = "-inmem", usage = "Boolean switch to read index in memory") - public Boolean inmem = false; - - @Option(name = "-topicField", usage = "Which field of the query should be used, default \"title\"." + - " For TREC ad hoc topics, description or narrative can be used.") - public String topicfield = "vector"; - @Option(name = "-runtag", metaVar = "[tag]", usage = "runtag") - public String runtag = null; + public String runtag = "Anserini"; @Option(name = "-format", metaVar = "[output format]", usage = "Output format, default \"trec\", alternative \"msmarco\".") public String format = "trec"; @@ -160,160 +143,10 @@ public static class Args { private final Args args; private final IndexReader reader; - - private final class SearcherThread extends Thread { - final private IndexReader reader; - final private IndexSearcher searcher; - final private SortedMap> topics; - final private String outputPath; - final private String runTag; - - private SearcherThread(IndexReader reader, SortedMap> topics, String outputPath, String runTag) { - this.reader = reader; - this.topics = topics; - this.runTag = runTag; - this.outputPath = outputPath; - this.searcher = new IndexSearcher(this.reader); - setName(outputPath); - } - - @Override - public void run() { - try { - // A short descriptor of the ranking setup. - final String desc = String.format("ranker: kNN"); - // ThreadPool for parallelizing the execution of individual queries: - ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.parallelism); - // Data structure for holding the per-query results, with the qid as the key and the results (the lines that - // will go into the final run file) as the value. - ConcurrentSkipListMap results = new ConcurrentSkipListMap<>(); - AtomicInteger cnt = new AtomicInteger(); - DenseEncoder queryEncoder; - if (args.encoder != null) { - queryEncoder = (DenseEncoder) Class - .forName(String.format("io.anserini.encoder.dense.%sEncoder", args.encoder)) - .getConstructor().newInstance(); - } else { - queryEncoder = null; - } - final long start = System.nanoTime(); - for (Map.Entry> entry : topics.entrySet()) { - K qid = entry.getKey(); - - // This is the per-query execution, in parallel. - executor.execute(() -> { - // This is for holding the results. - StringBuilder out = new StringBuilder(); - String queryString = entry.getValue().get(args.topicfield); - ScoredDocuments docs; - - float[] queryFloat = null; - if (queryEncoder != null) { - try { - queryFloat = queryEncoder.encode(queryString); - } catch (OrtException e) { - e.printStackTrace(); - } - } - try { - if (queryFloat != null) { - docs = search(this.searcher, queryFloat); - } else { - docs = search(this.searcher, queryString); - } - } catch (IOException e) { - throw new CompletionException(e); - } - - // For removing duplicate docids. - Set docids = new HashSet<>(); - - int rank = 1; - for (int i = 0; i < docs.documents.length; i++) { - String docid = docs.documents[i].get(Constants.ID); - - if (args.selectMaxPassage) { - docid = docid.split(args.selectMaxPassage_delimiter)[0]; - } - - if (docids.contains(docid)) - continue; - - // Remove docids that are identical to the query id if flag is set. - if (args.removeQuery && docid.equals(qid)) - continue; - - if ("msmarco".equals(args.format)) { - // MS MARCO output format: - out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, docid, rank)); - } else { - // Standard TREC format: - // + the first column is the topic number. - // + the second column is currently unused and should always be "Q0". - // + the third column is the official document identifier of the retrieved document. - // + the fourth column is the rank the document is retrieved. - // + the fifth column shows the score (integer or floating point) that generated the ranking. - // + the sixth column is called the "run tag" and should be a unique identifier for your - out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", - qid, docid, rank, docs.scores[i], runTag)); - } - - // Note that this option is set to false by default because duplicate documents usually indicate some - // underlying indexing issues, and we don't want to just eat errors silently. - // - // However, we we're performing passage retrieval, i.e., with "selectMaxSegment", we *do* want to remove - // duplicates. - if (args.removedups || args.selectMaxPassage) { - docids.add(docid); - } - - rank++; - - if (args.selectMaxPassage && rank > args.selectMaxPassage_hits) { - break; - } - } - - results.put(qid, out.toString()); - int n = cnt.incrementAndGet(); - if (n % 100 == 0) { - LOG.info(String.format("%s: %d queries processed", desc, n)); - } - }); - } - - executor.shutdown(); - - try { - // Wait for existing tasks to terminate. - while (!executor.awaitTermination(1, TimeUnit.MINUTES)); - } catch (InterruptedException ie) { - // (Re-)Cancel if current thread also interrupted. - executor.shutdownNow(); - // Preserve interrupt status. - Thread.currentThread().interrupt(); - } - final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); - - LOG.info(desc + ": " + topics.size() + " queries processed in " + - DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + - String.format(" = ~%.2f q/s", topics.size()/(durationMillis/1000.0))); - - // Now we write the results to a run file. - PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(outputPath), StandardCharsets.UTF_8)); - - // This is the default case: just dump out the qids by their natural order. - for (K qid : results.keySet()) { - out.print(results.get(qid)); - } - out.flush(); - out.close(); - - } catch (Exception e) { - LOG.error(Thread.currentThread().getName() + ": Unexpected Exception: ", e); - } - } - } + private final IndexSearcher searcher; + private final VectorQueryGenerator generator; + private final DenseEncoder queryEncoder; + private final ConcurrentSkipListMap results = new ConcurrentSkipListMap<>(); public SearchHnswDenseVectors(Args args) throws IOException { this.args = args; @@ -323,12 +156,37 @@ public SearchHnswDenseVectors(Args args) throws IOException { throw new IllegalArgumentException(String.format("Index path '%s' does not exist or is not a directory.", args.index)); } - LOG.info("============ Initializing Searcher ============"); + LOG.info("============ Initializing HNSW Searcher ============"); LOG.info("Index: " + indexPath); - this.reader = args.inmem ? DirectoryReader.open(MMapDirectory.open(indexPath)) : - DirectoryReader.open(FSDirectory.open(indexPath)); - LOG.info("Vector Search:"); - LOG.info("Number of threads for running different parameter configurations: " + args.threads); + LOG.info("Query generator: " + args.queryGenerator); + LOG.info("Encoder: " + args.encoder); + LOG.info("Threads: " + args.threads); + + this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); + this.searcher = new IndexSearcher(this.reader); + + try { + this.generator = (VectorQueryGenerator) Class + .forName(String.format("io.anserini.search.query.%s", args.queryGenerator)) + .getConstructor().newInstance(); + } catch (Exception e) { + e.printStackTrace(); + throw new IllegalArgumentException("Unable to load QueryGenerator: " + args.queryGenerator); + } + + if (args.encoder != null) { + try { + queryEncoder = (DenseEncoder) Class + .forName(String.format("io.anserini.encoder.dense.%sEncoder", args.encoder)) + .getConstructor().newInstance(); + } catch (Exception e) { + e.printStackTrace(); + throw new IllegalArgumentException("Unable to load encoder: " + args.encoder); + } + } else { + queryEncoder = null; + } + } @Override @@ -337,11 +195,11 @@ public void close() throws IOException { } @SuppressWarnings("unchecked") - public void runTopics() throws IOException { + @Override + public void run() { SortedMap> topics = new TreeMap<>(); - - for (String singleTopicsFile : args.topics) { - Path topicsFilePath = Paths.get(singleTopicsFile); + for (String file : args.topics) { + Path topicsFilePath = Paths.get(file); if (!Files.exists(topicsFilePath) || !Files.isRegularFile(topicsFilePath) || !Files.isReadable(topicsFilePath)) { throw new IllegalArgumentException("Topics file : " + topicsFilePath + " does not exist or is not a (readable) file."); } @@ -356,84 +214,111 @@ public void runTopics() throws IOException { } } - final String runTag = args.runtag == null ? "Anserini" : args.runtag; - LOG.info("runtag: " + runTag); - + LOG.info("============ Launching Search Threads ============"); final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.threads); + final AtomicInteger cnt = new AtomicInteger(); - LOG.info("============ Launching Search Threads ============"); + final long start = System.nanoTime(); + for (Map.Entry> entry : topics.entrySet()) { + K qid = entry.getKey(); + + // This is the per-query execution, in parallel. + executor.execute(() -> { + String queryString = entry.getValue().get(args.topicField); + ScoredDocuments docs; + + try { + docs = queryEncoder != null ? + search(this.searcher, queryEncoder.encode(queryString)) : + search(this.searcher, queryString); + } catch (IOException|OrtException e) { + throw new CompletionException(e); + } + + String runOutput = SearchCollection.generateRunOutput(docs, qid, args.format, args.runtag, args.removedups, + args.removeQuery, args.selectMaxPassage, args.selectMaxPassage_delimiter, args.selectMaxPassage_hits); + + results.put(qid, runOutput); + int n = cnt.incrementAndGet(); + if (n % 100 == 0) { + LOG.info(String.format("%d queries processed", n)); + } + }); + } - String outputPath = args.output; - executor.execute(new SearcherThread<>(reader, topics, outputPath, runTag)); executor.shutdown(); try { - // Wait for existing tasks to terminate - while (!executor.awaitTermination(1, TimeUnit.MINUTES)) { - } + // Wait for existing tasks to terminate. + while (!executor.awaitTermination(1, TimeUnit.MINUTES)); } catch (InterruptedException ie) { - // (Re-)Cancel if current thread also interrupted + // (Re-)Cancel if current thread also interrupted. executor.shutdownNow(); - // Preserve interrupt status + // Preserve interrupt status. Thread.currentThread().interrupt(); } - } + final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); - public ScoredDocuments search(IndexSearcher searcher, float[] queryFloat) throws IOException { - KnnFloatVectorQuery query = new KnnFloatVectorQuery(Constants.VECTOR, queryFloat, args.efSearch); + LOG.info(topics.size() + " queries processed in " + + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + + String.format(" = ~%.2f q/s", topics.size()/(durationMillis/1000.0))); - TopDocs rs = searcher.search(query, args.hits, BREAK_SCORE_TIES_BY_DOCID, true); - ScoredDocuments scoredDocs = ScoredDocuments.fromTopDocs(rs, searcher); + // Now we write the results to a run file. + try { + PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(args.output), StandardCharsets.UTF_8)); - return scoredDocs; - } + // This is the default case: just dump out the qids by their natural order. + for (K qid : results.keySet()) { + out.print(results.get(qid)); + } - public ScoredDocuments search(IndexSearcher searcher, String queryString) throws IOException { - KnnFloatVectorQuery query; - VectorQueryGenerator generator; - try { - generator = (VectorQueryGenerator) Class.forName("io.anserini.search.query." + args.queryGenerator) - .getConstructor().newInstance(); - } catch (Exception e) { + out.flush(); + out.close(); + } catch (IOException e) { e.printStackTrace(); - throw new IllegalArgumentException("Unable to load QueryGenerator: " + args.topicReader); } + } - query = generator.buildQuery(Constants.VECTOR, queryString, args.efSearch); + private ScoredDocuments search(IndexSearcher searcher, float[] queryFloat) throws IOException { + KnnFloatVectorQuery query = new KnnFloatVectorQuery(Constants.VECTOR, queryFloat, args.efSearch); TopDocs rs = searcher.search(query, args.hits, BREAK_SCORE_TIES_BY_DOCID, true); - ScoredDocuments scoredDocs = ScoredDocuments.fromTopDocs(rs, searcher); - return scoredDocs; + return ScoredDocuments.fromTopDocs(rs, searcher); } + private ScoredDocuments search(IndexSearcher searcher, String queryString) throws IOException { + KnnFloatVectorQuery query = generator.buildQuery(Constants.VECTOR, queryString, args.efSearch); + TopDocs rs = searcher.search(query, args.hits, BREAK_SCORE_TIES_BY_DOCID, true); + + return ScoredDocuments.fromTopDocs(rs, searcher); + } public static void main(String[] args) throws Exception { Args searchArgs = new Args(); - CmdLineParser parser = new CmdLineParser(searchArgs, ParserProperties.defaults().withUsageWidth(100)); + CmdLineParser parser = new CmdLineParser(searchArgs, ParserProperties.defaults().withUsageWidth(120)); try { parser.parseArgument(args); } catch (CmdLineException e) { System.err.println(e.getMessage()); parser.printUsage(System.err); - System.err.println("Example: SearchCollection" + parser.printExample(OptionHandlerFilter.REQUIRED)); + System.err.println("Example: SearchHnswDenseVectors" + parser.printExample(OptionHandlerFilter.REQUIRED)); return; } final long start = System.nanoTime(); - SearchHnswDenseVectors searcher; // We're at top-level already inside a main; makes no sense to propagate exceptions further, so reformat the // exception messages and display on console. try { - searcher = new SearchHnswDenseVectors(searchArgs); + SearchHnswDenseVectors searcher = new SearchHnswDenseVectors(searchArgs); + searcher.run(); + searcher.close(); } catch (IllegalArgumentException e) { System.err.println(e.getMessage()); return; } - searcher.runTopics(); - searcher.close(); final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); LOG.info("Total run time: " + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")); } diff --git a/src/main/java/io/anserini/search/SearchInvertedDenseVectors.java b/src/main/java/io/anserini/search/SearchInvertedDenseVectors.java index 81c862e5b9..1d9bba7b1c 100644 --- a/src/main/java/io/anserini/search/SearchInvertedDenseVectors.java +++ b/src/main/java/io/anserini/search/SearchInvertedDenseVectors.java @@ -31,9 +31,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.similarities.ClassicSimilarity; -import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.FSDirectory; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; @@ -43,22 +41,17 @@ import org.kohsuke.args4j.spi.StringArrayOptionHandler; import java.io.Closeable; -import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.HashSet; -import java.util.Locale; import java.util.Map; -import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentSkipListMap; -import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -69,7 +62,7 @@ /** * Main entry point for inverted dense vector search. */ -public final class SearchInvertedDenseVectors implements Closeable { +public final class SearchInvertedDenseVectors implements Runnable, Closeable { // These are the default tie-breaking rules for documents that end up with the same score with respect to a query. // For most collections, docids are strings, and we break ties by lexicographic sort order. public static final Sort BREAK_SCORE_TIES_BY_DOCID = @@ -90,10 +83,10 @@ public static class Args { @Option(name = "-topicReader", usage = "TopicReader to use.") public String topicReader; - @Option(name = "-topicField", usage = "Which field of topic should be used as the query.") + @Option(name = "-topicField", usage = "Topic field that should be used as the query.") public String topicField = "title"; - @Option(name = "-encoding", metaVar = "[word]", required = true, usage = "encoding must be one of {fw, lexlsh}") + @Option(name = "-encoding", metaVar = "[word]", required = true, usage = "Encoding, must be one of {fw, lexlsh}") public String encoding; @Option(name = "-lexlsh.n", metaVar = "[int]", usage = "ngrams") @@ -114,14 +107,8 @@ public static class Args { @Option(name = "-fw.q", metaVar = "[int]", usage = "quantization factor") public int q = FakeWordsEncoderAnalyzer.DEFAULT_Q; - @Option(name = "-threads", metaVar = "[int]", usage = "Number of threads to use for running different parameter configurations.") - public int threads = 1; - - @Option(name = "-parallelism", metaVar = "[int]", usage = "Number of threads to use for each individual parameter configuration.") - public int parallelism = 8; - - @Option(name = "-threadsPerQuery", metaVar = "[int]", usage = "Number of threads used to execute each query.") - public int threadsPerQuery = 1; + @Option(name = "-threads", metaVar = "[int]", usage = "Number of threads for running queries in parallel.") + public int threads = 4; @Option(name = "-removeQuery", usage = "Remove docids that have the query id when writing final run output.") public Boolean removeQuery = false; @@ -131,14 +118,11 @@ public static class Args { @Option(name = "-removedups", usage = "Remove duplicate docids when writing final run output.") public Boolean removedups = false; - @Option(name = "-hits", metaVar = "[number]", required = false, usage = "max number of hits to return") + @Option(name = "-hits", metaVar = "[number]", usage = "max number of hits to return") public int hits = 1000; - @Option(name = "-inmem", usage = "Boolean switch to read index in memory") - public Boolean inmem = false; - @Option(name = "-runtag", metaVar = "[tag]", usage = "runtag") - public String runtag = null; + public String runtag = "Anserini"; @Option(name = "-format", metaVar = "[output format]", usage = "Output format, default \"trec\", alternative \"msmarco\".") public String format = "trec"; @@ -173,148 +157,9 @@ public static class Args { private final Args args; private final IndexReader reader; - - private InvertedDenseVectorQueryGenerator generator; - - private final class SearcherThread extends Thread { - - final private IndexReader reader; - final private IndexSearcher searcher; - final private SortedMap> topics; - final private String outputPath; - final private String runTag; - - private SearcherThread(IndexReader reader, SortedMap> topics, String outputPath, String runTag, - ExecutorService executorService, Similarity similarity) { - this.reader = reader; - this.topics = topics; - this.runTag = runTag; - this.outputPath = outputPath; - this.searcher = executorService != null ? new IndexSearcher(this.reader, executorService) : new IndexSearcher(this.reader); - if (similarity != null) { - searcher.setSimilarity(similarity); - } - setName(outputPath); - } - - @Override - public void run() { - try { - // A short descriptor of the ranking setup. - final String desc = String.format("ranker: kNN"); - // ThreadPool for parallelizing the execution of individual queries: - ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.parallelism); - // Data structure for holding the per-query results, with the qid as the key and the results (the lines that - // will go into the final run file) as the value. - ConcurrentSkipListMap results = new ConcurrentSkipListMap<>(); - AtomicInteger cnt = new AtomicInteger(); - - final long start = System.nanoTime(); - for (Map.Entry> entry : topics.entrySet()) { - K qid = entry.getKey(); - - // This is the per-query execution, in parallel. - executor.execute(() -> { - // This is for holding the results. - StringBuilder out = new StringBuilder(); - String queryString = entry.getValue().get(args.topicField); - ScoredDocuments docs; - try { - docs = search(this.searcher, queryString); - } catch (IOException e) { - throw new CompletionException(e); - } - - // For removing duplicate docids. - Set docids = new HashSet<>(); - - int rank = 1; - for (int i = 0; i < docs.documents.length; i++) { - String docid = docs.documents[i].get(Constants.ID); - - if (args.selectMaxPassage) { - docid = docid.split(args.selectMaxPassage_delimiter)[0]; - } - - if (docids.contains(docid)) { - continue; - } - - // Remove docids that are identical to the query id if flag is set. - if (args.removeQuery && docid.equals(qid)) { - continue; - } - - if ("msmarco".equals(args.format)) { - // MS MARCO output format: - out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, docid, rank)); - } else { - // Standard TREC format: - // + the first column is the topic number. - // + the second column is currently unused and should always be "Q0". - // + the third column is the official document identifier of the retrieved document. - // + the fourth column is the rank the document is retrieved. - // + the fifth column shows the score (integer or floating point) that generated the ranking. - // + the sixth column is called the "run tag" and should be a unique identifier for your - out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", - qid, docid, rank, docs.scores[i], runTag)); - } - - // Note that this option is set to false by default because duplicate documents usually indicate some - // underlying indexing issues, and we don't want to just eat errors silently. - // - // However, we we're performing passage retrieval, i.e., with "selectMaxSegment", we *do* want to remove - // duplicates. - if (args.removedups || args.selectMaxPassage) { - docids.add(docid); - } - - rank++; - - if (args.selectMaxPassage && rank > args.selectMaxPassage_hits) { - break; - } - } - - results.put(qid, out.toString()); - int n = cnt.incrementAndGet(); - if (n % 100 == 0) { - LOG.info(String.format("%s: %d queries processed", desc, n)); - } - }); - } - - executor.shutdown(); - - try { - // Wait for existing tasks to terminate. - while (!executor.awaitTermination(1, TimeUnit.MINUTES)) ; - } catch (InterruptedException ie) { - // (Re-)Cancel if current thread also interrupted. - executor.shutdownNow(); - // Preserve interrupt status. - Thread.currentThread().interrupt(); - } - final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); - - LOG.info(desc + ": " + topics.size() + " queries processed in " + - DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + - String.format(" = ~%.2f q/s", topics.size() / (durationMillis / 1000.0))); - - // Now we write the results to a run file. - PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(outputPath), StandardCharsets.UTF_8)); - - // This is the default case: just dump out the qids by their natural order. - for (K qid : results.keySet()) { - out.print(results.get(qid)); - } - out.flush(); - out.close(); - } catch (Exception e) { - LOG.error(Thread.currentThread().getName() + ": Unexpected Exception: ", e); - } - } - } + private final IndexSearcher searcher; + private final InvertedDenseVectorQueryGenerator generator; + private final ConcurrentSkipListMap results = new ConcurrentSkipListMap<>(); public SearchInvertedDenseVectors(Args args) throws IOException { this.args = args; @@ -324,11 +169,18 @@ public SearchInvertedDenseVectors(Args args) throws IOException { throw new IllegalArgumentException(String.format("Index path '%s' does not exist or is not a directory.", args.index)); } - LOG.info("============ Initializing Searcher ============"); + LOG.info("============ Initializing InvertedDenseVector Searcher ============"); LOG.info("Index: " + indexPath); + LOG.info("Encoding: " + args.encoding); + LOG.info("Threads: " + args.threads); + this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); - LOG.info("Vector Search:"); - LOG.info("Number of threads for running different parameter configurations: " + args.threads); + this.searcher = new IndexSearcher(this.reader); + if (args.encoding.equalsIgnoreCase(FW)) { + searcher.setSimilarity(new ClassicSimilarity()); + } + + this.generator = new InvertedDenseVectorQueryGenerator(args, true); } @Override @@ -337,9 +189,8 @@ public void close() throws IOException { } @SuppressWarnings("unchecked") - public void runTopics() { - generator = new InvertedDenseVectorQueryGenerator(args, true); - TopicReader tr; + @Override + public void run() { SortedMap> topics = new TreeMap<>(); for (String singleTopicsFile : args.topics) { Path topicsFilePath = Paths.get(singleTopicsFile); @@ -347,7 +198,8 @@ public void runTopics() { throw new IllegalArgumentException("Topics file : " + topicsFilePath + " does not exist or is not a (readable) file."); } try { - tr = (TopicReader) Class.forName("io.anserini.search.topicreader." + args.topicReader + "TopicReader") + TopicReader tr = (TopicReader) Class + .forName(String.format("io.anserini.search.topicreader.%sTopicReader", args.topicReader)) .getConstructor(Path.class).newInstance(topicsFilePath); topics.putAll(tr.read()); } catch (Exception e) { @@ -356,44 +208,68 @@ public void runTopics() { } } - final String runTag = args.runtag == null ? "Anserini" : args.runtag; - LOG.info("runtag: " + runTag); - + LOG.info("============ Launching Search Threads ============"); final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.threads); + final AtomicInteger cnt = new AtomicInteger(); - LOG.info("============ Launching Search Threads ============"); + final long start = System.nanoTime(); + for (Map.Entry> entry : topics.entrySet()) { + K qid = entry.getKey(); - ExecutorService queryExecutor = null; - if (args.threadsPerQuery > 1) { - queryExecutor = Executors.newFixedThreadPool(args.threadsPerQuery); - } + // This is the per-query execution, in parallel. + executor.execute(() -> { + ScoredDocuments docs; + try { + docs = search(this.searcher, entry.getValue().get(args.topicField)); + } catch (IOException e) { + throw new CompletionException(e); + } - Similarity similarity = null; - if (args.encoding.equalsIgnoreCase(FW)) { - similarity = new ClassicSimilarity(); + String runOutput = SearchCollection.generateRunOutput(docs, qid, args.format, args.runtag, args.removedups, + args.removeQuery, args.selectMaxPassage, args.selectMaxPassage_delimiter, args.selectMaxPassage_hits); + + results.put(qid, runOutput); + int n = cnt.incrementAndGet(); + if (n % 100 == 0) { + LOG.info(String.format("%d queries processed", n)); + } + }); } - String outputPath = args.output; - executor.execute(new SearcherThread<>(reader, topics, outputPath, runTag, queryExecutor, similarity)); executor.shutdown(); try { - // Wait for existing tasks to terminate - while (!executor.awaitTermination(1, TimeUnit.MINUTES)) { - } - if (queryExecutor != null) { - while (!queryExecutor.awaitTermination(1, TimeUnit.MINUTES)) { - } - } + // Wait for existing tasks to terminate. + while (!executor.awaitTermination(1, TimeUnit.MINUTES)) ; } catch (InterruptedException ie) { - // (Re-)Cancel if current thread also interrupted + // (Re-)Cancel if current thread also interrupted. executor.shutdownNow(); - // Preserve interrupt status + // Preserve interrupt status. Thread.currentThread().interrupt(); } + final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); + + LOG.info(topics.size() + " queries processed in " + + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + + String.format(" = ~%.2f q/s", topics.size() / (durationMillis / 1000.0))); + + // Now we write the results to a run file. + try { + PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(args.output), StandardCharsets.UTF_8)); + + // This is the default case: just dump out the qids by their natural order. + for (K qid : results.keySet()) { + out.print(results.get(qid)); + } + + out.flush(); + out.close(); + } catch (IOException e) { + e.printStackTrace(); + } } - public ScoredDocuments search(IndexSearcher searcher, String queryString) throws IOException { + private ScoredDocuments search(IndexSearcher searcher, String queryString) throws IOException { Query query = generator.buildQuery(queryString); TopDocs results = searcher.search(query, args.hits, BREAK_SCORE_TIES_BY_DOCID, true); @@ -409,28 +285,24 @@ public static void main(String[] args) throws Exception { } catch (CmdLineException e) { System.err.println(e.getMessage()); parser.printUsage(System.err); - System.err.println("Example: SearchCollection" + parser.printExample(OptionHandlerFilter.REQUIRED)); + System.err.println("Example: SearchInvertedDenseVectors" + parser.printExample(OptionHandlerFilter.REQUIRED)); return; } final long start = System.nanoTime(); - SearchInvertedDenseVectors searcher; // We're at top-level already inside a main; makes no sense to propagate exceptions further, so reformat the // exception messages and display on console. try { - searcher = new SearchInvertedDenseVectors(searchArgs); + SearchInvertedDenseVectors searcher = new SearchInvertedDenseVectors(searchArgs); + searcher.run(); + searcher.close(); } catch (IllegalArgumentException e) { System.err.println(e.getMessage()); return; } - if (searchArgs.topicReader != null && searchArgs.topics != null) { - searcher.runTopics(); - } - searcher.close(); final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); LOG.info("Total run time: " + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")); } - } \ No newline at end of file diff --git a/src/main/resources/regression/dl19-passage-cos-dpr-distil-fw.yaml b/src/main/resources/regression/dl19-passage-cos-dpr-distil-fw.yaml index baf7a22423..aff3b80129 100644 --- a/src/main/resources/regression/dl19-passage-cos-dpr-distil-fw.yaml +++ b/src/main/resources/regression/dl19-passage-cos-dpr-distil-fw.yaml @@ -57,7 +57,7 @@ models: - name: cos-dpr-distil-fw-40 display: cosDPR-distill type: inverted-dense - params: -topicField vector -encoding fw -fw.q 40 -hits 1000 + params: -topicField vector -threads 16 -encoding fw -fw.q 40 -hits 1000 results: AP@1000: - 0.4271 diff --git a/src/main/resources/regression/dl19-passage-cos-dpr-distil-lexlsh.yaml b/src/main/resources/regression/dl19-passage-cos-dpr-distil-lexlsh.yaml index 8e99c127c6..e9fceecdf4 100644 --- a/src/main/resources/regression/dl19-passage-cos-dpr-distil-lexlsh.yaml +++ b/src/main/resources/regression/dl19-passage-cos-dpr-distil-lexlsh.yaml @@ -57,7 +57,7 @@ models: - name: cos-dpr-distil-lexlsh-600 display: cosDPR-distill type: inverted-dense - params: -topicField vector -encoding lexlsh -lexlsh.b 600 -hits 1000 + params: -topicField vector -threads 16 -encoding lexlsh -lexlsh.b 600 -hits 1000 results: AP@1000: - 0.4118 diff --git a/src/main/resources/regression/dl20-passage-cos-dpr-distil-fw.yaml b/src/main/resources/regression/dl20-passage-cos-dpr-distil-fw.yaml index 9eb6b75f50..0f1b35cac5 100644 --- a/src/main/resources/regression/dl20-passage-cos-dpr-distil-fw.yaml +++ b/src/main/resources/regression/dl20-passage-cos-dpr-distil-fw.yaml @@ -57,7 +57,7 @@ models: - name: cos-dpr-distil-fw-40 display: cosDPR-distill type: inverted-dense - params: -topicField vector -encoding fw -fw.q 40 -hits 1000 + params: -topicField vector -threads 16 -encoding fw -fw.q 40 -hits 1000 results: AP@1000: - 0.4597 diff --git a/src/main/resources/regression/dl20-passage-cos-dpr-distil-lexlsh.yaml b/src/main/resources/regression/dl20-passage-cos-dpr-distil-lexlsh.yaml index 3f88b61ca6..6ea4b6b713 100644 --- a/src/main/resources/regression/dl20-passage-cos-dpr-distil-lexlsh.yaml +++ b/src/main/resources/regression/dl20-passage-cos-dpr-distil-lexlsh.yaml @@ -57,7 +57,7 @@ models: - name: cos-dpr-distil-lexlsh-600 display: cosDPR-distill type: inverted-dense - params: -topicField vector -encoding lexlsh -lexlsh.b 600 -hits 1000 + params: -topicField vector -threads 16 -encoding lexlsh -lexlsh.b 600 -hits 1000 results: AP@1000: - 0.4486 diff --git a/src/main/resources/regression/msmarco-passage-cos-dpr-distil-fw.yaml b/src/main/resources/regression/msmarco-passage-cos-dpr-distil-fw.yaml index 180cde86f0..5856ca3464 100644 --- a/src/main/resources/regression/msmarco-passage-cos-dpr-distil-fw.yaml +++ b/src/main/resources/regression/msmarco-passage-cos-dpr-distil-fw.yaml @@ -57,7 +57,7 @@ models: - name: cos-dpr-distil-fw-40 display: cosDPR-distill type: inverted-dense - params: -topicField vector -encoding fw -fw.q 40 -hits 1000 + params: -topicField vector -threads 16 -encoding fw -fw.q 40 -hits 1000 results: AP@1000: - 0.3654 diff --git a/src/main/resources/regression/msmarco-passage-cos-dpr-distil-lexlsh.yaml b/src/main/resources/regression/msmarco-passage-cos-dpr-distil-lexlsh.yaml index 7c1d5b8090..b8c82c7c36 100644 --- a/src/main/resources/regression/msmarco-passage-cos-dpr-distil-lexlsh.yaml +++ b/src/main/resources/regression/msmarco-passage-cos-dpr-distil-lexlsh.yaml @@ -57,7 +57,7 @@ models: - name: cos-dpr-distil-lexlsh-600 display: cosDPR-distill type: inverted-dense - params: -topicField vector -encoding lexlsh -lexlsh.b 600 -hits 1000 + params: -topicField vector -threads 16 -encoding lexlsh -lexlsh.b 600 -hits 1000 results: AP@1000: - 0.3509