Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance: attempt to optimize the ArrayHitCounter by maintaining some state while updating the counter #721

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ lazy val `elastiknn-lucene` = project
"org.apache.lucene" % "lucene-core" % LuceneVersion,
"org.apache.lucene" % "lucene-analysis-common" % LuceneVersion % Test
),
javacOptions ++= Seq(
"--add-exports",
"java.base/jdk.internal.vm.annotation=ALL-UNNAMED"
),
TestSettings
)

Expand Down
2 changes: 1 addition & 1 deletion docs/pages/performance/fashion-mnist/plot.b64

Large diffs are not rendered by default.

Binary file modified docs/pages/performance/fashion-mnist/plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions docs/pages/performance/fashion-mnist/results.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
|Model|Parameters|Recall|Queries per Second|
|---|---|---|---|
|eknn-l2lsh|L=175 k=7 w=3900 candidates=100 probes=0|0.607|304.462|
|eknn-l2lsh|L=175 k=7 w=3900 candidates=500 probes=0|0.921|269.909|
|eknn-l2lsh|L=175 k=7 w=3900 candidates=1000 probes=0|0.962|239.598|
|eknn-l2lsh|L=175 k=7 w=3900 candidates=100 probes=0|0.607|301.546|
|eknn-l2lsh|L=175 k=7 w=3900 candidates=500 probes=0|0.921|266.820|
|eknn-l2lsh|L=175 k=7 w=3900 candidates=1000 probes=0|0.962|231.592|
Original file line number Diff line number Diff line change
@@ -1,63 +1,39 @@
package com.klibisz.elastiknn.jmhbenchmarks

import org.openjdk.jmh.annotations._
import org.apache.lucene.internal.hppc.IntIntHashMap
import org.eclipse.collections.impl.map.mutable.primitive.IntShortHashMap
import com.klibisz.elastiknn.search.ArrayHitCounter
import org.openjdk.jmh.annotations.*
import org.apache.lucene.search.DocIdSetIterator

import scala.util.Random

@State(Scope.Benchmark)
class HitCounterBenchmarksFixtures {
val rng = new Random(0)
val numDocs = 60000
val numHits = 2000
val initialMapSize = 1000
val numHits = 30000
val candidates = 1000
val docs: Array[Int] = (1 to numHits).map(_ => rng.nextInt(numDocs)).toArray
val maxCount = docs.groupBy(identity).keys.max
}

