Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CombinedFieldQuery (Lucene 9999) #74857

Merged
merged 1 commit into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modifications copyright (C) 2020 Elasticsearch B.V.
*/

package org.apache.lucene.search;

import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.XCombinedFieldQuery.FieldAndWeight;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.SmallFloat;

Expand All @@ -28,134 +30,148 @@
import java.util.List;
import java.util.Objects;

import static org.apache.lucene.search.XCombinedFieldQuery.FieldAndWeight;

/**
* Copy of {@link LeafSimScorer} that sums document's norms from multiple fields.
* Copy of {@link MultiNormsLeafSimScorer} that contains a fix for LUCENE-9999.
* TODO: remove once LUCENE-9999 is fixed and integrated
*
* TODO: this is temporarily copied from Lucene, remove once we update to Lucene 8.9.
* <p>This scorer requires that either all fields or no fields have norms enabled. It will throw an
* error if some fields have norms enabled, while others have norms disabled.
*/
final class XMultiNormsLeafSimScorer {
/**
* Cache of decoded norms.
*/
private static final float[] LENGTH_TABLE = new float[256];

static {
for (int i = 0; i < 256; i++) {
LENGTH_TABLE[i] = SmallFloat.byte4ToInt((byte) i);
}
}
/** Cache of decoded norms. */
private static final float[] LENGTH_TABLE = new float[256];

private final SimScorer scorer;
private final NumericDocValues norms;

/**
* Sole constructor: Score documents of {@code reader} with {@code scorer}.
*
*/
XMultiNormsLeafSimScorer(SimScorer scorer,
LeafReader reader,
Collection<FieldAndWeight> normFields,
boolean needsScores) throws IOException {
this.scorer = Objects.requireNonNull(scorer);
if (needsScores) {
final List<NumericDocValues> normsList = new ArrayList<>();
final List<Float> weightList = new ArrayList<>();
for (FieldAndWeight field : normFields) {
NumericDocValues norms = reader.getNormValues(field.field);
if (norms != null) {
normsList.add(norms);
weightList.add(field.weight);
}
}
if (normsList.isEmpty()) {
norms = null;
} else if (normsList.size() == 1) {
norms = normsList.get(0);
} else {
final NumericDocValues[] normsArr = normsList.toArray(new NumericDocValues[0]);
final float[] weightArr = new float[normsList.size()];
for (int i = 0; i < weightList.size(); i++) {
weightArr[i] = weightList.get(i);
}
norms = new XMultiNormsLeafSimScorer.MultiFieldNormValues(normsArr, weightArr);
}
} else {
norms = null;
}
static {
for (int i = 0; i < 256; i++) {
LENGTH_TABLE[i] = SmallFloat.byte4ToInt((byte) i);
}

private long getNormValue(int doc) throws IOException {
}

private final SimScorer scorer;
private final NumericDocValues norms;

/** Sole constructor: Score documents of {@code reader} with {@code scorer}. */
XMultiNormsLeafSimScorer(
SimScorer scorer,
LeafReader reader,
Collection<FieldAndWeight> normFields,
boolean needsScores)
throws IOException {
this.scorer = Objects.requireNonNull(scorer);
if (needsScores) {
final List<NumericDocValues> normsList = new ArrayList<>();
final List<Float> weightList = new ArrayList<>();
for (FieldAndWeight field : normFields) {
NumericDocValues norms = reader.getNormValues(field.field);
if (norms != null) {
boolean found = norms.advanceExact(doc);
assert found;
return norms.longValue();
} else {
return 1L; // default norm
normsList.add(norms);
weightList.add(field.weight);
}
}

if (normsList.isEmpty() == false && normsList.size() != normFields.size()) {
throw new IllegalArgumentException(
getClass().getSimpleName()
+ " requires norms to be consistent across fields: some fields cannot"
+ " have norms enabled, while others have norms disabled");
}

if (normsList.isEmpty()) {
norms = null;
} else if (normsList.size() == 1) {
norms = normsList.get(0);
} else {
final NumericDocValues[] normsArr = normsList.toArray(new NumericDocValues[0]);
final float[] weightArr = new float[normsList.size()];
for (int i = 0; i < weightList.size(); i++) {
weightArr[i] = weightList.get(i);
}
norms = new MultiFieldNormValues(normsArr, weightArr);
}
} else {
norms = null;
}

/** Score the provided document assuming the given term document frequency.
* This method must be called on non-decreasing sequences of doc ids.
* @see SimScorer#score(float, long) */
public float score(int doc, float freq) throws IOException {
return scorer.score(freq, getNormValue(doc));
}

private long getNormValue(int doc) throws IOException {
if (norms != null) {
boolean found = norms.advanceExact(doc);
assert found;
return norms.longValue();
} else {
return 1L; // default norm
}

/** Explain the score for the provided document assuming the given term document frequency.
* This method must be called on non-decreasing sequences of doc ids.
* @see SimScorer#explain(Explanation, long) */
public Explanation explain(int doc, Explanation freqExpl) throws IOException {
return scorer.explain(freqExpl, getNormValue(doc));
}

/**
* Score the provided document assuming the given term document frequency. This method must be
* called on non-decreasing sequences of doc ids.
*
* @see SimScorer#score(float, long)
*/
public float score(int doc, float freq) throws IOException {
return scorer.score(freq, getNormValue(doc));
}

/**
* Explain the score for the provided document assuming the given term document frequency. This
* method must be called on non-decreasing sequences of doc ids.
*
* @see SimScorer#explain(Explanation, long)
*/
public Explanation explain(int doc, Explanation freqExpl) throws IOException {
return scorer.explain(freqExpl, getNormValue(doc));
}

private static class MultiFieldNormValues extends NumericDocValues {
private final NumericDocValues[] normsArr;
private final float[] weightArr;
private long current;
private int docID = -1;

MultiFieldNormValues(NumericDocValues[] normsArr, float[] weightArr) {
this.normsArr = normsArr;
this.weightArr = weightArr;
}

private static class MultiFieldNormValues extends NumericDocValues {
private final NumericDocValues[] normsArr;
private final float[] weightArr;
private long current;
private int docID = -1;

MultiFieldNormValues(NumericDocValues[] normsArr, float[] weightArr) {
this.normsArr = normsArr;
this.weightArr = weightArr;
}

@Override
public long longValue() {
return current;
}
@Override
public long longValue() {
return current;
}

@Override
public boolean advanceExact(int target) throws IOException {
float normValue = 0;
for (int i = 0; i < normsArr.length; i++) {
boolean found = normsArr[i].advanceExact(target);
assert found;
normValue += weightArr[i] * LENGTH_TABLE[Byte.toUnsignedInt((byte) normsArr[i].longValue())];
}
current = SmallFloat.intToByte4(Math.round(normValue));
return true;
@Override
public boolean advanceExact(int target) throws IOException {
float normValue = 0;
boolean found = false;
for (int i = 0; i < normsArr.length; i++) {
if (normsArr[i].advanceExact(target)) {
normValue +=
weightArr[i] * LENGTH_TABLE[Byte.toUnsignedInt((byte) normsArr[i].longValue())];
found = true;
}
}
current = SmallFloat.intToByte4(Math.round(normValue));
return found;
}

@Override
public int docID() {
return docID;
}
@Override
public int docID() {
return docID;
}

@Override
public int nextDoc() {
throw new UnsupportedOperationException();
}
@Override
public int nextDoc() {
throw new UnsupportedOperationException();
}

@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}

@Override
public long cost() {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
throw new UnsupportedOperationException();
}
}
}
Loading