From 65374d92cbad765643dd1086f45d0c91f5bbf38e Mon Sep 17 00:00:00 2001 From: Nicholas DiPiazza Date: Thu, 8 Oct 2020 17:03:08 -0500 Subject: [PATCH] port of nddipiazza:SOLR-14728 https://github.com/lucidworks/lucene-solr/pull/1 --- .../apache/solr/search/JoinQParserPlugin.java | 21 +- .../solr/search/PostFilterJoinQuery.java | 73 ++--- .../apache/solr/search/TopLevelJoinQuery.java | 269 ++++++++++++++++++ .../join/MultiValueTermOrdinalCollector.java | 65 +++++ 4 files changed, 372 insertions(+), 56 deletions(-) create mode 100644 solr/core/src/java/org/apache/solr/search/TopLevelJoinQuery.java create mode 100644 solr/core/src/java/org/apache/solr/search/join/MultiValueTermOrdinalCollector.java diff --git a/solr/core/src/java/org/apache/solr/search/JoinQParserPlugin.java b/solr/core/src/java/org/apache/solr/search/JoinQParserPlugin.java index 980851a14bc2..ea205382885f 100644 --- a/solr/core/src/java/org/apache/solr/search/JoinQParserPlugin.java +++ b/solr/core/src/java/org/apache/solr/search/JoinQParserPlugin.java @@ -84,8 +84,14 @@ public Query parse() throws SyntaxError { private boolean postFilterEnabled() { return localParams != null && - localParams.getInt(COST) != null && localParams.getPrimitiveInt(COST) > 99 && - localParams.getBool(CACHE) != null && localParams.getPrimitiveBool(CACHE) == false; + localParams.getInt(COST) != null && localParams.getPrimitiveInt(COST) == 100 && + localParams.getBool(CACHE) != null && !localParams.getPrimitiveBool(CACHE); + } + + private boolean topLevelJoinEnabled() { + return localParams != null && + localParams.getInt(COST) != null && localParams.getPrimitiveInt(COST) == 101 && + localParams.getBool(CACHE) != null && !localParams.getPrimitiveBool(CACHE); } Query parseJoin() throws SyntaxError { @@ -131,7 +137,14 @@ Query parseJoin() throws SyntaxError { final String indexToUse = coreName == null ? fromIndex : coreName; - final JoinQuery jq = postFilterEnabled() ? new PostFilterJoinQuery(fromField, toField, indexToUse, fromQuery) : new JoinQuery(fromField, toField, indexToUse, fromQuery); + final JoinQuery jq; + if (postFilterEnabled()) { + jq = new PostFilterJoinQuery(fromField, toField, indexToUse, fromQuery); + } else if (topLevelJoinEnabled()) { + jq = new TopLevelJoinQuery(fromField, toField, indexToUse, fromQuery); + } else { + jq = new JoinQuery(fromField, toField, indexToUse, fromQuery); + } jq.fromCoreOpenTime = fromCoreOpenTime; return jq; } @@ -186,7 +199,7 @@ public Weight createWeight(IndexSearcher searcher, boolean needsScores, float bo return new JoinQueryWeight((SolrIndexSearcher)searcher, boost); } - private class JoinQueryWeight extends ConstantScoreWeight { + protected class JoinQueryWeight extends ConstantScoreWeight { SolrIndexSearcher fromSearcher; RefCounted fromRef; SolrIndexSearcher toSearcher; diff --git a/solr/core/src/java/org/apache/solr/search/PostFilterJoinQuery.java b/solr/core/src/java/org/apache/solr/search/PostFilterJoinQuery.java index 11d65d686007..0863639edf6b 100644 --- a/solr/core/src/java/org/apache/solr/search/PostFilterJoinQuery.java +++ b/solr/core/src/java/org/apache/solr/search/PostFilterJoinQuery.java @@ -20,10 +20,11 @@ import java.io.Closeable; import java.io.IOException; import java.lang.invoke.MethodHandles; + import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.SortedDocValues; +import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; @@ -65,51 +66,36 @@ public DelegatingCollector getFilterCollector(IndexSearcher searcher) { ensureJoinFieldExistsAndHasDocValues(fromSearcher, fromField, "from"); ensureJoinFieldExistsAndHasDocValues(toSearcher, toField, "to"); - final boolean isSameCoreSameField = fromSearcher.equals(toSearcher) && fromField.equals(toField); - final SortedDocValues fromValues = DocValues.getSorted(fromSearcher.getSlowAtomicReader(), fromField); final SortedSetDocValues toValues = DocValues.getSortedSet(toSearcher.getSlowAtomicReader(), toField); ensureDocValuesAreNonEmpty(fromValues, fromField, "from"); ensureDocValuesAreNonEmpty(toValues, toField, "to"); final LongBitSet fromOrdBitSet = new LongBitSet(fromValues.getValueCount()); + final LongBitSet toOrdBitSet = new LongBitSet(toValues.getValueCount()); final TermOrdinalCollector collector = new TermOrdinalCollector(fromField, fromValues, fromOrdBitSet); - - // runs a search using the from field fromSearcher.search(q, collector); - // goes through each result of the "from" result and converts it to a compatible "to" ordinal. - final LongBitSet toOrdBitSet; - boolean matchesAtLeastOneTerm = false; + long fromOrdinal = 0; long firstToOrd = -1; long lastToOrd = 0; - if (!isSameCoreSameField) { - toOrdBitSet = new LongBitSet(toValues.getValueCount()); - long fromOrdinal = 0; - long start = System.currentTimeMillis(); - int count = 0; - while ((fromOrdinal = fromOrdBitSet.nextSetBit(fromOrdinal)) >= 0) { - ++count; - final BytesRef fromBytesRef = fromValues.lookupOrd((int)fromOrdinal); - final long toOrdinal = lookupTerm(toValues, fromBytesRef, lastToOrd);//toValues.lookupTerm(fromBytesRef); - if (toOrdinal >= 0) { - toOrdBitSet.set(toOrdinal); - if (firstToOrd == -1) firstToOrd = toOrdinal; - lastToOrd = toOrdinal; - matchesAtLeastOneTerm = true; - } - fromOrdinal++; + boolean matchesAtLeastOneTerm = false; + long start = System.currentTimeMillis(); + int count = 0; + while ((fromOrdinal = fromOrdBitSet.nextSetBit(fromOrdinal)) >= 0) { + ++count; + final BytesRef fromBytesRef = fromValues.lookupOrd((int)fromOrdinal); + final long toOrdinal = lookupTerm(toValues, fromBytesRef, lastToOrd);//toValues.lookupTerm(fromBytesRef); + if (toOrdinal >= 0) { + toOrdBitSet.set(toOrdinal); + if (firstToOrd == -1) firstToOrd = toOrdinal; + lastToOrd = toOrdinal; + matchesAtLeastOneTerm = true; } - long end = System.currentTimeMillis(); - log.debug("Built the join filter in "+Long.toString(end-start)+" millis, filter term count is "+count); - } else { - matchesAtLeastOneTerm = true; - toOrdBitSet = fromOrdBitSet; - firstToOrd = collector.minOrdinal; - lastToOrd = collector.maxOrdinal; + fromOrdinal++; } - - // at this point, now the toOrdBitSet is set and ready to go + long end = System.currentTimeMillis(); + log.debug("Built the join filter in "+Long.toString(end-start)+" millis, filter term count is "+count); if (matchesAtLeastOneTerm) { return new JoinQueryCollector(toValues, toOrdBitSet, firstToOrd, lastToOrd); } else { @@ -266,8 +252,6 @@ private static class TermOrdinalCollector extends DelegatingCollector { private SortedDocValues topLevelDocValues; private final String fieldName; private final LongBitSet topLevelDocValuesBitSet; - private long minOrdinal = Long.MAX_VALUE; - private long maxOrdinal = Long.MIN_VALUE; public TermOrdinalCollector(String fieldName, SortedDocValues topLevelDocValues, LongBitSet topLevelDocValuesBitSet) { this.fieldName = fieldName; @@ -285,14 +269,6 @@ public boolean needsScores(){ return false; } - public long getMinOrdinal() { - return minOrdinal; - } - - public long getMaxOrdinal() { - return maxOrdinal; - } - @Override public void doSetNextReader(LeafReaderContext context) throws IOException { this.docBase = context.docBase; @@ -303,13 +279,6 @@ public void collect(int doc) throws IOException { final int globalDoc = docBase + doc; if (topLevelDocValues.advanceExact(globalDoc)) { // TODO The use of advanceExact assumes collect() is called in increasing docId order. Is that true? - if (topLevelDocValues.ordValue() < minOrdinal) { - minOrdinal = topLevelDocValues.ordValue(); - } - if (topLevelDocValues.ordValue() > maxOrdinal) { - maxOrdinal = topLevelDocValues.ordValue(); - } - topLevelDocValuesBitSet.set(topLevelDocValues.ordValue()); } } @@ -349,8 +318,8 @@ public void collect(int doc) throws IOException { while (true) { final long ord = topLevelDocValues.nextOrd(); if (ord == SortedSetDocValues.NO_MORE_ORDS) break; - if (lastOrd != -1 && ord > lastOrd) break; - if (firstOrd != -1 || ord < firstOrd) continue; + if (ord > lastOrd) break; + if (ord < firstOrd) continue; if (topLevelDocValuesBitSet.get(ord)) { leafCollector.collect(doc); break; diff --git a/solr/core/src/java/org/apache/solr/search/TopLevelJoinQuery.java b/solr/core/src/java/org/apache/solr/search/TopLevelJoinQuery.java new file mode 100644 index 000000000000..bb3fae1b15d1 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/TopLevelJoinQuery.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.search; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.Optional; + +import org.apache.lucene.index.DocValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.LongBitSet; +import org.apache.solr.common.SolrException; +import org.apache.solr.schema.IndexSchema; +import org.apache.solr.schema.SchemaField; +import org.apache.solr.search.join.MultiValueTermOrdinalCollector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link JoinQuery} implementation using global (top-level) DocValues ordinals to efficiently compare values in the + * "from" and "to" fields. + */ +public class TopLevelJoinQuery extends JoinQuery { + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + public TopLevelJoinQuery(String fromField, String toField, String coreName, Query subQuery) { + super(fromField, toField, coreName, subQuery); + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { + if (! (searcher instanceof SolrIndexSearcher)) { + log.debug("Falling back to JoinQueryWeight because searcher [{}] is not the required SolrIndexSearcher", searcher); + return super.createWeight(searcher, needsScores, boost); + } + + final SolrIndexSearcher solrSearcher = (SolrIndexSearcher) searcher; + final JoinQueryWeight weight = new JoinQueryWeight(solrSearcher, 1.0f); + final SolrIndexSearcher fromSearcher = weight.fromSearcher; + final SolrIndexSearcher toSearcher = weight.toSearcher; + final boolean isSameCoreSameField = fromSearcher.equals(toSearcher) && fromField.equals(toField); + + try { + final SortedSetDocValues topLevelFromDocValues = validateAndFetchDocValues(fromSearcher, fromField, "from"); + final SortedSetDocValues topLevelToDocValues = validateAndFetchDocValues(toSearcher, toField, "to"); + if (topLevelFromDocValues.getValueCount() == 0 || topLevelToDocValues.getValueCount() == 0) { + return createNoMatchesWeight(boost); + } + + final LongBitSet fromOrdBitSet = findFieldOrdinalsMatchingQuery(q, fromField, fromSearcher, topLevelFromDocValues); + final Optional toBitsetBounds; + final LongBitSet toOrdBitSet; + if (isSameCoreSameField) { + // When the "to" searcher == "from" searcher, and the "to" field == the "from" field, + // we do not need to convert the ordinals on the "to" field to match the ordinals on the "from" field + // because they are the same ordinals. + toOrdBitSet = fromOrdBitSet; + toBitsetBounds = Optional.empty(); + } else { + toOrdBitSet = new LongBitSet(topLevelToDocValues.getValueCount()); + toBitsetBounds = Optional.of(convertFromOrdinalsIntoToField(fromOrdBitSet, topLevelFromDocValues, toOrdBitSet, topLevelToDocValues)); + } + + final boolean toMultivalued = toSearcher.getSchema().getFieldOrNull(toField).multiValued(); + return new ConstantScoreWeight(this, boost) { + public Scorer scorer(LeafReaderContext context) throws IOException { + if (toBitsetBounds.isPresent() && toBitsetBounds.get().lower == BitsetBounds.NO_MATCHES) { + return null; + } + + final DocIdSetIterator toApproximation = (toMultivalued) ? context.reader().getSortedSetDocValues(toField) : + context.reader().getSortedDocValues(toField); + if (toApproximation == null) { + return null; + } + + final int docBase = context.docBase; + return new ConstantScoreScorer(this, this.score(), new TwoPhaseIterator(toApproximation) { + public boolean matches() throws IOException { + final boolean hasDoc = topLevelToDocValues.advanceExact(docBase + approximation.docID()); + if (hasDoc) { + for (long ord = topLevelToDocValues.nextOrd(); ord != -1L; ord = topLevelToDocValues.nextOrd()) { + if (toOrdBitSet.get(ord)) { + return true; + } + } + } + return false; + } + + public float matchCost() { + return 10.0F; + } + }); + + } + + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private Weight createNoMatchesWeight(float boost) { + return new ConstantScoreWeight(this, boost) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + }; + } + + private SortedSetDocValues validateAndFetchDocValues(SolrIndexSearcher solrSearcher, String fieldName, String querySide) throws IOException { + final IndexSchema schema = solrSearcher.getSchema(); + final SchemaField field = schema.getFieldOrNull(fieldName); + if (field == null) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, querySide + " field '" + fieldName + "' does not exist"); + } + + if (!field.hasDocValues()) { + throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, + "'top-level' join queries require both 'from' and 'to' fields to have docValues, but " + querySide + + " field [" + fieldName + "] does not."); + } + + final LeafReader leafReader = solrSearcher.getSlowAtomicReader(); + if (field.multiValued()) { + return DocValues.getSortedSet(leafReader, fieldName); + } + return DocValues.singleton(DocValues.getSorted(leafReader, fieldName)); + } + + private static LongBitSet findFieldOrdinalsMatchingQuery(Query q, String field, SolrIndexSearcher searcher, SortedSetDocValues docValues) throws IOException { + final LongBitSet fromOrdBitSet = new LongBitSet(docValues.getValueCount()); + final Collector fromCollector = new MultiValueTermOrdinalCollector(field, docValues, fromOrdBitSet); + + searcher.search(q, fromCollector); + + return fromOrdBitSet; + } + + /** + * When storing ordinals, the ordinal -> DocValue lookup is different per-searcher, and per-core. + * So when doing a join, if the "To" searcher is not the same as the "From" searcher, and/or if the + * "To" field is not the same as the "From" field, we will need to convert the "to" ordinals so that + * they match the "from" ordinals. + * + * For example: If we have FromCore=Core1, FromField=MyStatus + * + *
+   * Ordinal | Term for field "MyStatus"
+   * -------------------
+   * 0       | status_deleted
+   * 1       | status_published
+   * 
+ *

