Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal sanity checks to custom/scripted similarities. #33564

Merged
merged 5 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@

package org.elasticsearch.index.similarity;

import org.apache.logging.log4j.LogManager;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.AbstractIndexComponent;
import org.elasticsearch.index.IndexModule;
Expand All @@ -44,7 +51,7 @@

public final class SimilarityService extends AbstractIndexComponent {

private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic";
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
Expand Down Expand Up @@ -132,6 +139,7 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
providers.put(name, () -> similarity);
}
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
Expand Down Expand Up @@ -182,4 +190,79 @@ public Similarity get(String name) {
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
}
}

static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
}

private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < 0) {
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
}
}

private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
float previousScore = 0;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < previousScore) {
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
previousScore = score;
}
}

private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);

long previousNorm = 0;
float previousScore = Float.MAX_VALUE;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
float score = scorer.score(1, norm);
if (score > previousScore) {
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
scorer.explain(Explanation.match(1, "term freq"), norm));
}
previousScore = score;
previousNorm = norm;
}
}

private static void fail(Version indexCreatedVersion, String message) {
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
throw new IllegalArgumentException(message);
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
DEPRECATION_LOGGER.deprecated(message);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
*/
package org.elasticsearch.index.similarity;

import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.IndexSettingsModule;
import org.hamcrest.Matchers;

import java.util.Collections;

Expand Down Expand Up @@ -56,4 +62,75 @@ public void testOverrideDefaultSimilarity() {
SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap());
assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity);
}

public void testSimilarityValidation() {
Similarity negativeScoresSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {

@Override
public float score(float freq, long norm) {
return -1;
}

};
}
};
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, negativeScoresSim));
assertThat(e.getMessage(), Matchers.containsString("Similarities should not return negative scores"));

Similarity decreasingScoresWithFreqSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {

@Override
public float score(float freq, long norm) {
return 1 / (freq + norm);
}

};
}
};
e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, decreasingScoresWithFreqSim));
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not decrease when term frequency increases"));

Similarity increasingScoresWithNormSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {

@Override
public float score(float freq, long norm) {
return freq + norm;
}

};
}
};
e = expectThrows(IllegalArgumentException.class,
() -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, increasingScoresWithNormSim));
assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not increase when norm increases"));
}
}