diff --git a/server/src/main/java/org/elasticsearch/index/similarity/NonNegativeScoresSimilarity.java b/server/src/main/java/org/elasticsearch/index/similarity/NonNegativeScoresSimilarity.java new file mode 100644 index 0000000000000..319ac0ff4b283 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/similarity/NonNegativeScoresSimilarity.java @@ -0,0 +1,96 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.similarity; + +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.similarities.Similarity; + +/** + * A {@link Similarity} that rejects negative scores. This class exists so that users get + * an error instead of silently corrupt top hits. It should be applied to any custom or + * scripted similarity. + */ +// public for testing +public final class NonNegativeScoresSimilarity extends Similarity { + + // Escape hatch + private static final String ES_ENFORCE_POSITIVE_SCORES = "es.enforce.positive.scores"; + private static final boolean ENFORCE_POSITIVE_SCORES; + static { + String enforcePositiveScores = System.getProperty(ES_ENFORCE_POSITIVE_SCORES); + if (enforcePositiveScores == null) { + ENFORCE_POSITIVE_SCORES = true; + } else if ("false".equals(enforcePositiveScores)) { + ENFORCE_POSITIVE_SCORES = false; + } else { + throw new IllegalArgumentException(ES_ENFORCE_POSITIVE_SCORES + " may only be unset or set to [false], but got [" + + enforcePositiveScores + "]"); + } + } + + private final Similarity in; + + public NonNegativeScoresSimilarity(Similarity in) { + this.in = in; + } + + public Similarity getDelegate() { + return in; + } + + @Override + public long computeNorm(FieldInvertState state) { + return in.computeNorm(state); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + final SimScorer inScorer = in.scorer(boost, collectionStats, termStats); + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + float score = inScorer.score(freq, norm); + if (score < 0f) { + if (ENFORCE_POSITIVE_SCORES) { + throw new IllegalArgumentException("Similarities must not produce negative scores, but got:\n" + + inScorer.explain(Explanation.match(freq, "term frequency"), norm)); + } else { + return 0f; + } + } + return score; + } + + @Override + public Explanation explain(Explanation freq, long norm) { + Explanation expl = inScorer.explain(freq, norm); + if (expl.isMatch() && expl.getValue().floatValue() < 0) { + expl = Explanation.match(0f, "max of:", + expl, Explanation.match(0f, "Minimum allowed score")); + } + return expl; + } + }; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java index eaed2169f11c0..06a476e64ec7a 100644 --- a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java +++ b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java @@ -19,15 +19,22 @@ package org.elasticsearch.index.similarity; +import org.apache.logging.log4j.LogManager; +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BooleanSimilarity; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.PerFieldSimilarityWrapper; import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.search.similarities.Similarity.SimScorer; +import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.logging.DeprecationLogger; -import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.AbstractIndexComponent; import org.elasticsearch.index.IndexModule; @@ -44,7 +51,7 @@ public final class SimilarityService extends AbstractIndexComponent { - private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(Loggers.getLogger(SimilarityService.class)); + private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger(LogManager.getLogger(SimilarityService.class)); public static final String DEFAULT_SIMILARITY = "BM25"; private static final String CLASSIC_SIMILARITY = "classic"; private static final Map>> DEFAULTS; @@ -131,8 +138,14 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic } TriFunction defaultFactory = BUILT_IN.get(typeName); TriFunction factory = similarities.getOrDefault(typeName, defaultFactory); - final Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService); - providers.put(name, () -> similarity); + Similarity similarity = factory.apply(providerSettings, indexSettings.getIndexVersionCreated(), scriptService); + validateSimilarity(indexSettings.getIndexVersionCreated(), similarity); + if (BUILT_IN.containsKey(typeName) == false || "scripted".equals(typeName)) { + // We don't trust custom similarities + similarity = new NonNegativeScoresSimilarity(similarity); + } + final Similarity similarityF = similarity; // like similarity but final + providers.put(name, () -> similarityF); } for (Map.Entry>> entry : DEFAULTS.entrySet()) { providers.put(entry.getKey(), entry.getValue().apply(indexSettings.getIndexVersionCreated())); @@ -151,7 +164,7 @@ public Similarity similarity(MapperService mapperService) { defaultSimilarity; } - + public SimilarityProvider getSimilarity(String name) { Supplier sim = similarities.get(name); if (sim == null) { @@ -182,4 +195,80 @@ public Similarity get(String name) { return (fieldType != null && fieldType.similarity() != null) ? fieldType.similarity().get() : defaultSimilarity; } } + + static void validateSimilarity(Version indexCreatedVersion, Similarity similarity) { + validateScoresArePositive(indexCreatedVersion, similarity); + validateScoresDoNotDecreaseWithFreq(indexCreatedVersion, similarity); + validateScoresDoNotIncreaseWithNorm(indexCreatedVersion, similarity); + } + + private static void validateScoresArePositive(Version indexCreatedVersion, Similarity similarity) { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimScorer scorer = similarity.scorer(2f, collectionStats, termStats); + FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field", + IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + for (int freq = 1; freq <= 10; ++freq) { + float score = scorer.score(freq, norm); + if (score < 0) { + fail(indexCreatedVersion, "Similarities should not return negative scores:\n" + + scorer.explain(Explanation.match(freq, "term freq"), norm)); + } + } + } + + private static void validateScoresDoNotDecreaseWithFreq(Version indexCreatedVersion, Similarity similarity) { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimScorer scorer = similarity.scorer(2f, collectionStats, termStats); + FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field", + IndexOptions.DOCS_AND_FREQS, 20, 20, 0, 50, 10, 3); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + float previousScore = 0; + for (int freq = 1; freq <= 10; ++freq) { + float score = scorer.score(freq, norm); + if (score < previousScore) { + fail(indexCreatedVersion, "Similarity scores should not decrease when term frequency increases:\n" + + scorer.explain(Explanation.match(freq - 1, "term freq"), norm) + "\n" + + scorer.explain(Explanation.match(freq, "term freq"), norm)); + } + previousScore = score; + } + } + + private static void validateScoresDoNotIncreaseWithNorm(Version indexCreatedVersion, Similarity similarity) { + CollectionStatistics collectionStats = new CollectionStatistics("some_field", 1200, 1100, 3000, 2000); + TermStatistics termStats = new TermStatistics(new BytesRef("some_value"), 100, 130); + SimScorer scorer = similarity.scorer(2f, collectionStats, termStats); + + long previousNorm = 0; + float previousScore = Float.MAX_VALUE; + for (int length = 1; length <= 10; ++length) { + FieldInvertState state = new FieldInvertState(indexCreatedVersion.major, "some_field", + IndexOptions.DOCS_AND_FREQS, length, length, 0, 50, 10, 3); // length = 20, no overlap + final long norm = similarity.computeNorm(state); + if (Long.compareUnsigned(previousNorm, norm) > 0) { + // esoteric similarity, skip this check + break; + } + float score = scorer.score(1, norm); + if (score > previousScore) { + fail(indexCreatedVersion, "Similarity scores should not increase when norm increases:\n" + + scorer.explain(Explanation.match(1, "term freq"), norm - 1) + "\n" + + scorer.explain(Explanation.match(1, "term freq"), norm)); + } + previousScore = score; + previousNorm = norm; + } + } + + private static void fail(Version indexCreatedVersion, String message) { + if (indexCreatedVersion.onOrAfter(Version.V_7_0_0_alpha1)) { + throw new IllegalArgumentException(message); + } else if (indexCreatedVersion.onOrAfter(Version.V_6_5_0)) { + DEPRECATION_LOGGER.deprecated(message); + } + } + } diff --git a/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java b/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java index 078ec5ec20abc..a1166029146e6 100644 --- a/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java +++ b/server/src/test/java/org/elasticsearch/index/IndexModuleTests.java @@ -59,6 +59,7 @@ import org.elasticsearch.index.shard.IndexingOperationListener; import org.elasticsearch.index.shard.SearchOperationListener; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity; import org.elasticsearch.index.similarity.SimilarityService; import org.elasticsearch.index.store.IndexStore; import org.elasticsearch.indices.IndicesModule; @@ -77,6 +78,7 @@ import org.elasticsearch.test.engine.MockEngineFactory; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.hamcrest.Matchers; import java.io.IOException; import java.util.Collections; @@ -295,10 +297,13 @@ public void testAddSimilarity() throws IOException { IndexService indexService = newIndexService(module); SimilarityService similarityService = indexService.similarityService(); - assertNotNull(similarityService.getSimilarity("my_similarity")); - assertTrue(similarityService.getSimilarity("my_similarity").get() instanceof TestSimilarity); + Similarity similarity = similarityService.getSimilarity("my_similarity").get(); + assertNotNull(similarity); + assertThat(similarity, Matchers.instanceOf(NonNegativeScoresSimilarity.class)); + similarity = ((NonNegativeScoresSimilarity) similarity).getDelegate(); + assertThat(similarity, Matchers.instanceOf(TestSimilarity.class)); assertEquals("my_similarity", similarityService.getSimilarity("my_similarity").name()); - assertEquals("there is a key", ((TestSimilarity) similarityService.getSimilarity("my_similarity").get()).key); + assertEquals("there is a key", ((TestSimilarity) similarity).key); indexService.close("simon says", false); } diff --git a/server/src/test/java/org/elasticsearch/index/similarity/NonNegativeScoresSimilarityTests.java b/server/src/test/java/org/elasticsearch/index/similarity/NonNegativeScoresSimilarityTests.java new file mode 100644 index 0000000000000..33528c2190051 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/similarity/NonNegativeScoresSimilarityTests.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.similarity; + +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.TermStatistics; +import org.apache.lucene.search.similarities.Similarity; +import org.apache.lucene.search.similarities.Similarity.SimScorer; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +public class NonNegativeScoresSimilarityTests extends ESTestCase { + + public void testBasics() { + Similarity negativeScoresSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + @Override + public float score(float freq, long norm) { + return freq - 5; + } + }; + } + }; + Similarity assertingSimilarity = new NonNegativeScoresSimilarity(negativeScoresSim); + SimScorer scorer = assertingSimilarity.scorer(1f, null); + assertEquals(2f, scorer.score(7f, 1L), 0f); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> scorer.score(2f, 1L)); + assertThat(e.getMessage(), Matchers.containsString("Similarities must not produce negative scores")); + } + +} diff --git a/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java b/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java index 5d18a595e9687..48d1e2b9c9b6c 100644 --- a/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java +++ b/server/src/test/java/org/elasticsearch/index/similarity/SimilarityServiceTests.java @@ -18,12 +18,18 @@ */ package org.elasticsearch.index.similarity; +import org.apache.lucene.index.FieldInvertState; +import org.apache.lucene.search.CollectionStatistics; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BooleanSimilarity; +import org.apache.lucene.search.similarities.Similarity; +import org.elasticsearch.Version; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.IndexSettingsModule; +import org.hamcrest.Matchers; import java.util.Collections; @@ -56,4 +62,76 @@ public void testOverrideDefaultSimilarity() { SimilarityService service = new SimilarityService(indexSettings, null, Collections.emptyMap()); assertTrue(service.getDefaultSimilarity() instanceof BooleanSimilarity); } + + public void testSimilarityValidation() { + Similarity negativeScoresSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + return -1; + } + + }; + } + }; + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, negativeScoresSim)); + assertThat(e.getMessage(), Matchers.containsString("Similarities should not return negative scores")); + + Similarity decreasingScoresWithFreqSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + return 1 / (freq + norm); + } + + }; + } + }; + e = expectThrows(IllegalArgumentException.class, + () -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, decreasingScoresWithFreqSim)); + assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not decrease when term frequency increases")); + + Similarity increasingScoresWithNormSim = new Similarity() { + + @Override + public long computeNorm(FieldInvertState state) { + return state.getLength(); + } + + @Override + public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) { + return new SimScorer() { + + @Override + public float score(float freq, long norm) { + return freq + norm; + } + + }; + } + }; + e = expectThrows(IllegalArgumentException.class, + () -> SimilarityService.validateSimilarity(Version.V_7_0_0_alpha1, increasingScoresWithNormSim)); + assertThat(e.getMessage(), Matchers.containsString("Similarity scores should not increase when norm increases")); + } + } diff --git a/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java b/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java index 35416c617fdd0..b4e98775d97ac 100644 --- a/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java +++ b/server/src/test/java/org/elasticsearch/indices/IndicesServiceTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.indices; import org.apache.lucene.search.similarities.BM25Similarity; +import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.AlreadyClosedException; import org.elasticsearch.Version; import org.elasticsearch.action.admin.indices.stats.CommonStatsFlags; @@ -56,6 +57,7 @@ import org.elasticsearch.index.shard.IndexShardState; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.index.shard.ShardPath; +import org.elasticsearch.index.similarity.NonNegativeScoresSimilarity; import org.elasticsearch.indices.IndicesService.ShardDeletionCheckResult; import org.elasticsearch.plugins.EnginePlugin; import org.elasticsearch.plugins.MapperPlugin; @@ -448,8 +450,10 @@ public void testStandAloneMapperServiceWithPlugins() throws IOException { .build(); MapperService mapperService = indicesService.createIndexMapperService(indexMetaData); assertNotNull(mapperService.documentMapperParser().parserContext("type").typeParser("fake-mapper")); - assertThat(mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get(), - instanceOf(BM25Similarity.class)); + Similarity sim = mapperService.documentMapperParser().parserContext("type").getSimilarity("test").get(); + assertThat(sim, instanceOf(NonNegativeScoresSimilarity.class)); + sim = ((NonNegativeScoresSimilarity) sim).getDelegate(); + assertThat(sim, instanceOf(BM25Similarity.class)); } public void testStatsByShardDoesNotDieFromExpectedExceptions() {