+ * And then if we ToCore=Core2, ToField=SomeOtherStatus + * + *

+   * Ordinal | Term for field "SomeOtherStatus"
+   * -------------------
+   * 0       | status_cancel
+   * 1       | status_deleted
+   * 2       | status_published
+   * 
+ *

+ * For the term "status_published", the ordinal on "from" is 1, and the ordinal for "to" is 2. + *

+ * This method will iterate through each ordinal and populate a "to" ordinal bit set that is compatible with the + * "from" ordinal bit set. + * + * @param fromOrdBitSet The from ordinal bit set. + * @param fromDocValues The from doc values. + * @param toOrdBitSet The to ordinal bit set. + * @param toDocValues The to doc values. + * @return A boundary of the first + last to ordinal found in the mapping. + */ + private BitsetBounds convertFromOrdinalsIntoToField(LongBitSet fromOrdBitSet, SortedSetDocValues fromDocValues, + LongBitSet toOrdBitSet, SortedSetDocValues toDocValues) throws IOException { + long fromOrdinal = 0; + long firstToOrd = BitsetBounds.NO_MATCHES; + long lastToOrd = 0; + + while (fromOrdinal < fromOrdBitSet.length() && (fromOrdinal = fromOrdBitSet.nextSetBit(fromOrdinal)) >= 0) { + final BytesRef fromBytesRef = fromDocValues.lookupOrd(fromOrdinal); + final long toOrdinal = lookupTerm(toDocValues, fromBytesRef, lastToOrd); + if (toOrdinal >= 0) { + toOrdBitSet.set(toOrdinal); + if (firstToOrd == BitsetBounds.NO_MATCHES) firstToOrd = toOrdinal; + lastToOrd = toOrdinal; + } + fromOrdinal++; + } + + return new BitsetBounds(firstToOrd, lastToOrd); + } + + /* + * Same binary-search based implementation as SortedSetDocValues.lookupTerm(BytesRef), but with an + * optimization to narrow the search space where possible by providing a startOrd instead of beginning each search + * at 0. + */ + private long lookupTerm(SortedSetDocValues docValues, BytesRef key, long startOrd) throws IOException { + long low = startOrd; + long high = docValues.getValueCount()-1; + + while (low <= high) { + long mid = (low + high) >>> 1; + final BytesRef term = docValues.lookupOrd(mid); + int cmp = term.compareTo(key); + + if (cmp < 0) { + low = mid + 1; + } else if (cmp > 0) { + high = mid - 1; + } else { + return mid; // key found + } + } + + return -(low + 1); // key not found. + } + + private static class BitsetBounds { + public static final long NO_MATCHES = -1L; + public final long lower; + public final long upper; + + public BitsetBounds(long lower, long upper) { + this.lower = lower; + this.upper = upper; + } + } +} diff --git a/solr/core/src/java/org/apache/solr/search/join/MultiValueTermOrdinalCollector.java b/solr/core/src/java/org/apache/solr/search/join/MultiValueTermOrdinalCollector.java new file mode 100644 index 000000000000..9bb9295a3fc0 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/search/join/MultiValueTermOrdinalCollector.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.search.join; + +import java.io.IOException; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.SimpleCollector; +import org.apache.lucene.util.LongBitSet; + +/** + * Populates a bitset of (top-level) ordinals based on field values in a multi-valued field. + */ +public class MultiValueTermOrdinalCollector extends SimpleCollector { + + private int docBase; + private SortedSetDocValues topLevelDocValues; + private final String fieldName; + // Records all ordinals found during collection + private final LongBitSet topLevelDocValuesBitSet; + + public MultiValueTermOrdinalCollector(String fieldName, SortedSetDocValues topLevelDocValues, LongBitSet topLevelDocValuesBitSet) { + this.fieldName = fieldName; + this.topLevelDocValues = topLevelDocValues; + this.topLevelDocValuesBitSet = topLevelDocValuesBitSet; + } + + @Override + public boolean needsScores() { + return false; + } + + @Override + public void doSetNextReader(LeafReaderContext context) throws IOException { + this.docBase = context.docBase; + } + + @Override + public void collect(int doc) throws IOException { + final int globalDoc = docBase + doc; + + if (topLevelDocValues.advanceExact(globalDoc)) { + long ord = SortedSetDocValues.NO_MORE_ORDS; + while ((ord = topLevelDocValues.nextOrd()) != SortedSetDocValues.NO_MORE_ORDS) { + topLevelDocValuesBitSet.set(ord); + } + } + } +}