Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal sanity checks to custom/scripted similarities. #33564

Merged
merged 5 commits into from
Sep 19, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.index.similarity;

import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;

/**
* A {@link Similarity} that rejects negative scores. This class exists so that users get
* an error instead of silently corrupt top hits. It should be applied to any custom or
* scripted similarity.
*/
// public for testing
public final class NonNegativeScoresSimilarity extends Similarity {

// Escape hatch
private static final String ES_ENFORCE_POSITIVE_SCORES = "es.enforce.positive.scores";
private static final boolean ENFORCE_POSITIVE_SCORES;
static {
String enforcePositiveScores = System.getProperty(ES_ENFORCE_POSITIVE_SCORES);
if (enforcePositiveScores == null) {
ENFORCE_POSITIVE_SCORES = true;
} else if ("false".equals(enforcePositiveScores)) {
ENFORCE_POSITIVE_SCORES = false;
} else {
throw new IllegalArgumentException(ES_ENFORCE_POSITIVE_SCORES + " may only be unset or set to [false], but got [" +
enforcePositiveScores + "]");
}
}

private final Similarity in;

public NonNegativeScoresSimilarity(Similarity in) {
this.in = in;
}

public Similarity getDelegate() {
return in;
}

@Override
public long computeNorm(FieldInvertState state) {
return in.computeNorm(state);
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
final SimScorer inScorer = in.scorer(boost, collectionStats, termStats);
return new SimScorer() {

@Override
public float score(float freq, long norm) {
float score = inScorer.score(freq, norm);
if (score < 0f) {
if (ENFORCE_POSITIVE_SCORES) {
throw new IllegalArgumentException("Similarities must not produce negative scores, but got:\n" +
inScorer.explain(Explanation.match(freq, "term frequency"), norm));
} else {
return 0f;
}
}
return score;
}

@Override
public Explanation explain(Explanation freq, long norm) {
Explanation expl = inScorer.explain(freq, norm);
if (expl.isMatch() && expl.getValue().floatValue() < 0) {
expl = Explanation.match(0f, "max of:",
expl, Explanation.match(0f, "Minimum allowed score"));
}
return expl;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@

package org.elasticsearch.index.similarity;

import org.apache.logging.log4j.LogManager;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.BM25Similarity;
import org.apache.lucene.search.similarities.BooleanSimilarity;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.AbstractIndexComponent;
import org.elasticsearch.index.IndexModule;
Expand All @@ -44,7 +51,7 @@

public final class SimilarityService extends AbstractIndexComponent {

private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class));
private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class));
public static final String DEFAULT_SIMILARITY = "BM25";
private static final String CLASSIC_SIMILARITY = "classic";
private static final Map<String, Function<Version, Supplier<Similarity>>> DEFAULTS;
Expand Down Expand Up @@ -131,8 +138,14 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
}
TriFunction<Settings, Version, ScriptService, Similarity> defaultFactory = BUILT_IN.get(typeName);
TriFunction<Settings, Version, ScriptService, Similarity> factory = similarities.getOrDefault(typeName, defaultFactory);
final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
providers.put(name, () -> similarity);
Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService);
validateSimilarity(indexSettings.getIndexVersionCreated(), similarity);
if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) {
// We don't trust custom similarities
similarity = new NonNegativeScoresSimilarity(similarity);
}
final Similarity similarityF = similarity; // like similarity but final
providers.put(name, () -> similarityF);
}
for (Map.Entry<String, Function<Version, Supplier<Similarity>>> entry : DEFAULTS.entrySet()) {
providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated()));
Expand All @@ -151,7 +164,7 @@ public Similarity similarity(MapperService mapperService) {
defaultSimilarity;
}


public SimilarityProvider getSimilarity(String name) {
Supplier<Similarity> sim = similarities.get(name);
if (sim == null) {
Expand Down Expand Up @@ -182,4 +195,80 @@ public Similarity get(String name) {
return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity;
}
}

static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) {
validateScoresArePositive(indexCreatedVersion, similarity);
validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity);
validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity);
}

private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < 0) {
fail(indexCreatedVersion, "Similarities should not return negative scores:\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
}
}

private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
float previousScore = 0;
for (int freq = 1; freq <= 10; ++freq) {
float score = scorer.score(freq, norm);
if (score < previousScore) {
fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" +
scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" +
scorer.explain(Explanation.match(freq, "term freq"), norm));
}
previousScore = score;
}
}

private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) {
CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000);
TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130);
SimScorer scorer = similarity.scorer(2f, collectionStats, termStats);

long previousNorm = 0;
float previousScore = Float.MAX_VALUE;
for (int length = 1; length <= 10; ++length) {
FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field",
IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap
final long norm = similarity.computeNorm(state);
if (Long.compareUnsigned(previousNorm, norm) > 0) {
// esoteric similarity, skip this check
break;
}
float score = scorer.score(1, norm);
if (score > previousScore) {
fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" +
scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" +
scorer.explain(Explanation.match(1, "term freq"), norm));
}
previousScore = score;
previousNorm = norm;
}
}

private static void fail(Version indexCreatedVersion, String message) {
if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) {
throw new IllegalArgumentException(message);
} else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) {
DEPRECATION_LOGGER.deprecated(message);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.elasticsearch.index.shard.IndexingOperationListener;
import org.elasticsearch.index.shard.SearchOperationListener;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity;
import org.elasticsearch.index.similarity.SimilarityService;
import org.elasticsearch.index.store.IndexStore;
import org.elasticsearch.indices.IndicesModule;
Expand All @@ -77,6 +78,7 @@
import org.elasticsearch.test.engine.MockEngineFactory;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.hamcrest.Matchers;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -295,10 +297,13 @@ public void testAddSimilarity() throws IOException {

IndexService indexService = newIndexService(module);
SimilarityService similarityService = indexService.similarityService();
assertNotNull(similarityService.getSimilarity("my_similarity"));
assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity);
Similarity similarity = similarityService.getSimilarity("my_similarity").get();
assertNotNull(similarity);
assertThat(similarity, Matchers.instanceOf(NonNegativeScoresSimilarity.class));
similarity = ((NonNegativeScoresSimilarity) similarity).getDelegate();
assertThat(similarity, Matchers.instanceOf(TestSimilarity.class));
assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name());
assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key);
assertEquals("there is a key", ((TestSimilarity) similarity).key);
indexService.close("simon says", false);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.index.similarity;

import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.elasticsearch.test.ESTestCase;
import org.hamcrest.Matchers;

public class NonNegativeScoresSimilarityTests extends ESTestCase {

public void testBasics() {
Similarity negativeScoresSim = new Similarity() {

@Override
public long computeNorm(FieldInvertState state) {
return state.getLength();
}

@Override
public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq - 5;
}
};
}
};
Similarity assertingSimilarity = new NonNegativeScoresSimilarity(negativeScoresSim);
SimScorer scorer = assertingSimilarity.scorer(1f, null);
assertEquals(2f, scorer.score(7f, 1L), 0f);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> scorer.score(2f, 1L));
assertThat(e.getMessage(), Matchers.containsString("Similarities must not produce negative scores"));
}

}
Loading