From a222e8bbc6b90d15b6718091b61a017a9548a0fc Mon Sep 17 00:00:00 2001 From: Ketan Verma Date: Mon, 22 May 2023 09:50:33 +0530 Subject: [PATCH] Improvements to LongHash Signed-off-by: Ketan Verma --- .../opensearch/common/LongHashBenchmark.java | 153 ++++++++++++++++++ .../opensearch/common/util/LongRHHash.java | 153 ++++++++++++++++++ .../bucket/terms/LongKeyedBucketOrds.java | 6 +- .../common/util/LongRHHashTests.java | 63 ++++++++ 4 files changed, 372 insertions(+), 3 deletions(-) create mode 100644 benchmarks/src/main/java/org/opensearch/common/LongHashBenchmark.java create mode 100644 server/src/main/java/org/opensearch/common/util/LongRHHash.java create mode 100644 server/src/test/java/org/opensearch/common/util/LongRHHashTests.java diff --git a/benchmarks/src/main/java/org/opensearch/common/LongHashBenchmark.java b/benchmarks/src/main/java/org/opensearch/common/LongHashBenchmark.java new file mode 100644 index 0000000000000..ab17bde3e7507 --- /dev/null +++ b/benchmarks/src/main/java/org/opensearch/common/LongHashBenchmark.java @@ -0,0 +1,153 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.LongHash; +import org.opensearch.common.util.LongRHHash; +import org.opensearch.common.util.PageCacheRecycler; + +import java.util.Arrays; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@Fork(3) +@Warmup(iterations = 0) +@Measurement(iterations = 3) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class LongHashBenchmark { + + private static final int ITERATIONS = (1 << 24); + + private static final BigArrays BIG_ARRAYS = new BigArrays( + new PageCacheRecycler(Settings.EMPTY), + null, + "REQUEST" + ); + + @State(Scope.Benchmark) + public static class BenchmarkState { + @Param({ + "1", "2", "3", "6", "8", "10", "13", "14", "15", "18", "20", "24", "27", "31", "35", "38", "40", "42", "46", + "48", "53", "60", "68", "73", "77", "81", "86", "93", "100", "104", "109", "118", "130", "136", "149", "162", + "168", "174", "186", "201", "210", "229", "248", "261", "274", "284", "296", "313", "336", "352", "377", + "388", "408", "431", "449", "469", "488", "502", "520", "535", "545", "607", "647", "733", "764", "826", + "866", "881", "953", "968", "1017", "1032", "1078", "1132", "1196", "1209", "1263", "1358", "1390", "1444", + "1552", "2304", "2466", "3175", "3675", "4102", "4635", "4751", "5469", "6187", "6935", "7072", "7404", + "7808", "8336", "9037", "9154", "9844", "9996", "10287" + }) + public int size; + + @Param({"1", "128", "1024"}) + public long capacity; + + @Param({"0.6"}) + public float loadFactor; + + public long[] data; + public long[] queries; + + // Pre-populated instances to benchmark the 'find' method. + public LongHash lh; + public LongRHHash rh; + + @Setup(Level.Trial) + public void setUp() { + data = nyc_taxis(); + // data = http_logs(); + // data = random(); + + // Fisher-Yates shuffle. + // This will avoid hitting the cache lines which would otherwise unfairly favour + // the naive linear-probing algorithm. + Random random = new Random(0); + queries = Arrays.copyOf(data, size); + for (int i = size - 1; i > 0; i--) { + int j = Math.abs(random.nextInt()) % (i + 1); + long temp = queries[i]; + queries[i] = queries[j]; + queries[j] = temp; + } + + lh = new LongHash(capacity, loadFactor, BIG_ARRAYS); + rh = new LongRHHash(capacity, loadFactor, BIG_ARRAYS); + for (int i = 0; i < data.length * 2; i++) { + lh.add(data[i % data.length]); + rh.add(data[i % data.length]); + } + } + + private long[] nyc_taxis() { + long[] data = new long[size]; + for (int i = 0; i < size; i++) { + data[i] = 1420070400000L + 86400000L * i; + } + return data; + } + + private long[] http_logs() { + long[] data = new long[size]; + for (int i = 0; i < size; i++) { + data[i] = 893962800000L + 3600000L * i; + } + return data; + } + + private long[] random() { + Random random = new Random(0); + long[] data = new long[size]; + for (int i = 0; i < size; i++) { + data[i] = random.nextLong(); + } + return data; + } + } + + /* Benchmarks for the 'add' method. */ + @Benchmark + public void baselineAdd(Blackhole bh, BenchmarkState s) { + try (LongHash h = new LongHash(s.capacity, s.loadFactor, BIG_ARRAYS)) { + for (int i = 0; i < ITERATIONS; i++) { + long key = s.queries[i % s.queries.length]; + bh.consume(h.add(key)); + } + } + } + + @Benchmark + public void contenderAdd(Blackhole bh, BenchmarkState s) { + try (LongRHHash h = new LongRHHash(s.capacity, s.loadFactor, BIG_ARRAYS)) { + for (int i = 0; i < ITERATIONS; i++) { + long key = s.queries[i % s.queries.length]; + bh.consume(h.add(key)); + } + } + } + + /* Benchmarks for the 'find' method. */ + @Benchmark + public void baselineFind(Blackhole bh, BenchmarkState s) { + for (int i = 0; i < ITERATIONS; i++) { + long key = s.queries[i % s.queries.length]; + bh.consume(s.lh.find(key)); + } + } + + @Benchmark + public void contenderFind(Blackhole bh, BenchmarkState s) { + for (int i = 0; i < ITERATIONS; i++) { + long key = s.queries[i % s.queries.length]; + bh.consume(s.rh.find(key)); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/util/LongRHHash.java b/server/src/main/java/org/opensearch/common/util/LongRHHash.java new file mode 100644 index 0000000000000..9b827ad6664ef --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/LongRHHash.java @@ -0,0 +1,153 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import com.carrotsearch.hppc.BitMixer; +import org.opensearch.common.lease.Releasable; + + +public class LongRHHash implements Releasable { + private final float loadFactor; + private final BigArrays bigArrays; + + private long capacity; + private long mask; + private long size; + private long grow; + + private LongArray ords; + private LongArray keys; + + public LongRHHash(long capacity, float loadFactor, BigArrays bigArrays) { + capacity = (long) (capacity / loadFactor); + capacity = Math.max(1, Long.highestOneBit(capacity - 1) << 1); // next power of two + + this.loadFactor = loadFactor; + this.bigArrays = bigArrays; + this.capacity = capacity; + this.mask = capacity - 1; + this.size = 0; + this.grow = (long) (capacity * loadFactor); + + this.ords = bigArrays.newLongArray(capacity, false); + this.ords.fill(0, capacity, -1); + this.keys = bigArrays.newLongArray(capacity, false); + } + + public long add(final long key) { + final long found = find(key); + if (found != -1) { + return -(1 + found); + } + + if (size >= grow) { + grow(); + } + + return set(key, size); + } + + private long set(final long key, final long ordinal) { + long idx = slot(key), ord = ordinal, psl = 0; + long curOrd, curPsl; + + do { + if ((curOrd = ords.get(idx)) == -1) { + ords.set(idx, ord); + keys = bigArrays.grow(keys, size + 1); + keys.set(ordinal, key); + return size++; + } else if ((curPsl = psl(keys.get(curOrd), idx)) < psl) { + ord = ords.set(idx, ord); + psl = curPsl; + } + idx = (idx + 1) & mask; + psl++; + } while (true); + } + + public long get(final long ordinal) { + return keys.get(ordinal); + } + + public long find(final long key) { + for (long idx = slot(key);; idx = (idx + 1) & mask) { + final long ord = ords.get(idx); + if (ord == -1 || keys.get(ord) == key) { + return ord; + } + } + } + + private long slot(final long key) { + return BitMixer.mix64(key) & mask; + } + + private long psl(final long key, final long idx) { + return (capacity + idx - slot(key)) & mask; + } + + public long size() { + return size; + } + + public long maxPsl() { + long maxPsl = 0; + + for (long idx = 0; idx < capacity; idx++) { + long ordinal = ords.get(idx); + if (ordinal == -1) { + continue; + } + + long key = keys.get(ordinal); + maxPsl = Math.max(maxPsl, psl(key, idx)); + } + + return maxPsl; + } + + public double avgPsl() { + long pslSum = 0; + + for (long idx = 0; idx < capacity; idx++) { + long ordinal = ords.get(idx); + if (ordinal == -1) { + continue; + } + + long key = keys.get(ordinal); + pslSum += psl(key, idx); + } + + return (double) pslSum / size; + } + + private void grow() { + final long oldSize = size; + + capacity <<= 1; + mask = capacity - 1; + size = 0; + grow = (long) (capacity * loadFactor); + + ords = bigArrays.resize(ords, capacity); + ords.fill(0, capacity, -1); + + for (long ordinal = 0; ordinal < oldSize; ordinal++) { + set(keys.get(ordinal), ordinal); + } + } + + @Override + public void close() { + ords.close(); + keys.close(); + } +} diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongKeyedBucketOrds.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongKeyedBucketOrds.java index bcf77ee194ea4..a357ddcd0e3e6 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongKeyedBucketOrds.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/LongKeyedBucketOrds.java @@ -34,8 +34,8 @@ import org.opensearch.common.lease.Releasable; import org.opensearch.common.util.BigArrays; -import org.opensearch.common.util.LongHash; import org.opensearch.common.util.LongLongHash; +import org.opensearch.common.util.LongRHHash; import org.opensearch.search.aggregations.CardinalityUpperBound; /** @@ -148,10 +148,10 @@ public long value() { * @opensearch.internal */ public static class FromSingle extends LongKeyedBucketOrds { - private final LongHash ords; + private final LongRHHash ords; public FromSingle(BigArrays bigArrays) { - ords = new LongHash(1, bigArrays); + ords = new LongRHHash(1, 0.6f, bigArrays); } @Override diff --git a/server/src/test/java/org/opensearch/common/util/LongRHHashTests.java b/server/src/test/java/org/opensearch/common/util/LongRHHashTests.java new file mode 100644 index 0000000000000..46675ad5da645 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/util/LongRHHashTests.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Map; +import java.util.TreeMap; + +public class LongRHHashTests extends OpenSearchTestCase { + + public void testFuzzy() { + Map reference = new TreeMap<>(); + LongRHHash h = new LongRHHash(1, 0.6f, BigArrays.NON_RECYCLING_INSTANCE); + + for (int i = 0; i < (1 << 20); i++) { + long key = randomLong() % (1 << 12); + if (reference.containsKey(key)) { + long expectedOrdinal = reference.get(key); + assertEquals(-1-expectedOrdinal, h.add(key)); + assertEquals(expectedOrdinal, h.find(key)); + } else { + assertEquals(-1, h.find(key)); + reference.put(key, (long) reference.size()); + assertEquals((long) reference.get(key), h.add(key)); + } + } + + h.close(); + } + + public void testReport() { + int[] sizes = new int[]{ + 1, 2, 3, 6, 8, 10, 13, 14, 15, 18, 20, 24, 27, 31, 35, 38, 40, 42, 46, + 48, 53, 60, 68, 73, 77, 81, 86, 93, 100, 104, 109, 118, 130, 136, 149, 162, + 168, 174, 186, 201, 210, 229, 248, 261, 274, 284, 296, 313, 336, 352, 377, + 388, 408, 431, 449, 469, 488, 502, 520, 535, 545, 607, 647, 733, 764, 826, + 866, 881, 953, 968, 1017, 1032, 1078, 1132, 1196, 1209, 1263, 1358, 1390, 1444, + 1552, 2304, 2466, 3175, 3675, 4102, 4635, 4751, 5469, 6187, 6935, 7072, 7404, + 7808, 8336, 9037, 9154, 9844, 9996, 10287 + }; + + for (int size : sizes) { + try (LongRHHash h = new LongRHHash(1, 0.6f, BigArrays.NON_RECYCLING_INSTANCE)) { + for (int i = 0; i < size; i++) { + long key = 1420070400000L + 86400000L * i; + h.add(key); + } + for (int i = 0; i < size; i++) { + long key = 1420070400000L + 86400000L * i; + h.add(key); + } + System.out.println("size: " + size + ", max_psl: " + h.maxPsl() + ", avg_psl: " + h.avgPsl()); + } + } + } +}