class HitCounterBenchmarks {

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def arrayCountBaseline(f: HitCounterBenchmarksFixtures): Unit = {
val arr = new Array[Int](f.numDocs)
for (d <- f.docs) arr.update(d, arr(d) + 1)
()
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def hashMapGetOrDefault(f: HitCounterBenchmarksFixtures): Unit = {
val h = new java.util.HashMap[Int, Int](f.initialMapSize, 0.99f)
for (d <- f.docs) h.put(d, h.getOrDefault(d, 0) + 1)
()
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def luceneIntIntHashMap(f: HitCounterBenchmarksFixtures): Unit = {
val m = new IntIntHashMap(f.initialMapSize, 0.99d)
for (d <- f.docs) m.putOrAdd(d, 1, 1)
private def consumeDocIdSetIterator(disi: DocIdSetIterator): Unit = {
while (disi.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
val _ = disi.docID()
}
()
}

@Benchmark
@BenchmarkMode(Array(Mode.Throughput))
@Fork(value = 1)
@Warmup(time = 5, iterations = 5)
@Measurement(time = 5, iterations = 5)
def eclipseIntShortHashMapAddToValue(f: HitCounterBenchmarksFixtures): Unit = {
val m = new IntShortHashMap(f.initialMapSize)
for (d <- f.docs) m.addToValue(d, 1)
@Measurement(time = 5, iterations = 10)
def arrayHitCounter(f: HitCounterBenchmarksFixtures): Unit = {
val ahc = new ArrayHitCounter(f.numDocs, f.maxCount)
for (d <- f.docs) ahc.increment(d)
consumeDocIdSetIterator(ahc.docIdSetIterator(f.candidates))
()
}
}
Original file line number Diff line number Diff line change
@@ -1,94 +1,97 @@
package com.klibisz.elastiknn.search;

import jdk.internal.vm.annotation.ForceInline;
import org.apache.lucene.search.DocIdSetIterator;

import java.util.Arrays;

public final class ArrayHitCounter implements HitCounter {

private final short[] counts;
private int numHits;
private int minKey;
private int maxKey;
// Mapping an integer doc ID to the number of times it has occurred.
// E.g., if document 10 has been matched 11 times, then docIdToCount[10] = 11.
private final short[] docIdToCount;

// Mapping an integer count to the number of times it has occurred.
// E.g., if there are 10 docs which have each been matched 11 times, countToCount[11] = 10.
private short[] countToCount;

private int minDocId;
private int maxDocId;

private int maxCount = 0;

public ArrayHitCounter(int numDocs, int expectedMaxCount) {
docIdToCount = new short[numDocs];
countToCount = new short[expectedMaxCount + 1];
minDocId = Integer.MAX_VALUE;
maxDocId = 0;
}

public ArrayHitCounter(int numDocs) {
this(numDocs, 10);
}

@ForceInline
private void incrementKeyByCount(int docId, short count) {
int newCount = (docIdToCount[docId] += count);
if (newCount > maxCount) maxCount = newCount;

// Potentially grow the count arrays.
if (newCount >= countToCount.length) {
countToCount = Arrays.copyOf(countToCount, newCount + 1);
}

// Update the old count.
int oldCount = newCount - count;
if (oldCount > 0) countToCount[oldCount] -= 1;

private short maxValue;
// Update the new count.
countToCount[newCount]++;

public ArrayHitCounter(int capacity) {
counts = new short[capacity];
numHits = 0;
minKey = capacity;
maxKey = 0;
maxValue = 0;
// Update min/max doc IDs.
if (docId < minDocId) minDocId = docId;
if (docId > maxDocId) maxDocId = docId;
}

@Override
public void increment(int key) {
short after = ++counts[key];
if (after == 1) {
numHits++;
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
}
if (after > maxValue) maxValue = after;
incrementKeyByCount(key, (short) 1);
}

@Override
public void increment(int key, short count) {
short after = (counts[key] += count);
if (after == count) {
numHits++;
minKey = Math.min(key, minKey);
maxKey = Math.max(key, maxKey);
}
if (after > maxValue) maxValue = after;
incrementKeyByCount(key, count);
}

@Override
public short get(int key) {
return counts[key];
return docIdToCount[key];
}

@Override
public int capacity() {
return counts.length;
}


private KthGreatestResult kthGreatest(int k) {
// Find the kth greatest document hit count in O(n) time and O(n) space.
// Though the space is typically negligibly small in practice.
// This implementation exploits the fact that we're specifically counting document hit counts.
// Counts are integers, and they're likely to be pretty small, since we're unlikely to match
// the same document many times.

// Start by building a histogram of all counts.
// e.g., if the counts are [0, 4, 1, 1, 2],
// then the histogram is [1, 2, 1, 0, 1],
// because 0 occurs once, 1 occurs twice, 2 occurs once, 3 occurs zero times, and 4 occurs once.
short[] hist = new short[maxValue + 1];
for (short c: counts) hist[c]++;

// Now we start at the max value and iterate backwards through the histogram,
// accumulating counts of counts until we've exceeded k.
int numGreaterEqual = 0;
short kthGreatest = maxValue;

while (true) {
numGreaterEqual += hist[kthGreatest];
if (kthGreatest > 1 && numGreaterEqual < k) kthGreatest--;
else break;
}

// Finally we find the number that were greater than the kth greatest count.
// There's a special case if kthGreatest is zero, then the number that were greater is the number of hits.
int numGreater = numGreaterEqual - hist[kthGreatest];
return new KthGreatestResult(kthGreatest, numGreater);
return docIdToCount.length;
}

@Override
public DocIdSetIterator docIdSetIterator(int candidates) {
if (numHits == 0) return DocIdSetIterator.empty();
if (maxCount == 0) return DocIdSetIterator.empty();
else {

KthGreatestResult kgr = kthGreatest(candidates);
// Loop backwards through countToCount to figure out the minimum count that's required for a
// document to be a candidate.
int kthGreatest = maxCount;
int numGreaterEqual = 0;
while (true) {
numGreaterEqual += countToCount[kthGreatest];
if (kthGreatest > 1 && numGreaterEqual < candidates) kthGreatest--;
else break;
}
// Java seems to want me to do this in order to reuse the values in the class below.
final int finalKthGreatest = kthGreatest;
final int finalMinDocId = minDocId;
final int finalMaxDocId = maxDocId;
final int numGreaterThan = numGreaterEqual - countToCount[kthGreatest];

// Return an iterator over the doc ids >= the min candidate count.
return new DocIdSetIterator() {
Expand All @@ -97,9 +100,15 @@ public DocIdSetIterator docIdSetIterator(int candidates) {
private int docID = -1;
private boolean started = false;

// Track the number of ids emitted, and the number of ids with count = kgr.kthGreatest emitted.
private int numEmitted = 0;
private int numEq = 0;
// Track the number of total IDs emitted.
private int numTotalEmitted = 0;

// The threshold of IDs w/ count = kthGreatest that can be emitted.
private final int numEqThreshold = candidates - numGreaterThan;

// Track the number of IDs w/ count = kthGreatest that have been emitted
private int numEqEmitted = 0;


@Override
public int docID() {
Expand All @@ -111,23 +120,23 @@ public int nextDoc() {

if (!started) {
started = true;
docID = minKey - 1;
docID = finalMinDocId - 1;
}

// Ensure that docs with count = kgr.kthGreatest are only emitted when there are fewer
// than `candidates` docs with count > kgr.kthGreatest.
while (true) {
if (numEmitted == candidates || docID + 1 > maxKey) {
if (numTotalEmitted == candidates || docID + 1 > finalMaxDocId) {
docID = DocIdSetIterator.NO_MORE_DOCS;
return docID;
} else {
docID++;
if (counts[docID] > kgr.kthGreatest) {
numEmitted++;
if (docIdToCount[docID] > finalKthGreatest) {
numTotalEmitted++;
return docID;
} else if (counts[docID] == kgr.kthGreatest && numEq < candidates - kgr.numGreaterThan) {
numEq++;
numEmitted++;
} else if (docIdToCount[docID] == finalKthGreatest && numEqEmitted < numEqThreshold) {
numEqEmitted++;
numTotalEmitted++;
return docID;
}
}
Expand All @@ -142,10 +151,9 @@ public int advance(int target) {

@Override
public long cost() {
return maxKey - minKey;
return finalMaxDocId - finalMinDocId;
}
};
}
}

}
Original file line number Diff line number Diff line change
@@ -1,23 +1,7 @@
package com.klibisz.elastiknn.search;

public class KthGreatestResult {
public final short kthGreatest;
public final int numGreaterThan;
public KthGreatestResult(short kthGreatest, int numGreaterThan) {
this.kthGreatest = kthGreatest;
this.numGreaterThan = numGreaterThan;
}
public record KthGreatestResult(short kthGreatest, int numGreaterThan) {

@Override
public boolean equals(Object o) {
if (o == this) {
return true;
} else if (!(o instanceof KthGreatestResult other)) {
return false;
} else {
return kthGreatest == other.kthGreatest && numGreaterThan == other.numGreaterThan;
}
}

@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ private HitCounter countHits(LeafReader reader) throws IOException {
TermsEnum termsEnum = terms.iterator();
PostingsEnum docs = null;

HitCounter counter = new ArrayHitCounter(reader.maxDoc());
HitCounter counter = new ArrayHitCounter(reader.maxDoc(), hashAndFrequencies.length);
for (HashAndFreq hf : hashAndFrequencies) {
// We take two different paths here, depending on the frequency of the current hash.
// If the frequency is one, we avoid checking the frequency of matching docs when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {
val rng = new Random(seed)
val numDocs = 60000
val numMatches = numDocs / 2
val maxCount = 10
info(s"Using seed $seed")
for (_ <- 0 until 99) {
val matches = (0 until numMatches).map(_ => rng.nextInt(numDocs))
Expand All @@ -103,7 +104,7 @@ final class ArrayHitCounterSpec extends AnyFreeSpec with Matchers {
ref.increment(doc)
ahc.increment(doc)
ahc.get(doc) shouldBe ref.get(doc)
val count = rng.nextInt(10).toShort
val count = rng.nextInt(maxCount).toShort
ref.increment(doc, count)
ahc.increment(doc, count)
ahc.get(doc) shouldBe ref.get(doc)
Expand Down