From 5971c9357137a90479101cd104dd1d47fe63806a Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 5 Sep 2023 16:47:40 +0800 Subject: [PATCH 01/70] sparse mapper field and query builder Signed-off-by: zhichao-aws --- .../index/mapper/SparseVectorMapper.java | 178 +++++++++++++++ .../neuralsearch/plugin/NeuralSearch.java | 23 +- .../sparse/BoundedLinearFeatureQuery.java | 210 ++++++++++++++++++ .../query/sparse/SparseQueryBuilder.java | 192 ++++++++++++++++ 4 files changed, 601 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java diff --git a/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java b/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java new file mode 100644 index 000000000..f117d31dc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java @@ -0,0 +1,178 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.index.mapper; + +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.search.Query; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.FieldMapper; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.ParametrizedFieldMapper; +import org.opensearch.index.mapper.ParseContext; +import org.opensearch.index.mapper.SourceValueFetcher; +import org.opensearch.index.mapper.TextSearchInfo; +import org.opensearch.index.mapper.ValueFetcher; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.QueryShardException; +import org.opensearch.search.lookup.SearchLookup; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A FieldMapper that exposes Lucene's {@link FeatureField}. + * It is designed for learned sparse vectors, the expected for ingested content is a map of (token,weight) pairs, with String and Float type. + * In current version, this field doesn't support existing query clauses like "match" or "exists". + * The ingested documents can only be searched with our "sparse" query clause. + */ +public class SparseVectorMapper extends ParametrizedFieldMapper { + public static final String CONTENT_TYPE = "sparse_vector"; + + private static SparseVectorMapper toType(FieldMapper in) { + return (SparseVectorMapper) in; + } + + public static class SparseVectorBuilder extends ParametrizedFieldMapper.Builder { + + private final Parameter> meta = Parameter.metaParam(); + // Both match query and our sparse query use lucene Boolean query to connect all term-level queries. + // lucene BooleanQuery use WAND (Weak AND) algorithm to accelerate the search, and WAND algorithm + // uses term's max possible value to skip unnecessary calculations. The max possible value in match clause is term idf value. + // However, The default behavior of lucene FeatureQuery is returning Float.MAX_VALUE for every term. Which will + // invalidate WAND algorithm. + + // By setting maxTermScoreForSparseQuery, we'll use it as the term score upperbound to accelerate the search. + // Users can also overwrite this setting in sparse query. Our experiments show a proper maxTermScoreForSparseQuery + // value can reduce search latency by 4x while losing precision less than 0.5%. + + // If user doesn't set the value explicitly, we'll degrade to the default behavior in lucene FeatureQuery, + // i.e. using Float.MAX_VALUE. + private final Parameter maxTermScoreForSparseQuery = Parameter.floatParam( + "max_term_score_for_sparse_query", + false, + m -> toType(m).maxTermScoreForSparseQuery, + Float.MAX_VALUE + ); + + public SparseVectorBuilder( + String name + ) { + super(name); + } + + @Override + protected List> getParameters() { + return Arrays.asList(maxTermScoreForSparseQuery, meta); + } + + @Override + public SparseVectorMapper build(BuilderContext context) { + return new SparseVectorMapper( + name, + new SparseVectorFieldType(buildFullName(context), meta.getValue(), maxTermScoreForSparseQuery.getValue()), + multiFieldsBuilder.build(this, context), + copyTo.build(), + maxTermScoreForSparseQuery.getValue() + ); + } + } + + public static final TypeParser PARSER = new TypeParser((n, c) -> new SparseVectorBuilder(n)); + + public static final class SparseVectorFieldType extends MappedFieldType { + private final float maxTermScoreForSparseQuery; + + public SparseVectorFieldType( + String name, + Map meta, + float maxTermScoreForSparseQuery + ) { + super(name, true, false, true, TextSearchInfo.NONE, meta); + this.maxTermScoreForSparseQuery = maxTermScoreForSparseQuery; + } + + public float maxTermScoreForSparseQuery() { + return maxTermScoreForSparseQuery; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { + if (format != null) { + throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] does not support format"); + } + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query existsQuery(QueryShardContext context) { + throw new QueryShardException( + context, + "Field [" + name() + "] of type [" + typeName() + "] does not support exists query for now" + ); + } + + @Override + public Query termQuery(Object value, QueryShardContext context) { + throw new QueryShardException( + context, + "Field [" + name() + "] of type [" + typeName() + "] does not support term query for now" + ); + } + } + + private final float maxTermScoreForSparseQuery; + + protected SparseVectorMapper( + String simpleName, + MappedFieldType mappedFieldType, + MultiFields multiFields, + CopyTo copyTo, + float maxTermScoreForSparseQuery + ) { + super(simpleName, mappedFieldType, multiFields, copyTo); + this.maxTermScoreForSparseQuery = maxTermScoreForSparseQuery; + } + + @Override + public ParametrizedFieldMapper.Builder getMergeBuilder() { + return new SparseVectorBuilder(simpleName()).init(this); + } + + @Override + protected void parseCreateField(ParseContext context) throws IOException { + if (XContentParser.Token.START_OBJECT != context.parser().currentToken()) { + throw new IllegalArgumentException( + "Wrong format for input data. Field type " + typeName() + " can only parse map object." + + ); + } + final Map termWeight = context.parser().map(HashMap::new, XContentParser::floatValue); + for (Map.Entry entry: termWeight.entrySet()) { + context.doc().add(new FeatureField(fieldType().name(), entry.getKey(), entry.getValue())); + } + } + + // Users are not supposed to give an array for the input value. + // Here we set the return value of parsesArrayValue() as true, + // intercept the request and throw an exception in parseCreateField() + @Override + public final boolean parsesArrayValue() { + return true; + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index e94a2957d..5fc98cd8f 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -26,8 +26,10 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.index.mapper.Mapper; import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.neuralsearch.index.mapper.SparseVectorMapper; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; @@ -40,10 +42,12 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.sparse.SparseQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; +import org.opensearch.plugins.MapperPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; @@ -58,7 +62,13 @@ * Neural Search plugin class */ @Log4j2 -public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { +public class NeuralSearch extends Plugin implements + ActionPlugin, + SearchPlugin, + IngestPlugin, + ExtensiblePlugin, + SearchPipelinePlugin, + MapperPlugin { private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); @@ -87,7 +97,16 @@ public Collection createComponents( public List> getQueries() { return Arrays.asList( new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent), - new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent) + new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent), + new QuerySpec<>(SparseQueryBuilder.NAME, SparseQueryBuilder::new, SparseQueryBuilder::fromXContent) + ); + } + + @Override + public Map getMappers() { + return Collections.singletonMap( + SparseVectorMapper.CONTENT_TYPE, + SparseVectorMapper.PARSER ); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java new file mode 100644 index 000000000..7b593c7b7 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* This class is built based on lucene FeatureQuery. We use LinearFuntion and add an upperbound to it */ + +package org.opensearch.neuralsearch.query.sparse; + +import java.io.IOException; +import java.util.Objects; + +import org.apache.lucene.index.ImpactsEnum; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.ImpactsDISI; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.similarities.Similarity.SimScorer; +import org.apache.lucene.util.BytesRef; + +public final class BoundedLinearFeatureQuery extends Query { + + private final String fieldName; + private final String featureName; + private final Float scoreUpperBound; + + public BoundedLinearFeatureQuery(String fieldName, String featureName, Float scoreUpperBound) { + this.fieldName = Objects.requireNonNull(fieldName); + this.featureName = Objects.requireNonNull(featureName); + this.scoreUpperBound = Objects.requireNonNull(scoreUpperBound); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj; + return Objects.equals(fieldName, that.fieldName) + && Objects.equals(featureName, that.featureName) + && Objects.equals(scoreUpperBound, that.scoreUpperBound); + } + + @Override + public int hashCode() { + int h = getClass().hashCode(); + h = 31 * h + fieldName.hashCode(); + h = 31 * h + featureName.hashCode(); + h = 31 * h + scoreUpperBound.hashCode(); + return h; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + if (!scoreMode.needsScores()) { + // We don't need scores (e.g. for faceting), and since features are stored as terms, + // allow TermQuery to optimize in this case + TermQuery tq = new TermQuery(new Term(fieldName, featureName)); + return searcher.rewrite(tq).createWeight(searcher, scoreMode, boost); + } + + return new Weight(this) { + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + String desc = "weight(" + getQuery() + " in " + doc + ") [\" BoundedLinearFeatureQuery \"]"; + + Terms terms = context.reader().terms(fieldName); + if (terms == null) { + return Explanation.noMatch(desc + ". Field " + fieldName + " doesn't exist."); + } + TermsEnum termsEnum = terms.iterator(); + if (termsEnum.seekExact(new BytesRef(featureName)) == false) { + return Explanation.noMatch(desc + ". Feature " + featureName + " doesn't exist."); + } + + PostingsEnum postings = termsEnum.postings(null, PostingsEnum.FREQS); + if (postings.advance(doc) != doc) { + return Explanation.noMatch(desc + ". Feature " + featureName + " isn't set."); + } + + int freq = postings.freq(); + float featureValue = decodeFeatureValue(freq); + float score = boost * featureValue; + return Explanation.match( + score, + "Linear function on the " + + fieldName + + " field for the " + + featureName + + " feature, computed as w * S from:", + Explanation.match(boost, "w, weight of this function"), + Explanation.match(featureValue, "S, feature value")); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Terms terms = Terms.getTerms(context.reader(), fieldName); + TermsEnum termsEnum = terms.iterator(); + if (termsEnum.seekExact(new BytesRef(featureName)) == false) { + return null; + } + + final SimScorer scorer = new SimScorer() { + @Override + public float score(float freq, long norm) { + return boost * decodeFeatureValue(freq); + } + }; + final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS); + final ImpactsDISI impactsDisi = new ImpactsDISI(impacts, impacts, scorer); + + return new Scorer(this) { + + @Override + public int docID() { + return impacts.docID(); + } + + @Override + public float score() throws IOException { + return scorer.score(impacts.freq(), 1L); + } + + @Override + public DocIdSetIterator iterator() { + return impactsDisi; + } + + @Override + public int advanceShallow(int target) throws IOException { + return impactsDisi.advanceShallow(target); + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return impactsDisi.getMaxScore(upTo); + } + + @Override + public void setMinCompetitiveScore(float minScore) { + impactsDisi.setMinCompetitiveScore(minScore); + } + }; + } + }; + } + + @Override + public void visit(QueryVisitor visitor) { + if (visitor.acceptField(fieldName)) { + visitor.visitLeaf(this); + } + } + + @Override + public String toString(String field) { + return "BoundedLinearFeatureQuery(field=" + + fieldName + + ", feature=" + + featureName + + ", scoreUpperBound=" + + scoreUpperBound + + ")"; + } + + static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15; + private float decodeFeatureValue(float freq) { + if (freq > MAX_FREQ) { + return scoreUpperBound; + } + int tf = (int) freq; // lossless + int featureBits = tf << 15; + return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound); + } +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java new file mode 100644 index 000000000..6a97d69d6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java @@ -0,0 +1,192 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query.sparse; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BoostQuery; +import org.apache.lucene.search.Query; +import org.opensearch.core.ParseField; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryShardContext; + +import com.google.common.annotations.VisibleForTesting; +import org.opensearch.neuralsearch.index.mapper.SparseVectorMapper; + +@Log4j2 +@Getter +@Setter +@Accessors(chain = true, fluent = true) +@NoArgsConstructor +@AllArgsConstructor +public class SparseQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "sparse"; + + @VisibleForTesting + static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens"); + + private String fieldName; + // todo: if termWeight is null + private Map termWeight; + + public SparseQueryBuilder(StreamInput in) throws IOException { + super(in); + this.fieldName = in.readString(); + this.termWeight = in.readMap(StreamInput::readString, StreamInput::readFloat); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(fieldName); + out.writeMap(termWeight, StreamOutput::writeString, StreamOutput::writeFloat); + } + + @Override + protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject(NAME); + xContentBuilder.startObject(fieldName); + xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), termWeight); + printBoostAndQueryName(xContentBuilder); + xContentBuilder.endObject(); + xContentBuilder.endObject(); + } + + /** + * The expected parsing form looks like: + * { + * "SAMPLE_FIELD": { + * "query_tokens": { + * "token_a": float, + * "token_b": float, + * ... + * }, + * "max_term_score_for_sparse_query": float (optional) + * } + * } + */ + public static SparseQueryBuilder fromXContent(XContentParser parser) throws IOException { + SparseQueryBuilder sparseQueryBuilder = new SparseQueryBuilder(); + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT"); + } + parser.nextToken(); + sparseQueryBuilder.fieldName(parser.currentName()); + parser.nextToken(); + parseQueryParams(parser, sparseQueryBuilder); + if (parser.nextToken() != XContentParser.Token.END_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + "[" + + NAME + + "] query doesn't support multiple fields, found [" + + sparseQueryBuilder.fieldName() + + "] and [" + + parser.currentName() + + "]" + ); + } + + return sparseQueryBuilder; + } + + // todo: refactor this to switch style + private static void parseQueryParams(XContentParser parser, SparseQueryBuilder sparseQueryBuilder) throws IOException { + XContentParser.Token token; + String currentFieldName = ""; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseQueryBuilder.queryName(parser.text()); + } else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseQueryBuilder.boost(parser.floatValue()); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" + ); + } + } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseQueryBuilder.termWeight(parser.map(HashMap::new, XContentParser::floatValue)); +// sparseQueryBuilder.termWeight(castToTermWeight(parser.map())); + } else { + throw new ParsingException( + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + ); + } + } + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + final MappedFieldType ft = context.fieldMapper(fieldName); + if (!(ft instanceof SparseVectorMapper.SparseVectorFieldType)) { + throw new IllegalArgumentException( + "[" + NAME + "] query only works on [" + SparseVectorMapper.CONTENT_TYPE + "] fields, " + + "not [" + ft.typeName() + "]" + ); + } + final Float maxTermScoreForSparseQuery = ((SparseVectorMapper.SparseVectorFieldType) ft).maxTermScoreForSparseQuery(); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (Map.Entry entry: termWeight.entrySet()) { + builder.add( + new BoostQuery( + new BoundedLinearFeatureQuery( + fieldName, + entry.getKey(), + maxTermScoreForSparseQuery + ), + entry.getValue() + ), + BooleanClause.Occur.SHOULD + ); + } + return builder.build(); + } + + @Override + protected boolean doEquals(SparseQueryBuilder obj) { + // todo: validate + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(fieldName, obj.fieldName); + equalsBuilder.append(termWeight, obj.termWeight); + return equalsBuilder.isEquals(); + } + + @Override + protected int doHashCode() { + return new HashCodeBuilder().append(fieldName).append(termWeight).toHashCode(); + } + + @Override + public String getWriteableName() { + return NAME; + } +} \ No newline at end of file From 0386a95ef6ea0ce15d50a1d9fab9f5f08e0eab51 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 6 Sep 2023 11:47:56 +0800 Subject: [PATCH 02/70] fix typo Signed-off-by: zhichao-aws --- .../java/org/opensearch/neuralsearch/plugin/NeuralSearch.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 5fc98cd8f..b33810ee6 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -72,7 +72,7 @@ public class NeuralSearch extends Plugin implements private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); - private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();; + private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); @Override public Collection createComponents( From 489b6e8edf5efb32e78c354fcc6a505fe861594b Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 22 Aug 2023 12:05:37 +0800 Subject: [PATCH 03/70] Add map result support in neural search for non text embedding models Signed-off-by: zane-neo --- .../ml/MLCommonsClientAccessor.java | 51 +++++++- .../ml/MLCommonsClientAccessorTests.java | 111 ++++++++++++++++++ 2 files changed, 159 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 768584ec9..3f434d3ce 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import lombok.NonNull; @@ -15,6 +16,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -100,10 +102,38 @@ public void inferenceSentences( @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); } - private void inferenceSentencesWithRetry( + public void inferenceSentencesWithMapResult( + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener> listener) { + retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); + } + + private void retryableInferenceSentencesWithMapResult( + final String modelId, + final List inputText, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLInput(null, inputText); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final Map result = buildMapResultFromResponse(mlOutput); + log.debug("Inference Response for input sentence {} is : {} ", inputText, result); + listener.onResponse(result); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private void retryableInferenceSentencesWithVectorResult( final List targetResponseFilters, final String modelId, final List inputText, @@ -118,7 +148,7 @@ private void inferenceSentencesWithRetry( }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { final int retryTimeAdd = retryTime + 1; - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); } else { listener.onFailure(e); } @@ -144,4 +174,19 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } + private Map buildMapResultFromResponse(MLOutput mlOutput) { + final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; + final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); + if (CollectionUtils.isEmpty(tensorOutputList)) { + log.error("No tensor output found!"); + return null; + } + List tensorList = tensorOutputList.get(0).getMlModelTensors(); + if (CollectionUtils.isEmpty(tensorList)) { + log.error("No tensor found!"); + return null; + } + return tensorList.get(0).getDataAsMap(); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3ef5431b3..e6faedb93 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -161,6 +162,98 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } + public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { + final Map map = ImmutableMap.of("key", "value"); + final ActionListener> resultListener = mock(ActionListener.class); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(map)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(map); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() { + final ActionListener> resultListener = mock(ActionListener.class); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(null); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() { + final ActionListener> resultListener = mock(ActionListener.class); + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + tensorsList.add(new ModelTensors(mlModelTensorList)); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(null); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + final ActionListener> resultListener = mock(ActionListener.class); + accessor.inferenceSentencesWithMapResult( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST, + resultListener + ); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(nodeNodeConnectedException); + } + + public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { + final IllegalStateException illegalStateException = new IllegalStateException("Illegal state"); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(illegalStateException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + final ActionListener> resultListener = mock(ActionListener.class); + accessor.inferenceSentencesWithMapResult( + TestCommonConstants.MODEL_ID, + TestCommonConstants.SENTENCES_LIST, + resultListener + ); + + Mockito.verify(client, times(1)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(illegalStateException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -178,4 +271,22 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) { tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createModelTensorOutput(final Map map) { + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor( + "response", + null, + null, + null, + null, + null, + map + ); + mlModelTensorList.add(tensor); + final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); + tensorsList.add(modelTensors); + return new ModelTensorOutput(tensorsList); + } } From 18022b8c47d16d6fafe998dcc8146f1d0ee2656a Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 1 Sep 2023 16:20:15 +0800 Subject: [PATCH 04/70] Fix compilation failure issue Signed-off-by: zane-neo --- .../neuralsearch/ml/MLCommonsClientAccessorTests.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index e6faedb93..813742878 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,7 +5,6 @@ package org.opensearch.neuralsearch.ml; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -163,7 +162,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { } public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { - final Map map = ImmutableMap.of("key", "value"); + final Map map = Map.of("key", "value"); final ActionListener> resultListener = mock(ActionListener.class); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); From 886cdeba25eb77b6172d518ab1f901050b858089 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 1 Sep 2023 16:46:04 +0800 Subject: [PATCH 05/70] Add more UTs Signed-off-by: zane-neo --- .../ml/MLCommonsClientAccessor.java | 14 +++--- .../ml/MLCommonsClientAccessorTests.java | 46 +++++++++++++++++-- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 3f434d3ce..55ee89bbd 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -177,14 +177,16 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { private Map buildMapResultFromResponse(MLOutput mlOutput) { final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); - if (CollectionUtils.isEmpty(tensorOutputList)) { - log.error("No tensor output found!"); - return null; + if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { + throw new IllegalStateException( + "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]" + ); } List tensorList = tensorOutputList.get(0).getMlModelTensors(); - if (CollectionUtils.isEmpty(tensorList)) { - log.error("No tensor found!"); - return null; + if (tensorList.size() != 1) { + throw new IllegalStateException( + "Unexpected number of map result produced. Expected 1 map result to be returned, but got [" + tensorList.size() + "]" + ); } return tensorList.get(0).getDataAsMap(); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 813742878..d7c2cddcb 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -16,12 +16,14 @@ import java.util.Map; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; @@ -177,7 +179,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenReturnNull() { + public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { final ActionListener> resultListener = mock(ActionListener.class); final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); Mockito.doAnswer(invocation -> { @@ -189,11 +191,13 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenR Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(resultListener).onResponse(null); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage()); Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenReturnNull() { + public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { final ActionListener> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -208,7 +212,41 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenRe Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(resultListener).onResponse(null); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage()); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenException() { + final ActionListener> resultListener = mock(ActionListener.class); + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor( + "response", + null, + null, + null, + null, + null, + Map.of("key", "value") + ); + mlModelTensorList.add(tensor); + mlModelTensorList.add(tensor); + tensorsList.add(new ModelTensors(mlModelTensorList)); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals("Unexpected number of map result produced. Expected 1 map result to be returned, but got [2]", argumentCaptor.getValue().getMessage()); Mockito.verifyNoMoreInteractions(resultListener); } From d336f88618f62b2ed5bc9a758f35fc6bcf67391b Mon Sep 17 00:00:00 2001 From: xinyual Date: Wed, 6 Sep 2023 16:33:26 +0800 Subject: [PATCH 06/70] add sparse encoding processor Signed-off-by: xinyual --- build.gradle | 2 +- gradle.properties | 1 + .../neuralsearch/plugin/NeuralSearch.java | 14 +-- .../neuralsearch/processor/DLProcessor.java | 93 +++++++++++++++++++ .../processor/SparseEncodingProcessor.java | 42 +++++++++ .../SparseEncodingProcessorFactory.java | 41 ++++++++ 6 files changed, 185 insertions(+), 8 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java diff --git a/build.gradle b/build.gradle index 1d8eca483..41b3a3f7b 100644 --- a/build.gradle +++ b/build.gradle @@ -144,7 +144,7 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" compileOnly fileTree(dir: knnJarDirectory, include: '*.jar') api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' // ml-common excluded reflection for runtime so we need to add it by ourselves. // https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9 // TODO: Remove following three lines of dependencies if ml-common include them in their jar diff --git a/gradle.properties b/gradle.properties index 5e5cd9ced..f4b55d2a3 100644 --- a/gradle.properties +++ b/gradle.properties @@ -9,3 +9,4 @@ org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAME --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED +customDistributionUrl=https://artifacts.opensearch.org/snapshots/core/opensearch/3.0.0-SNAPSHOT/opensearch-min-3.0.0-SNAPSHOT-darwin-x64-latest.tar.gz \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index b33810ee6..ff9b20ce5 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -7,12 +7,7 @@ import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import lombok.extern.log4j.Log4j2; @@ -33,11 +28,13 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; @@ -113,7 +110,10 @@ public Map getMappers() { @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); - return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); + Map allProcessors = new HashMap<>(); + allProcessors.put(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); + allProcessors.put(SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env)); + return allProcessors; } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java new file mode 100644 index 000000000..e4bc2e3f6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +public abstract class DLProcessor extends AbstractProcessor { + public static final String TYPE = "text_embedding"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + + private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + + @VisibleForTesting + private final String modelId; + + private final Map fieldMap; + + private final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + public DLProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateFieldMapConfiguration(fieldMap); + + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + public abstract void validateFieldMapConfiguration(Map fieldMap); + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + // When received a bulk indexing request, the pipeline will be executed in this method, (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). + // Before the pipeline execution, the pipeline will be marked as resolved (means executed), + // and then this overriding method will be invoked when executing the text embedding processor. + // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. + try { + validateIngestFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(knnMap); + if (inferenceList.size() == 0) { + handler.accept(ingestDocument, null); + } else { + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, knnMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + } catch (Exception e) { + handler.accept(null, e); + } + } + + public abstract void validateIngestFieldsValue(IngestDocument ingestDocument); + + + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java new file mode 100644 index 000000000..c97ca4961 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import org.opensearch.env.Environment; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import java.util.Map; + +public class SparseEncodingProcessor extends AbstractProcessor { + + public static final String TYPE = "text_embedding"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + + private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + + @VisibleForTesting + private final String modelId; + + private final Map fieldMap; + + private final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return null; + } + + @Override + public String getType() { + return null; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java new file mode 100644 index 000000000..658b5f419 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import org.opensearch.env.Environment; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.DLProcessor; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; + +import java.util.Map; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.*; + +public class SparseEncodingProcessorFactory implements Processor.Factory { + private final MLCommonsClientAccessor clientAccessor; + private final Environment environment; + + public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public DLProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); + Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + + return new SparseEncodingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + } +} From bb14947a948ce17e304dab39912b94290db64b7a Mon Sep 17 00:00:00 2001 From: xinyual Date: Thu, 7 Sep 2023 14:11:02 +0800 Subject: [PATCH 07/70] add sparse encoding processor Signed-off-by: xinyual --- build.gradle | 1 + .../neuralsearch/processor/DLProcessor.java | 70 +---- .../processor/SparseEncodingProcessor.java | 294 +++++++++++++++++- .../SparseEncodingProcessorFactory.java | 5 +- 4 files changed, 293 insertions(+), 77 deletions(-) diff --git a/build.gradle b/build.gradle index 41b3a3f7b..a32195f15 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" + runtimeOnly group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' } // In order to add the jar to the classpath, we need to unzip the diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java index e4bc2e3f6..a042fb517 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java @@ -17,77 +17,9 @@ import java.util.function.BiConsumer; public abstract class DLProcessor extends AbstractProcessor { - public static final String TYPE = "text_embedding"; - public static final String MODEL_ID_FIELD = "model_id"; - public static final String FIELD_MAP_FIELD = "field_map"; - private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - @VisibleForTesting - private final String modelId; - - private final Map fieldMap; - - private final MLCommonsClientAccessor mlCommonsClientAccessor; - - private final Environment environment; - - public DLProcessor( - String tag, - String description, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment - ) { + protected DLProcessor(String tag, String description) { super(tag, description); - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); - validateFieldMapConfiguration(fieldMap); - - this.modelId = modelId; - this.fieldMap = fieldMap; - this.mlCommonsClientAccessor = clientAccessor; - this.environment = environment; - } - public abstract void validateFieldMapConfiguration(Map fieldMap); - - @Override - public IngestDocument execute(IngestDocument ingestDocument) { - return ingestDocument; } - - /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to - * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. - */ - @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { - // When received a bulk indexing request, the pipeline will be executed in this method, (see - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). - // Before the pipeline execution, the pipeline will be marked as resolved (means executed), - // and then this overriding method will be invoked when executing the text embedding processor. - // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. - try { - validateIngestFieldsValue(ingestDocument); - Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(knnMap); - if (inferenceList.size() == 0) { - handler.accept(ingestDocument, null); - } else { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, knnMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); - } - } catch (Exception e) { - handler.accept(null, e); - } - } - - public abstract void validateIngestFieldsValue(IngestDocument ingestDocument); - - - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index c97ca4961..1a9a5c034 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -5,20 +5,32 @@ package org.opensearch.neuralsearch.processor; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import java.util.Map; +import java.util.*; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.IntStream; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; + + +@Log4j2 public class SparseEncodingProcessor extends AbstractProcessor { - public static final String TYPE = "text_embedding"; + public static final String TYPE = "sparse_encoding"; public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; - private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + private static final String LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; @VisibleForTesting private final String modelId; @@ -29,14 +41,284 @@ public class SparseEncodingProcessor extends AbstractProcessor { private final Environment environment; + public SparseEncodingProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateEmbeddingConfiguration(fieldMap); + + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.size() == 0 + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + )) { + throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value"); + } + } @Override - public IngestDocument execute(IngestDocument ingestDocument) throws Exception { - return null; + public IngestDocument execute(IngestDocument ingestDocument) { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + // When received a bulk indexing request, the pipeline will be executed in this method, (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). + // Before the pipeline execution, the pipeline will be marked as resolved (means executed), + // and then this overriding method will be invoked when executing the text embedding processor. + // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. + try { + validateEmbeddingFieldsValue(ingestDocument); + Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(ProcessMap); + if (inferenceList.size() == 0) { + handler.accept(ingestDocument, null); + } else { + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMap -> { + Map.Entry>> entry = ((Map>>)resultMap).entrySet().iterator().next(); + List > resultTokenWeights = entry.getValue(); + setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + } catch (Exception e) { + handler.accept(null, e); + } + + } + + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List > resultTokenWeights) { + Objects.requireNonNull(resultTokenWeights, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map sparseEncodingResult = buildSparseEncodingResult(processorMap, resultTokenWeights, ingestDocument.getSourceAndMetadata()); + sparseEncodingResult.forEach(ingestDocument::setFieldValue); + } + + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); + } + }); + return texts; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else { + if (sourceValue == null) return; + texts.add(sourceValue.toString()); + } + } + + @VisibleForTesting + Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithProcessorKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + Map treeRes = new LinkedHashMap<>(); + buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); + } else { + mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + } + } + return mapWithProcessorKeys; + } + + @SuppressWarnings({ "unchecked" }) + private void buildMapWithProcessorKeyAndOriginalValueForMapType( + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (processorKey == null || sourceAndMetadataMap == null) return; + if (processorKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); + } else { + String key = String.valueOf(processorKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + } + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildSparseEncodingResult( + Map processorMap, + List > resultTokenWeights, + Map sourceAndMetadataMap + ) { + SparseEncodingProcessor.IndexWrapper indexWrapper = new SparseEncodingProcessor.IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : processorMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof String) { + result.put(knnKey, resultTokenWeights.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + result.put(knnKey, buildSparseEncodingResultForListType((List) sourceValue, resultTokenWeights, indexWrapper)); + } else if (sourceValue instanceof Map) { + putSparseEncodingResultToSourceMapForMapType(knnKey, sourceValue, resultTokenWeights, indexWrapper, sourceAndMetadataMap); + } + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putSparseEncodingResultToSourceMapForMapType( + String processorKey, + Object sourceValue, + List > resultTokenWeights, + SparseEncodingProcessor.IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putSparseEncodingResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + resultTokenWeights, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(processorKey, resultTokenWeights.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put( + processorKey, + buildSparseEncodingResultForListType((List) sourceValue, resultTokenWeights, indexWrapper) + ); + } + } + + private List>> buildSparseEncodingResultForListType( + List sourceValue, + List > resultTokenWeights, + SparseEncodingProcessor.IndexWrapper indexWrapper + ) { + List>> tokenWeights = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> tokenWeights.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, resultTokenWeights.get(indexWrapper.index++)))); + return tokenWeights; + } + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private static void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + } + } } @Override public String getType() { - return null; + return TYPE; } + + /** + * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * the index pointer during the recursive. + * index: the index pointer of the text embedding result. + */ + static class IndexWrapper { + private int index; + + protected IndexWrapper(int index) { + this.index = index; + } + } + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 658b5f419..27ec79d57 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor.factory; +import lombok.extern.log4j.Log4j2; import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -16,7 +17,7 @@ import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.*; - +@Log4j2 public class SparseEncodingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; private final Environment environment; @@ -27,7 +28,7 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En } @Override - public DLProcessor create( + public SparseEncodingProcessor create( Map registry, String processorTag, String description, From 5bb409b92afe72cae5ad2d77eb8b5b0c33be94d6 Mon Sep 17 00:00:00 2001 From: xinyual Date: Thu, 7 Sep 2023 14:12:14 +0800 Subject: [PATCH 08/70] remove guava in gradle Signed-off-by: xinyual --- build.gradle | 1 - 1 file changed, 1 deletion(-) diff --git a/build.gradle b/build.gradle index a32195f15..41b3a3f7b 100644 --- a/build.gradle +++ b/build.gradle @@ -151,7 +151,6 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" - runtimeOnly group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' } // In order to add the jar to the classpath, we need to unzip the From 8e8df7e4ab25d66717675a2b6cbe52b85069bdab Mon Sep 17 00:00:00 2001 From: xinyual Date: Thu, 7 Sep 2023 16:35:16 +0800 Subject: [PATCH 09/70] modify access control Signed-off-by: xinyual --- .../ml/MLCommonsClientAccessor.java | 23 +++++++++++-------- .../processor/SparseEncodingProcessor.java | 16 ++++++------- .../ml/MLCommonsClientAccessorTests.java | 12 +++++----- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 55ee89bbd..feb7539c0 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -108,7 +108,7 @@ public void inferenceSentences( public void inferenceSentencesWithMapResult( @NonNull final String modelId, @NonNull final List inputText, - @NonNull final ActionListener> listener) { + @NonNull final ActionListener>> listener) { retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); } @@ -116,11 +116,11 @@ private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, final int retryTime, - final ActionListener> listener + final ActionListener>> listener ) { MLInput mlInput = createMLInput(null, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { - final Map result = buildMapResultFromResponse(mlOutput); + final List> result = buildMapResultFromResponse(mlOutput); log.debug("Inference Response for input sentence {} is : {} ", inputText, result); listener.onResponse(result); }, e -> { @@ -174,7 +174,7 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } - private Map buildMapResultFromResponse(MLOutput mlOutput) { + private List > buildMapResultFromResponse(MLOutput mlOutput) { final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { @@ -182,13 +182,16 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]" ); } - List tensorList = tensorOutputList.get(0).getMlModelTensors(); - if (tensorList.size() != 1) { - throw new IllegalStateException( - "Unexpected number of map result produced. Expected 1 map result to be returned, but got [" + tensorList.size() + "]" - ); + List > resultMaps = new ArrayList<>(); + for (ModelTensors tensors: tensorOutputList) + { + List tensorList = tensors.getMlModelTensors(); + for (ModelTensor tensor: tensorList) + { + resultMaps.add(tensor.getDataAsMap()); + } } - return tensorList.get(0).getDataAsMap(); + return resultMaps; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 1a9a5c034..6369a4aa0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -96,9 +96,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer { - Map.Entry>> entry = ((Map>>)resultMap).entrySet().iterator().next(); - List > resultTokenWeights = entry.getValue(); + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultTokenWeights -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -109,7 +107,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer processorMap, List > resultTokenWeights) { + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List > resultTokenWeights) { Objects.requireNonNull(resultTokenWeights, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); Map sparseEncodingResult = buildSparseEncodingResult(processorMap, resultTokenWeights, ingestDocument.getSourceAndMetadata()); @@ -191,7 +189,7 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( @VisibleForTesting Map buildSparseEncodingResult( Map processorMap, - List > resultTokenWeights, + List > resultTokenWeights, Map sourceAndMetadataMap ) { SparseEncodingProcessor.IndexWrapper indexWrapper = new SparseEncodingProcessor.IndexWrapper(0); @@ -214,7 +212,7 @@ Map buildSparseEncodingResult( private void putSparseEncodingResultToSourceMapForMapType( String processorKey, Object sourceValue, - List > resultTokenWeights, + List > resultTokenWeights, SparseEncodingProcessor.IndexWrapper indexWrapper, Map sourceAndMetadataMap ) { @@ -239,12 +237,12 @@ private void putSparseEncodingResultToSourceMapForMapType( } } - private List>> buildSparseEncodingResultForListType( + private List>> buildSparseEncodingResultForListType( List sourceValue, - List > resultTokenWeights, + List > resultTokenWeights, SparseEncodingProcessor.IndexWrapper indexWrapper ) { - List>> tokenWeights = new ArrayList<>(); + List>> tokenWeights = new ArrayList<>(); IntStream.range(0, sourceValue.size()) .forEachOrdered(x -> tokenWeights.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, resultTokenWeights.get(indexWrapper.index++)))); return tokenWeights; diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index d7c2cddcb..a635e3fa2 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -165,7 +165,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { final Map map = Map.of("key", "value"); - final ActionListener> resultListener = mock(ActionListener.class); + final ActionListener>> resultListener = mock(ActionListener.class); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); actionListener.onResponse(createModelTensorOutput(map)); @@ -180,7 +180,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { } public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { - final ActionListener> resultListener = mock(ActionListener.class); + final ActionListener>> resultListener = mock(ActionListener.class); final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); @@ -198,7 +198,7 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenE } public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { - final ActionListener> resultListener = mock(ActionListener.class); + final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); tensorsList.add(new ModelTensors(mlModelTensorList)); @@ -219,7 +219,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenEx } public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenException() { - final ActionListener> resultListener = mock(ActionListener.class); + final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); final ModelTensor tensor = new ModelTensor( @@ -260,7 +260,7 @@ public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Ti actionListener.onFailure(nodeNodeConnectedException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - final ActionListener> resultListener = mock(ActionListener.class); + final ActionListener>> resultListener = mock(ActionListener.class); accessor.inferenceSentencesWithMapResult( TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, @@ -279,7 +279,7 @@ public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenF actionListener.onFailure(illegalStateException); return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - final ActionListener> resultListener = mock(ActionListener.class); + final ActionListener>> resultListener = mock(ActionListener.class); accessor.inferenceSentencesWithMapResult( TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, From 5e81ee56b4cf819e3821e4d54bd1af4655e98b89 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Tue, 22 Aug 2023 12:05:37 +0800 Subject: [PATCH 10/70] Add map result support in neural search for non text embedding models Signed-off-by: zane-neo --- .../opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index a635e3fa2..c358295cf 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.ml; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; From 20dd78e55087929af88586bcb66ea45c198bd7bf Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 1 Sep 2023 16:20:15 +0800 Subject: [PATCH 11/70] Fix compilation failure issue Signed-off-by: zane-neo --- .../opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index c358295cf..a635e3fa2 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -5,7 +5,6 @@ package org.opensearch.neuralsearch.ml; -import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; From 05c3be826e9a497eb5d4f57d4b34aecca72f5fa7 Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 8 Sep 2023 17:01:41 +0800 Subject: [PATCH 12/70] change output logic Signed-off-by: xinyual --- .../neuralsearch/processor/SparseEncodingProcessor.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 6369a4aa0..6e5ed3c05 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -96,7 +96,13 @@ public void execute(IngestDocument ingestDocument, BiConsumer { + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { + List > resultTokenWeights = new ArrayList<>(); + for (Map map: resultMaps) + { + resultTokenWeights.addAll((List>)map.get("response") ); + } + log.info(resultTokenWeights); setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); From 734fd50c0f96192ecdf71a22c5e929f11549d1dc Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 10:16:26 +0800 Subject: [PATCH 13/70] create abstract Signed-off-by: xinyual --- .../neuralsearch/processor/NLPProcessor.java | 223 ++++++++++++++++++ .../processor/SparseEncodingProcessor.java | 156 ++---------- 2 files changed, 237 insertions(+), 142 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java new file mode 100644 index 000000000..fcf59d4f9 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -0,0 +1,223 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import java.util.*; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + + +public abstract class NLPProcessor extends AbstractProcessor { + + @VisibleForTesting + protected final String modelId; + + protected final Map fieldMap; + + protected final MLCommonsClientAccessor mlCommonsClientAccessor; + + protected final Environment environment; + + public NLPProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateEmbeddingConfiguration(fieldMap); + + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.size() == 0 + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + )) { + throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private static void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + } + } + } + + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + } + } + + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } + } + } + + + private void buildMapWithProcessorKeyAndOriginalValueForMapType( + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (processorKey == null || sourceAndMetadataMap == null) return; + if (processorKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); + } else { + String key = String.valueOf(processorKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + } + } + + @VisibleForTesting + Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithProcessorKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + Map treeRes = new LinkedHashMap<>(); + buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); + } else { + mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + } + } + return mapWithProcessorKeys; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else { + if (sourceValue == null) return; + texts.add(sourceValue.toString()); + } + } + + + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); + } + }); + return texts; + } + + public abstract void doExecute(IngestDocument ingestDocument,Map ProcessMap, List inferenceList, BiConsumer handler); + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler){ + try { + validateEmbeddingFieldsValue(ingestDocument); + Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(ProcessMap); + if (inferenceList.size() == 0) { + handler.accept(ingestDocument, null); + } else { + doExecute(ingestDocument, ProcessMap, inferenceList, handler); + } + } catch (Exception e) { + handler.accept(null, e); + } + + } + + @Override + public String getType() { + return null; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 6e5ed3c05..4b5ca2453 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -24,7 +24,7 @@ @Log4j2 -public class SparseEncodingProcessor extends AbstractProcessor { +public class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; public static final String MODEL_ID_FIELD = "model_id"; @@ -32,87 +32,25 @@ public class SparseEncodingProcessor extends AbstractProcessor { private static final String LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; - @VisibleForTesting - private final String modelId; - - private final Map fieldMap; - - private final MLCommonsClientAccessor mlCommonsClientAccessor; - - private final Environment environment; - - public SparseEncodingProcessor( - String tag, - String description, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment - ) { - super(tag, description); - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); - validateEmbeddingConfiguration(fieldMap); - - this.modelId = modelId; - this.fieldMap = fieldMap; - this.mlCommonsClientAccessor = clientAccessor; - this.environment = environment; - } - - private void validateEmbeddingConfiguration(Map fieldMap) { - if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() - .stream() - .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) - )) { - throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value"); - } + public SparseEncodingProcessor(String tag, String description, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) { + super(tag, description, modelId, fieldMap, clientAccessor, environment); } @Override - public IngestDocument execute(IngestDocument ingestDocument) { - return ingestDocument; - } - - /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to - * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. - */ - @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { - // When received a bulk indexing request, the pipeline will be executed in this method, (see - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). - // Before the pipeline execution, the pipeline will be marked as resolved (means executed), - // and then this overriding method will be invoked when executing the text embedding processor. - // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. - try { - validateEmbeddingFieldsValue(ingestDocument); - Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(ProcessMap); - if (inferenceList.size() == 0) { - handler.accept(ingestDocument, null); - } else { - mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List > resultTokenWeights = new ArrayList<>(); - for (Map map: resultMaps) - { - resultTokenWeights.addAll((List>)map.get("response") ); - } - log.info(resultTokenWeights); - setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); + public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { + List > resultTokenWeights = new ArrayList<>(); + for (Map map: resultMaps) + { + resultTokenWeights.addAll((List>)map.get("response") ); } - } catch (Exception e) { - handler.accept(null, e); - } - + log.info(resultTokenWeights); + setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); } + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List > resultTokenWeights) { Objects.requireNonNull(resultTokenWeights, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); @@ -120,22 +58,6 @@ void setVectorFieldsToDocument(IngestDocument ingestDocument, Map createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); - } - }); - return texts; - } - @SuppressWarnings("unchecked") private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { if (sourceValue instanceof Map) { @@ -254,56 +176,6 @@ private void putSparseEncodingResultToSourceMapForMapType( return tokenWeights; } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { - if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); - } - } - } - @Override public String getType() { return TYPE; From c00a4cfe146970b4041737dcd71699087b261ff4 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 10:43:04 +0800 Subject: [PATCH 14/70] create abstract proccesor Signed-off-by: xinyual --- .../processor/SparseEncodingProcessor.java | 13 -- .../processor/TextEmbeddingProcessor.java | 128 ++---------------- 2 files changed, 8 insertions(+), 133 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 4b5ca2453..488f8900e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -44,7 +44,6 @@ public void doExecute(IngestDocument ingestDocument, Map Process { resultTokenWeights.addAll((List>)map.get("response") ); } - log.info(resultTokenWeights); setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); @@ -58,18 +57,6 @@ void setVectorFieldsToDocument(IngestDocument ingestDocument, Map texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); - } - } - @VisibleForTesting Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 878a410a8..2bed8fcce 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -32,7 +32,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. */ @Log4j2 -public class TextEmbeddingProcessor extends AbstractProcessor { +public class TextEmbeddingProcessor extends NLPProcessor { public static final String TYPE = "text_embedding"; public static final String MODEL_ID_FIELD = "model_id"; @@ -40,14 +40,6 @@ public class TextEmbeddingProcessor extends AbstractProcessor { private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - @VisibleForTesting - private final String modelId; - - private final Map fieldMap; - - private final MLCommonsClientAccessor mlCommonsClientAccessor; - - private final Environment environment; public TextEmbeddingProcessor( String tag, @@ -57,14 +49,7 @@ public TextEmbeddingProcessor( MLCommonsClientAccessor clientAccessor, Environment environment ) { - super(tag, description); - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); - validateEmbeddingConfiguration(fieldMap); - - this.modelId = modelId; - this.fieldMap = fieldMap; - this.mlCommonsClientAccessor = clientAccessor; - this.environment = environment; + super(tag, description, modelId, fieldMap, clientAccessor, environment); } private void validateEmbeddingConfiguration(Map fieldMap) { @@ -84,37 +69,15 @@ public IngestDocument execute(IngestDocument ingestDocument) { return ingestDocument; } - /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to - * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. - */ @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { - // When received a bulk indexing request, the pipeline will be executed in this method, (see - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). - // Before the pipeline execution, the pipeline will be marked as resolved (means executed), - // and then this overriding method will be invoked when executing the text embedding processor. - // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. - try { - validateEmbeddingFieldsValue(ingestDocument); - Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(knnMap); - if (inferenceList.size() == 0) { - handler.accept(ingestDocument, null); - } else { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, knnMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); - } - } catch (Exception e) { - handler.accept(null, e); - } - + public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); } + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); @@ -122,33 +85,7 @@ void setVectorFieldsToDocument(IngestDocument ingestDocument, Map createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); - } - }); - return texts; - } - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); - } - } @VisibleForTesting Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { @@ -257,55 +194,6 @@ private List>> buildTextEmbeddingResultForListType( return numbers; } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { - if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); - } - } - } @Override public String getType() { From a973b427808eb5bdbd072b944e80fa614cf813ee Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 11:40:17 +0800 Subject: [PATCH 15/70] add abstract class Signed-off-by: xinyual --- .../neuralsearch/processor/NLPProcessor.java | 98 +++++++++++- .../processor/SparseEncodingProcessor.java | 140 +----------------- .../processor/TextEmbeddingProcessor.java | 139 +---------------- 3 files changed, 101 insertions(+), 276 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index fcf59d4f9..f61f99f5d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -6,8 +6,9 @@ package org.opensearch.neuralsearch.processor; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; -import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; @@ -17,8 +18,9 @@ import java.util.*; import java.util.function.BiConsumer; import java.util.function.Supplier; +import java.util.stream.IntStream; - +@Log4j2 public abstract class NLPProcessor extends AbstractProcessor { @VisibleForTesting @@ -30,6 +32,8 @@ public abstract class NLPProcessor extends AbstractProcessor { protected final Environment environment; + protected String LIST_TYPE_NESTED_MAP_KEY = "NLP"; + public NLPProcessor( String tag, String description, @@ -216,8 +220,98 @@ public void execute(IngestDocument ingestDocument, BiConsumer processorMap, List results) { + Objects.requireNonNull(results, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); + nlpResult.forEach(ingestDocument::setFieldValue); + } + + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildNLPResult( + Map processorMap, + List results, + Map sourceAndMetadataMap + ) { + NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : processorMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof String) { + result.put(knnKey, results.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + result.put(knnKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + } else if (sourceValue instanceof Map) { + putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); + } + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putNLPResultToSourceMapForMapType( + String processorKey, + Object sourceValue, + List results, + NLPProcessor.IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putNLPResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put( + processorKey, + buildNLPResultForListType((List) sourceValue, results, indexWrapper) + ); + } + } + + private List> buildNLPResultForListType( + List sourceValue, + List results, + NLPProcessor.IndexWrapper indexWrapper + ) { + List> keyToResult = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, results.get(indexWrapper.index++)))); + return keyToResult; + } + + + @Override public String getType() { return null; } + + /** + * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * the index pointer during the recursive. + * index: the index pointer of the text embedding result. + */ + static class IndexWrapper { + private int index; + + protected IndexWrapper(int index) { + this.index = index; + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 488f8900e..5200c7e75 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -30,158 +30,26 @@ public class SparseEncodingProcessor extends NLPProcessor { public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; - private static final String LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; - public SparseEncodingProcessor(String tag, String description, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) { super(tag, description, modelId, fieldMap, clientAccessor, environment); + this.LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; } @Override public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List > resultTokenWeights = new ArrayList<>(); + List > results = new ArrayList<>(); for (Map map: resultMaps) { - resultTokenWeights.addAll((List>)map.get("response") ); + results.addAll((List>)map.get("response") ); } - setVectorFieldsToDocument(ingestDocument, ProcessMap, resultTokenWeights); + setVectorFieldsToDocument(ingestDocument, ProcessMap, results); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } - - void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List > resultTokenWeights) { - Objects.requireNonNull(resultTokenWeights, "embedding failed, inference returns null result!"); - log.debug("Text embedding result fetched, starting build vector output!"); - Map sparseEncodingResult = buildSparseEncodingResult(processorMap, resultTokenWeights, ingestDocument.getSourceAndMetadata()); - sparseEncodingResult.forEach(ingestDocument::setFieldValue); - } - - @VisibleForTesting - Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithProcessorKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - Map treeRes = new LinkedHashMap<>(); - buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); - } else { - mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); - } - } - return mapWithProcessorKeys; - } - - @SuppressWarnings({ "unchecked" }) - private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (processorKey == null || sourceAndMetadataMap == null) return; - if (processorKey instanceof Map) { - Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); - } - treeRes.put(parentKey, next); - } else { - String key = String.valueOf(processorKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); - } - } - - @SuppressWarnings({ "unchecked" }) - @VisibleForTesting - Map buildSparseEncodingResult( - Map processorMap, - List > resultTokenWeights, - Map sourceAndMetadataMap - ) { - SparseEncodingProcessor.IndexWrapper indexWrapper = new SparseEncodingProcessor.IndexWrapper(0); - Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : processorMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof String) { - result.put(knnKey, resultTokenWeights.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - result.put(knnKey, buildSparseEncodingResultForListType((List) sourceValue, resultTokenWeights, indexWrapper)); - } else if (sourceValue instanceof Map) { - putSparseEncodingResultToSourceMapForMapType(knnKey, sourceValue, resultTokenWeights, indexWrapper, sourceAndMetadataMap); - } - } - return result; - } - - @SuppressWarnings({ "unchecked" }) - private void putSparseEncodingResultToSourceMapForMapType( - String processorKey, - Object sourceValue, - List > resultTokenWeights, - SparseEncodingProcessor.IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putSparseEncodingResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - resultTokenWeights, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) - ); - } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(processorKey, resultTokenWeights.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put( - processorKey, - buildSparseEncodingResultForListType((List) sourceValue, resultTokenWeights, indexWrapper) - ); - } - } - - private List>> buildSparseEncodingResultForListType( - List sourceValue, - List > resultTokenWeights, - SparseEncodingProcessor.IndexWrapper indexWrapper - ) { - List>> tokenWeights = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> tokenWeights.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, resultTokenWeights.get(indexWrapper.index++)))); - return tokenWeights; - } - @Override public String getType() { return TYPE; } - - /** - * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, - * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order - * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the - * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase - * the index pointer during the recursive. - * index: the index pointer of the text embedding result. - */ - static class IndexWrapper { - private int index; - - protected IndexWrapper(int index) { - this.index = index; - } - } - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 2bed8fcce..68dad0bae 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -38,9 +38,6 @@ public class TextEmbeddingProcessor extends NLPProcessor { public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; - private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - - public TextEmbeddingProcessor( String tag, String description, @@ -50,6 +47,7 @@ public TextEmbeddingProcessor( Environment environment ) { super(tag, description, modelId, fieldMap, clientAccessor, environment); + this.LIST_TYPE_NESTED_MAP_KEY = "knn"; } private void validateEmbeddingConfiguration(Map fieldMap) { @@ -77,143 +75,8 @@ public void doExecute(IngestDocument ingestDocument, Map Process }, e -> { handler.accept(null, e); })); } - - void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { - Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); - log.debug("Text embedding result fetched, starting build vector output!"); - Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); - textEmbeddingResult.forEach(ingestDocument::setFieldValue); - } - - - - @VisibleForTesting - Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithKnnKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - Map treeRes = new LinkedHashMap<>(); - buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); - } else { - mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); - } - } - return mapWithKnnKeys; - } - - @SuppressWarnings({ "unchecked" }) - private void buildMapWithKnnKeyAndOriginalValueForMapType( - String parentKey, - Object knnKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (knnKey == null || sourceAndMetadataMap == null) return; - if (knnKey instanceof Map) { - Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) knnKey).entrySet()) { - buildMapWithKnnKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); - } - treeRes.put(parentKey, next); - } else { - String key = String.valueOf(knnKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); - } - } - - @SuppressWarnings({ "unchecked" }) - @VisibleForTesting - Map buildTextEmbeddingResult( - Map knnMap, - List> modelTensorList, - Map sourceAndMetadataMap - ) { - IndexWrapper indexWrapper = new IndexWrapper(0); - Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : knnMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof String) { - List modelTensor = modelTensorList.get(indexWrapper.index++); - result.put(knnKey, modelTensor); - } else if (sourceValue instanceof List) { - result.put(knnKey, buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper)); - } else if (sourceValue instanceof Map) { - putTextEmbeddingResultToSourceMapForMapType(knnKey, sourceValue, modelTensorList, indexWrapper, sourceAndMetadataMap); - } - } - return result; - } - - @SuppressWarnings({ "unchecked" }) - private void putTextEmbeddingResultToSourceMapForMapType( - String knnKey, - Object sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putTextEmbeddingResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - modelTensorList, - indexWrapper, - (Map) sourceAndMetadataMap.get(knnKey) - ); - } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put( - knnKey, - buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) - ); - } - } - - private List>> buildTextEmbeddingResultForListType( - List sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper - ) { - List>> numbers = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); - return numbers; - } - - @Override public String getType() { return TYPE; } - - /** - * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, - * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order - * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the - * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase - * the index pointer during the recursive. - * index: the index pointer of the text embedding result. - */ - static class IndexWrapper { - private int index; - - protected IndexWrapper(int index) { - this.index = index; - } - } - } From 30fc444a1aa646def380edb4c3d65bb5b283e484 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 13:44:36 +0800 Subject: [PATCH 16/70] remove duplicate code Signed-off-by: xinyual --- .../neuralsearch/processor/TextEmbeddingProcessor.java | 5 ----- .../neuralsearch/processor/TextEmbeddingProcessorTests.java | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 68dad0bae..eb8b299d3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -62,11 +62,6 @@ private void validateEmbeddingConfiguration(Map fieldMap) { } } - @Override - public IngestDocument execute(IngestDocument ingestDocument) { - return ingestDocument; - } - @Override public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index da0a46954..f4da16534 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -309,7 +309,7 @@ private Map createMaxDepthLimitExceedMap(Supplier maxDe return innerMap; } - public void testExecute_hybridTypeInput_successful() { + public void testExecute_hybridTypeInput_successful() throws Exception { List list1 = ImmutableList.of("test1", "test2"); Map> map1 = ImmutableMap.of("test3", list1); Map sourceAndMetadata = new HashMap<>(); From e2a30de74924b2d168d81887a8afb432ee303f7f Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 13:45:22 +0800 Subject: [PATCH 17/70] remove duplicate code Signed-off-by: xinyual --- .../processor/TextEmbeddingProcessor.java | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index eb8b299d3..16735ce62 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -50,18 +50,6 @@ public TextEmbeddingProcessor( this.LIST_TYPE_NESTED_MAP_KEY = "knn"; } - private void validateEmbeddingConfiguration(Map fieldMap) { - if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() - .stream() - .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) - )) { - throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value"); - } - } - @Override public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { From 6b94a1758b5cdb38a79a9b118781e43e303d28f2 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 13:52:07 +0800 Subject: [PATCH 18/70] remove dl process Signed-off-by: xinyual --- .../neuralsearch/processor/DLProcessor.java | 25 ------------------- .../processor/SparseEncodingProcessor.java | 4 +-- .../SparseEncodingProcessorFactory.java | 1 - 3 files changed, 2 insertions(+), 28 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java deleted file mode 100644 index a042fb517..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/DLProcessor.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor; - -import org.apache.commons.lang3.StringUtils; -import org.opensearch.core.action.ActionListener; -import org.opensearch.env.Environment; -import org.opensearch.ingest.AbstractProcessor; -import org.opensearch.ingest.IngestDocument; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; - -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; - -public abstract class DLProcessor extends AbstractProcessor { - - - protected DLProcessor(String tag, String description) { - super(tag, description); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 5200c7e75..39b2c3f81 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -38,10 +38,10 @@ public SparseEncodingProcessor(String tag, String description, String modelId, M @Override public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List > results = new ArrayList<>(); + List > results = new ArrayList<>(); for (Map map: resultMaps) { - results.addAll((List>)map.get("response") ); + results.addAll((List>)map.get("response") ); } setVectorFieldsToDocument(ingestDocument, ProcessMap, results); handler.accept(ingestDocument, null); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 27ec79d57..8065af445 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -9,7 +9,6 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.DLProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import java.util.Map; From a3d09bdc8377693217c4747fc35958f6aaf2e3b7 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 11 Sep 2023 14:25:56 +0800 Subject: [PATCH 19/70] move static to abstract class Signed-off-by: xinyual --- .../org/opensearch/neuralsearch/processor/NLPProcessor.java | 3 +++ .../neuralsearch/processor/SparseEncodingProcessor.java | 2 -- .../neuralsearch/processor/TextEmbeddingProcessor.java | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index f61f99f5d..bdec04e13 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -26,6 +26,9 @@ public abstract class NLPProcessor extends AbstractProcessor { @VisibleForTesting protected final String modelId; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + protected final Map fieldMap; protected final MLCommonsClientAccessor mlCommonsClientAccessor; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 39b2c3f81..0bf157e07 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -27,8 +27,6 @@ public class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; - public static final String MODEL_ID_FIELD = "model_id"; - public static final String FIELD_MAP_FIELD = "field_map"; public SparseEncodingProcessor(String tag, String description, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) { super(tag, description, modelId, fieldMap, clientAccessor, environment); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 16735ce62..3fa0ef114 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -35,8 +35,6 @@ public class TextEmbeddingProcessor extends NLPProcessor { public static final String TYPE = "text_embedding"; - public static final String MODEL_ID_FIELD = "model_id"; - public static final String FIELD_MAP_FIELD = "field_map"; public TextEmbeddingProcessor( String tag, From f10c94d625eb4ad993c84c6bb625a3233b675380 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 8 Sep 2023 15:37:01 +0800 Subject: [PATCH 20/70] update query rewrite logic Signed-off-by: zhichao-aws --- .../index/mapper/SparseVectorMapper.java | 32 +-- .../neuralsearch/plugin/NeuralSearch.java | 1 + .../query/sparse/SparseQueryBuilder.java | 186 +++++++++++++++--- .../neuralsearch/util/TokenWeightUtil.java | 36 ++++ 4 files changed, 214 insertions(+), 41 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java diff --git a/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java b/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java index f117d31dc..e92d9fb64 100644 --- a/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java +++ b/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java @@ -47,16 +47,16 @@ public static class SparseVectorBuilder extends ParametrizedFieldMapper.Builder // However, The default behavior of lucene FeatureQuery is returning Float.MAX_VALUE for every term. Which will // invalidate WAND algorithm. - // By setting maxTermScoreForSparseQuery, we'll use it as the term score upperbound to accelerate the search. - // Users can also overwrite this setting in sparse query. Our experiments show a proper maxTermScoreForSparseQuery + // By setting tokenScoreUpperBound, we'll use it as the term score upperbound to accelerate the search. + // Users can also overwrite this setting in sparse query. Our experiments show a proper tokenScoreUpperBound // value can reduce search latency by 4x while losing precision less than 0.5%. // If user doesn't set the value explicitly, we'll degrade to the default behavior in lucene FeatureQuery, // i.e. using Float.MAX_VALUE. - private final Parameter maxTermScoreForSparseQuery = Parameter.floatParam( - "max_term_score_for_sparse_query", + private final Parameter tokenScoreUpperBound = Parameter.floatParam( + "token_score_upper_bound", false, - m -> toType(m).maxTermScoreForSparseQuery, + m -> toType(m).tokenScoreUpperBound, Float.MAX_VALUE ); @@ -68,17 +68,17 @@ public SparseVectorBuilder( @Override protected List> getParameters() { - return Arrays.asList(maxTermScoreForSparseQuery, meta); + return Arrays.asList(tokenScoreUpperBound, meta); } @Override public SparseVectorMapper build(BuilderContext context) { return new SparseVectorMapper( name, - new SparseVectorFieldType(buildFullName(context), meta.getValue(), maxTermScoreForSparseQuery.getValue()), + new SparseVectorFieldType(buildFullName(context), meta.getValue(), tokenScoreUpperBound.getValue()), multiFieldsBuilder.build(this, context), copyTo.build(), - maxTermScoreForSparseQuery.getValue() + tokenScoreUpperBound.getValue() ); } } @@ -86,19 +86,19 @@ public SparseVectorMapper build(BuilderContext context) { public static final TypeParser PARSER = new TypeParser((n, c) -> new SparseVectorBuilder(n)); public static final class SparseVectorFieldType extends MappedFieldType { - private final float maxTermScoreForSparseQuery; + private final float tokenScoreUpperBound; public SparseVectorFieldType( String name, Map meta, - float maxTermScoreForSparseQuery + float tokenScoreUpperBound ) { super(name, true, false, true, TextSearchInfo.NONE, meta); - this.maxTermScoreForSparseQuery = maxTermScoreForSparseQuery; + this.tokenScoreUpperBound = tokenScoreUpperBound; } - public float maxTermScoreForSparseQuery() { - return maxTermScoreForSparseQuery; + public float tokenScoreUpperBound() { + return tokenScoreUpperBound; } @Override @@ -131,17 +131,17 @@ public Query termQuery(Object value, QueryShardContext context) { } } - private final float maxTermScoreForSparseQuery; + private final float tokenScoreUpperBound; protected SparseVectorMapper( String simpleName, MappedFieldType mappedFieldType, MultiFields multiFields, CopyTo copyTo, - float maxTermScoreForSparseQuery + float tokenScoreUpperBound ) { super(simpleName, mappedFieldType, multiFields, copyTo); - this.maxTermScoreForSparseQuery = maxTermScoreForSparseQuery; + this.tokenScoreUpperBound = tokenScoreUpperBound; } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index ff9b20ce5..ec1f945d1 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -86,6 +86,7 @@ public Collection createComponents( final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); + SparseQueryBuilder.initialize(clientAccessor); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java index 6a97d69d6..0cea18f8e 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java @@ -7,7 +7,9 @@ import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.function.Supplier; import lombok.AllArgsConstructor; import lombok.Getter; @@ -22,7 +24,9 @@ import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; +import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -30,10 +34,14 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.index.mapper.SparseVectorMapper; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.TokenWeightUtil; @Log4j2 @Getter @@ -43,31 +51,63 @@ @AllArgsConstructor public class SparseQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "sparse"; - @VisibleForTesting static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens"); + @VisibleForTesting + static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); + @VisibleForTesting + static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); + @VisibleForTesting + static final ParseField TOKEN_SCORE_UPPER_BOUND_FIELD = new ParseField("token_score_upper_bound"); + + private static MLCommonsClientAccessor ML_CLIENT; + + public static void initialize(MLCommonsClientAccessor mlClient) { + SparseQueryBuilder.ML_CLIENT = mlClient; + } private String fieldName; - // todo: if termWeight is null - private Map termWeight; + private Map queryTokens; + private String queryText; + private String modelId; + private Float tokenScoreUpperBound; + private Supplier> queryTokensSupplier; public SparseQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); - this.termWeight = in.readMap(StreamInput::readString, StreamInput::readFloat); + // we don't have readOptionalMap or write, need to do it manually + if (in.readBoolean()) { + this.queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); + } + this.queryText = in.readOptionalString(); + this.modelId = in.readOptionalString(); + this.tokenScoreUpperBound = in.readOptionalFloat(); } @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); - out.writeMap(termWeight, StreamOutput::writeString, StreamOutput::writeFloat); + if (null != queryTokens) { + out.writeBoolean(true); + out.writeMap(queryTokens, StreamOutput::writeString, StreamOutput::writeFloat); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(queryText); + out.writeOptionalString(modelId); + out.writeOptionalFloat(tokenScoreUpperBound); } @Override protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); - xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), termWeight); + if (null != queryTokens) xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokens); + if (null != queryText) xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + if (null != modelId) xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + if (null != tokenScoreUpperBound) + xContentBuilder.field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), tokenScoreUpperBound); printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -82,14 +122,24 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * "token_b": float, * ... * }, - * "max_term_score_for_sparse_query": float (optional) + * "token_score_upper_bound": float (optional) * } * } + * or + * "SAMPLE_FIELD": { + * "query_text": "string", + * "model_id": "string", + * "token_score_upper_bound": float (optional) + * } + * */ public static SparseQueryBuilder fromXContent(XContentParser parser) throws IOException { SparseQueryBuilder sparseQueryBuilder = new SparseQueryBuilder(); if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT"); + throw new ParsingException( + parser.getTokenLocation(), + "First token of " + NAME + "query must be START_OBJECT" + ); } parser.nextToken(); sparseQueryBuilder.fieldName(parser.currentName()); @@ -111,8 +161,10 @@ public static SparseQueryBuilder fromXContent(XContentParser parser) throws IOEx return sparseQueryBuilder; } - // todo: refactor this to switch style - private static void parseQueryParams(XContentParser parser, SparseQueryBuilder sparseQueryBuilder) throws IOException { + private static void parseQueryParams( + XContentParser parser, + SparseQueryBuilder sparseQueryBuilder + ) throws IOException { XContentParser.Token token; String currentFieldName = ""; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { @@ -123,6 +175,12 @@ private static void parseQueryParams(XContentParser parser, SparseQueryBuilder s sparseQueryBuilder.queryName(parser.text()); } else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseQueryBuilder.boost(parser.floatValue()); + } else if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseQueryBuilder.queryText(parser.text()); + } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseQueryBuilder.modelId(parser.text()); + } else if (TOKEN_SCORE_UPPER_BOUND_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseQueryBuilder.tokenScoreUpperBound(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -130,8 +188,7 @@ private static void parseQueryParams(XContentParser parser, SparseQueryBuilder s ); } } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.termWeight(parser.map(HashMap::new, XContentParser::floatValue)); -// sparseQueryBuilder.termWeight(castToTermWeight(parser.map())); + sparseQueryBuilder.queryTokens(parser.map(HashMap::new, XContentParser::floatValue)); } else { throw new ParsingException( parser.getTokenLocation(), @@ -141,25 +198,62 @@ private static void parseQueryParams(XContentParser parser, SparseQueryBuilder s } } + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + // If the user has specified query_tokens field, then we don't need to inference the sentence, + // just re-rewrite to self. Otherwise, we need to inference the sentence to get the queryTokens. Then the + // logic is similar to NeuralQueryBuilder + if (null != queryTokens) { + return this; + } + if (null != queryTokensSupplier) { + return queryTokensSupplier.get() == null ? this : + new SparseQueryBuilder() + .fieldName(fieldName) + .queryTokens(queryTokensSupplier.get()) + .queryText(queryText) + .modelId(modelId) + .tokenScoreUpperBound(tokenScoreUpperBound); + } + + validateForRewrite(queryText, modelId); + SetOnce> queryTokensSetOnce = new SetOnce<>(); + queryRewriteContext.registerAsyncAction( + ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult( + modelId(), + List.of(queryText), + ActionListener.wrap(mapResult -> { + queryTokensSetOnce.set(TokenWeightUtil.fetchQueryTokensList(mapResult).get(0)); + actionListener.onResponse(null); + }, actionListener::onFailure)) + ) + ); + return new SparseQueryBuilder() + .fieldName(fieldName) + .queryText(queryText) + .modelId(modelId) + .tokenScoreUpperBound(tokenScoreUpperBound) + .queryTokensSupplier(queryTokensSetOnce::get); + } + @Override protected Query doToQuery(QueryShardContext context) throws IOException { final MappedFieldType ft = context.fieldMapper(fieldName); - if (!(ft instanceof SparseVectorMapper.SparseVectorFieldType)) { - throw new IllegalArgumentException( - "[" + NAME + "] query only works on [" + SparseVectorMapper.CONTENT_TYPE + "] fields, " - + "not [" + ft.typeName() + "]" - ); - } - final Float maxTermScoreForSparseQuery = ((SparseVectorMapper.SparseVectorFieldType) ft).maxTermScoreForSparseQuery(); + validateFieldType(ft); + validateQueryTokens(queryTokens); + + // the tokenScoreUpperBound from query has higher priority + final Float scoreUpperBound = null != tokenScoreUpperBound? tokenScoreUpperBound: + ((SparseVectorMapper.SparseVectorFieldType) ft).tokenScoreUpperBound(); BooleanQuery.Builder builder = new BooleanQuery.Builder(); - for (Map.Entry entry: termWeight.entrySet()) { + for (Map.Entry entry: queryTokens.entrySet()) { builder.add( new BoostQuery( new BoundedLinearFeatureQuery( fieldName, entry.getKey(), - maxTermScoreForSparseQuery + scoreUpperBound ), entry.getValue() ), @@ -169,20 +263,62 @@ protected Query doToQuery(QueryShardContext context) throws IOException { return builder.build(); } + private static void validateForRewrite(String queryText, String modelId) { + if (null == queryText||null == modelId) { + throw new IllegalArgumentException( + "When " + QUERY_TOKENS_FIELD.getPreferredName() + " are not provided," + + QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + + " can not be null." + ); + } + } + + private static void validateFieldType(MappedFieldType fieldType) { + if (!(fieldType instanceof SparseVectorMapper.SparseVectorFieldType)) { + throw new IllegalArgumentException( + "[" + NAME + "] query only works on [" + SparseVectorMapper.CONTENT_TYPE + "] fields, " + + "not [" + fieldType.typeName() + "]" + ); + } + } + + private static void validateQueryTokens(Map queryTokens) { + if (null == queryTokens) { + throw new IllegalArgumentException( + QUERY_TOKENS_FIELD.getPreferredName() + " field can not be null." + ); + } + for (Map.Entry entry: queryTokens.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException( + "weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()); + } + } + } + @Override protected boolean doEquals(SparseQueryBuilder obj) { // todo: validate if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; - EqualsBuilder equalsBuilder = new EqualsBuilder(); - equalsBuilder.append(fieldName, obj.fieldName); - equalsBuilder.append(termWeight, obj.termWeight); + EqualsBuilder equalsBuilder = new EqualsBuilder() + .append(fieldName, obj.fieldName) + .append(queryTokens, obj.queryTokens) + .append(queryText, obj.queryText) + .append(modelId, obj.modelId) + .append(tokenScoreUpperBound, obj.tokenScoreUpperBound); return equalsBuilder.isEquals(); } @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName).append(termWeight).toHashCode(); + return new HashCodeBuilder() + .append(fieldName) + .append(queryTokens) + .append(queryText) + .append(modelId) + .append(tokenScoreUpperBound) + .toHashCode(); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java new file mode 100644 index 000000000..92fcef48f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class TokenWeightUtil { + /** + * Converts a Map result to Map + * + * @param mapResult {@link Map} of String and ? + * @return Map of String and Float + */ + public static String RESPONSE_KEY = "response"; + + public static List> fetchQueryTokensList(Map mapResult) { + assert mapResult.get(RESPONSE_KEY) instanceof List; + List> responseList = (List) mapResult.get(RESPONSE_KEY); + return responseList.stream().map(TokenWeightUtil::buildQueryTokensMap).collect(Collectors.toList()); + } + + private static Map buildQueryTokensMap(Map mapResult) { + Map result = new HashMap<>(); + for (Map.Entry entry: mapResult.entrySet()) { + assert entry.getValue() instanceof Number; + result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); + } + return result; + } +} From 589b1c0159939df95067071adb46bc26d9cb7346 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 11 Sep 2023 11:45:12 +0800 Subject: [PATCH 21/70] modify header Signed-off-by: zhichao-aws --- .../sparse/BoundedLinearFeatureQuery.java | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java index 7b593c7b7..f0b9b498c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java @@ -1,3 +1,11 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -15,7 +23,15 @@ * limitations under the License. */ -/* This class is built based on lucene FeatureQuery. We use LinearFuntion and add an upperbound to it */ +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +/* + * This class is built based on lucene FeatureQuery. We use LinearFuntion to + * build the query and add an upperbound to it. + */ package org.opensearch.neuralsearch.query.sparse; From ec3f4267c0575c4b2eb65d06337259799e6abe8f Mon Sep 17 00:00:00 2001 From: xinyual Date: Fri, 22 Sep 2023 14:31:42 +0800 Subject: [PATCH 22/70] merge conflict Signed-off-by: xinyual --- .../ml/MLCommonsClientAccessor.java | 40 +++++++++---------- .../query/sparse/SparseQueryBuilder.java | 4 +- .../neuralsearch/util/TokenWeightUtil.java | 23 ++++++----- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index feb7539c0..bedeff46e 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -47,16 +47,16 @@ public class MLCommonsClientAccessor { * @param listener {@link ActionListener} which will be called when prediction is completed or errored out */ public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @NonNull final ActionListener> listener + @NonNull final String modelId, + @NonNull final String inputText, + @NonNull final ActionListener> listener ) { inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( - new IllegalStateException( - "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]" - ) + new IllegalStateException( + "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]" + ) ); return; } @@ -77,9 +77,9 @@ public void inferenceSentence( * @param listener {@link ActionListener} which will be called when prediction is completed or errored out */ public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener ) { inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener); } @@ -97,10 +97,10 @@ public void inferenceSentences( * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. */ public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener + @NonNull final List targetResponseFilters, + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener ) { retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); } @@ -134,11 +134,11 @@ private void retryableInferenceSentencesWithMapResult( } private void retryableInferenceSentencesWithVectorResult( - final List targetResponseFilters, - final String modelId, - final List inputText, - final int retryTime, - final ActionListener>> listener + final List targetResponseFilters, + final String modelId, + final List inputText, + final int retryTime, + final ActionListener>> listener ) { MLInput mlInput = createMLInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { @@ -179,7 +179,7 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { throw new IllegalStateException( - "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]" + "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]" ); } List > resultMaps = new ArrayList<>(); @@ -194,4 +194,4 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return resultMaps; } -} +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java index 0cea18f8e..258bee68e 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java @@ -222,8 +222,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult( modelId(), List.of(queryText), - ActionListener.wrap(mapResult -> { - queryTokensSetOnce.set(TokenWeightUtil.fetchQueryTokensList(mapResult).get(0)); + ActionListener.wrap(mapResultList -> { + queryTokensSetOnce.set(TokenWeightUtil.fetchQueryTokensList(mapResultList).get(0)); actionListener.onResponse(null); }, actionListener::onFailure)) ) diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 92fcef48f..3bb0eeb20 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -11,6 +11,7 @@ import java.util.stream.Collectors; public class TokenWeightUtil { + //todo: change comments, add validation (throw exception) /** * Converts a Map result to Map * @@ -19,18 +20,22 @@ public class TokenWeightUtil { */ public static String RESPONSE_KEY = "response"; - public static List> fetchQueryTokensList(Map mapResult) { - assert mapResult.get(RESPONSE_KEY) instanceof List; - List> responseList = (List) mapResult.get(RESPONSE_KEY); - return responseList.stream().map(TokenWeightUtil::buildQueryTokensMap).collect(Collectors.toList()); + public static List> fetchQueryTokensList(List> mapResultList) { + return mapResultList.stream().map(TokenWeightUtil::buildQueryTokensMap).collect(Collectors.toList()); } + @SuppressWarnings("unchecked") private static Map buildQueryTokensMap(Map mapResult) { - Map result = new HashMap<>(); - for (Map.Entry entry: mapResult.entrySet()) { - assert entry.getValue() instanceof Number; - result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); + Object response = mapResult.get(RESPONSE_KEY); + if (response instanceof Map) { + Map result = new HashMap<>(); + for (Map.Entry entry: ((Map) response).entrySet()) { + assert entry.getValue() instanceof Number; + result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); + } + return result; + } else { + throw new IllegalArgumentException("wrong type"); } - return result; } } From a8520d3f675955944d5e44212210801e2389465c Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 13 Sep 2023 13:44:27 +0800 Subject: [PATCH 23/70] delete index mapper, change to rank_features Signed-off-by: zhichao-aws --- .../index/mapper/SparseVectorMapper.java | 178 ------------------ .../neuralsearch/plugin/NeuralSearch.java | 18 +- .../query/sparse/SparseQueryBuilder.java | 8 +- 3 files changed, 4 insertions(+), 200 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java diff --git a/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java b/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java deleted file mode 100644 index e92d9fb64..000000000 --- a/src/main/java/org/opensearch/neuralsearch/index/mapper/SparseVectorMapper.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.index.mapper; - -import org.apache.lucene.document.FeatureField; -import org.apache.lucene.search.Query; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.mapper.FieldMapper; -import org.opensearch.index.mapper.MappedFieldType; -import org.opensearch.index.mapper.ParametrizedFieldMapper; -import org.opensearch.index.mapper.ParseContext; -import org.opensearch.index.mapper.SourceValueFetcher; -import org.opensearch.index.mapper.TextSearchInfo; -import org.opensearch.index.mapper.ValueFetcher; -import org.opensearch.index.query.QueryShardContext; -import org.opensearch.index.query.QueryShardException; -import org.opensearch.search.lookup.SearchLookup; - -import java.io.IOException; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * A FieldMapper that exposes Lucene's {@link FeatureField}. - * It is designed for learned sparse vectors, the expected for ingested content is a map of (token,weight) pairs, with String and Float type. - * In current version, this field doesn't support existing query clauses like "match" or "exists". - * The ingested documents can only be searched with our "sparse" query clause. - */ -public class SparseVectorMapper extends ParametrizedFieldMapper { - public static final String CONTENT_TYPE = "sparse_vector"; - - private static SparseVectorMapper toType(FieldMapper in) { - return (SparseVectorMapper) in; - } - - public static class SparseVectorBuilder extends ParametrizedFieldMapper.Builder { - - private final Parameter> meta = Parameter.metaParam(); - // Both match query and our sparse query use lucene Boolean query to connect all term-level queries. - // lucene BooleanQuery use WAND (Weak AND) algorithm to accelerate the search, and WAND algorithm - // uses term's max possible value to skip unnecessary calculations. The max possible value in match clause is term idf value. - // However, The default behavior of lucene FeatureQuery is returning Float.MAX_VALUE for every term. Which will - // invalidate WAND algorithm. - - // By setting tokenScoreUpperBound, we'll use it as the term score upperbound to accelerate the search. - // Users can also overwrite this setting in sparse query. Our experiments show a proper tokenScoreUpperBound - // value can reduce search latency by 4x while losing precision less than 0.5%. - - // If user doesn't set the value explicitly, we'll degrade to the default behavior in lucene FeatureQuery, - // i.e. using Float.MAX_VALUE. - private final Parameter tokenScoreUpperBound = Parameter.floatParam( - "token_score_upper_bound", - false, - m -> toType(m).tokenScoreUpperBound, - Float.MAX_VALUE - ); - - public SparseVectorBuilder( - String name - ) { - super(name); - } - - @Override - protected List> getParameters() { - return Arrays.asList(tokenScoreUpperBound, meta); - } - - @Override - public SparseVectorMapper build(BuilderContext context) { - return new SparseVectorMapper( - name, - new SparseVectorFieldType(buildFullName(context), meta.getValue(), tokenScoreUpperBound.getValue()), - multiFieldsBuilder.build(this, context), - copyTo.build(), - tokenScoreUpperBound.getValue() - ); - } - } - - public static final TypeParser PARSER = new TypeParser((n, c) -> new SparseVectorBuilder(n)); - - public static final class SparseVectorFieldType extends MappedFieldType { - private final float tokenScoreUpperBound; - - public SparseVectorFieldType( - String name, - Map meta, - float tokenScoreUpperBound - ) { - super(name, true, false, true, TextSearchInfo.NONE, meta); - this.tokenScoreUpperBound = tokenScoreUpperBound; - } - - public float tokenScoreUpperBound() { - return tokenScoreUpperBound; - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public ValueFetcher valueFetcher(QueryShardContext context, SearchLookup searchLookup, String format) { - if (format != null) { - throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] does not support format"); - } - return SourceValueFetcher.identity(name(), context, format); - } - - @Override - public Query existsQuery(QueryShardContext context) { - throw new QueryShardException( - context, - "Field [" + name() + "] of type [" + typeName() + "] does not support exists query for now" - ); - } - - @Override - public Query termQuery(Object value, QueryShardContext context) { - throw new QueryShardException( - context, - "Field [" + name() + "] of type [" + typeName() + "] does not support term query for now" - ); - } - } - - private final float tokenScoreUpperBound; - - protected SparseVectorMapper( - String simpleName, - MappedFieldType mappedFieldType, - MultiFields multiFields, - CopyTo copyTo, - float tokenScoreUpperBound - ) { - super(simpleName, mappedFieldType, multiFields, copyTo); - this.tokenScoreUpperBound = tokenScoreUpperBound; - } - - @Override - public ParametrizedFieldMapper.Builder getMergeBuilder() { - return new SparseVectorBuilder(simpleName()).init(this); - } - - @Override - protected void parseCreateField(ParseContext context) throws IOException { - if (XContentParser.Token.START_OBJECT != context.parser().currentToken()) { - throw new IllegalArgumentException( - "Wrong format for input data. Field type " + typeName() + " can only parse map object." - - ); - } - final Map termWeight = context.parser().map(HashMap::new, XContentParser::floatValue); - for (Map.Entry entry: termWeight.entrySet()) { - context.doc().add(new FeatureField(fieldType().name(), entry.getKey(), entry.getValue())); - } - } - - // Users are not supposed to give an array for the input value. - // Here we set the return value of parsesArrayValue() as true, - // intercept the request and throw an exception in parseCreateField() - @Override - public final boolean parsesArrayValue() { - return true; - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } -} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index ec1f945d1..94abe44e7 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -21,10 +21,8 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; -import org.opensearch.index.mapper.Mapper; import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; -import org.opensearch.neuralsearch.index.mapper.SparseVectorMapper; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; @@ -59,13 +57,7 @@ * Neural Search plugin class */ @Log4j2 -public class NeuralSearch extends Plugin implements - ActionPlugin, - SearchPlugin, - IngestPlugin, - ExtensiblePlugin, - SearchPipelinePlugin, - MapperPlugin { +public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); @@ -100,14 +92,6 @@ public List> getQueries() { ); } - @Override - public Map getMappers() { - return Collections.singletonMap( - SparseVectorMapper.CONTENT_TYPE, - SparseVectorMapper.PARSER - ); - } - @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java index 258bee68e..a967be2d8 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java @@ -39,7 +39,6 @@ import org.opensearch.index.query.QueryShardContext; import com.google.common.annotations.VisibleForTesting; -import org.opensearch.neuralsearch.index.mapper.SparseVectorMapper; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.TokenWeightUtil; @@ -243,8 +242,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { validateQueryTokens(queryTokens); // the tokenScoreUpperBound from query has higher priority - final Float scoreUpperBound = null != tokenScoreUpperBound? tokenScoreUpperBound: - ((SparseVectorMapper.SparseVectorFieldType) ft).tokenScoreUpperBound(); + final Float scoreUpperBound = null != tokenScoreUpperBound? tokenScoreUpperBound: Float.MAX_VALUE; BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Map.Entry entry: queryTokens.entrySet()) { @@ -274,9 +272,9 @@ private static void validateForRewrite(String queryText, String modelId) { } private static void validateFieldType(MappedFieldType fieldType) { - if (!(fieldType instanceof SparseVectorMapper.SparseVectorFieldType)) { + if (!fieldType.typeName().equals("rank_features")) { throw new IllegalArgumentException( - "[" + NAME + "] query only works on [" + SparseVectorMapper.CONTENT_TYPE + "] fields, " + "[" + NAME + "] query only works on [rank_features] fields, " + "not [" + fieldType.typeName() + "]" ); } From b964d6c45097b65d3ae3b83cf4d5b08f3e484a52 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 13 Sep 2023 13:46:13 +0800 Subject: [PATCH 24/70] remove unused import Signed-off-by: zhichao-aws --- .../java/org/opensearch/neuralsearch/plugin/NeuralSearch.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 94abe44e7..6634775ba 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -42,7 +42,6 @@ import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; -import org.opensearch.plugins.MapperPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; From be45f86116fab760626c0db4b224f93bf0d51780 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 13 Sep 2023 16:56:09 +0800 Subject: [PATCH 25/70] list return result Signed-off-by: zhichao-aws --- .../opensearch/neuralsearch/util/TokenWeightUtil.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 3bb0eeb20..ab5edaa9d 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -27,15 +27,20 @@ public static List> fetchQueryTokensList(List> @SuppressWarnings("unchecked") private static Map buildQueryTokensMap(Map mapResult) { Object response = mapResult.get(RESPONSE_KEY); + Map result = new HashMap<>(); if (response instanceof Map) { - Map result = new HashMap<>(); for (Map.Entry entry: ((Map) response).entrySet()) { assert entry.getValue() instanceof Number; result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); } return result; } else { - throw new IllegalArgumentException("wrong type"); + assert response instanceof List; + for (Map.Entry entry: ((Map) ((List) response).get(0)).entrySet()) { + assert entry.getValue() instanceof Number; + result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); + } + return result; } } } From dbe00fd3c5c6870fa049a7eb6f5453f35fa04ae7 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Sep 2023 15:55:00 +0800 Subject: [PATCH 26/70] refactor type and listTypeNestedMapKey, tidy Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NLPProcessor.java | 35 ++++++++----------- .../processor/SparseEncodingProcessor.java | 9 ++--- .../processor/TextEmbeddingProcessor.java | 9 ++--- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index bdec04e13..e502cb61c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -23,23 +23,27 @@ @Log4j2 public abstract class NLPProcessor extends AbstractProcessor { - @VisibleForTesting - protected final String modelId; - public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; - protected final Map fieldMap; + protected final String type; + + protected final String listTypeNestedMapKey; + + @VisibleForTesting + protected final String modelId; - protected final MLCommonsClientAccessor mlCommonsClientAccessor; + protected final Map fieldMap; - protected final Environment environment; + protected final MLCommonsClientAccessor mlCommonsClientAccessor; - protected String LIST_TYPE_NESTED_MAP_KEY = "NLP"; + protected final Environment environment; public NLPProcessor( String tag, String description, + String type, + String listTypeNestedMapKey, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, @@ -49,14 +53,14 @@ public NLPProcessor( if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); validateEmbeddingConfiguration(fieldMap); + this.type = type; + this.listTypeNestedMapKey = listTypeNestedMapKey; this.modelId = modelId; this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; } - - private void validateEmbeddingConfiguration(Map fieldMap) { if (fieldMap == null || fieldMap.size() == 0 @@ -82,7 +86,6 @@ private static void validateListTypeValue(String sourceKey, Object sourceValue) } } - @SuppressWarnings({ "rawtypes", "unchecked" }) private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); @@ -102,7 +105,6 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl } } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { @@ -121,7 +123,6 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { } } - private void buildMapWithProcessorKeyAndOriginalValueForMapType( String parentKey, Object processorKey, @@ -176,7 +177,6 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List } } - @SuppressWarnings({ "unchecked" }) private List createInferenceList(Map knnKeyMap) { List texts = new ArrayList<>(); @@ -220,10 +220,8 @@ public void execute(IngestDocument ingestDocument, BiConsumer processorMap, List results) { Objects.requireNonNull(results, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); @@ -231,7 +229,6 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map buildNLPResult( @@ -291,15 +288,13 @@ private List> buildNLPResultForListType( ) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, results.get(indexWrapper.index++)))); + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); return keyToResult; } - - @Override public String getType() { - return null; + return type; } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 0bf157e07..814cda38d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -27,10 +27,10 @@ public class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; + public static final String LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; public SparseEncodingProcessor(String tag, String description, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) { - super(tag, description, modelId, fieldMap, clientAccessor, environment); - this.LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); } @Override @@ -45,9 +45,4 @@ public void doExecute(IngestDocument ingestDocument, Map Process handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } - - @Override - public String getType() { - return TYPE; - } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 3fa0ef114..60a9b9d4f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -35,6 +35,7 @@ public class TextEmbeddingProcessor extends NLPProcessor { public static final String TYPE = "text_embedding"; + public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; public TextEmbeddingProcessor( String tag, @@ -44,8 +45,7 @@ public TextEmbeddingProcessor( MLCommonsClientAccessor clientAccessor, Environment environment ) { - super(tag, description, modelId, fieldMap, clientAccessor, environment); - this.LIST_TYPE_NESTED_MAP_KEY = "knn"; + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); } @Override @@ -55,9 +55,4 @@ public void doExecute(IngestDocument ingestDocument, Map Process handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } - - @Override - public String getType() { - return TYPE; - } } From c109666207d667cb1d878084476da086c40598c0 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Sep 2023 16:54:41 +0800 Subject: [PATCH 27/70] forbid nested input. tidy. Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NLPProcessor.java | 15 ++++++++++----- .../processor/SparseEncodingProcessor.java | 19 +++++++++++-------- .../processor/TextEmbeddingProcessor.java | 11 ----------- .../TextEmbeddingProcessorTests.java | 2 +- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index e502cb61c..34e8e0f1c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -15,7 +15,11 @@ import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import java.util.*; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.function.BiConsumer; import java.util.function.Supplier; import java.util.stream.IntStream; @@ -69,12 +73,13 @@ private void validateEmbeddingConfiguration(Map fieldMap) { .anyMatch( x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) )) { - throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value"); + throw new IllegalArgumentException("Unable to create the " + type + + " processor as field_map has invalid key or value"); } } @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { + protected static void validateListTypeValue(String sourceKey, Object sourceValue) { for (Object value : (List) sourceValue) { if (value == null) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); @@ -87,7 +92,7 @@ private static void validateListTypeValue(String sourceKey, Object sourceValue) } @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + protected void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); @@ -224,7 +229,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer processorMap, List results) { Objects.requireNonNull(results, "embedding failed, inference returns null result!"); - log.debug("Text embedding result fetched, starting build vector output!"); + log.debug("Model inference result fetched, starting build vector output!"); Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); nlpResult.forEach(ingestDocument::setFieldValue); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 814cda38d..66da003b5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -6,21 +6,16 @@ package org.opensearch.neuralsearch.processor; import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang3.StringUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Supplier; -import java.util.stream.IntStream; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; @Log4j2 @@ -45,4 +40,12 @@ public void doExecute(IngestDocument ingestDocument, Map Process handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } + + @Override + protected void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + throw new IllegalArgumentException( + "[ " + TYPE + " ] ingest processor can not process nested source value. " + + "Please use plain string instead." + ); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 60a9b9d4f..4ae895cdb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -5,28 +5,17 @@ package org.opensearch.neuralsearch.processor; -import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.function.BiConsumer; -import java.util.function.Supplier; -import java.util.stream.IntStream; import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang3.StringUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; - /** * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index f4da16534..029b50c98 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -77,7 +77,7 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA try { textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } catch (IllegalArgumentException e) { - assertEquals("Unable to create the TextEmbedding processor as field_map has invalid key or value", e.getMessage()); + assertEquals("Unable to create the text_embedding processor as field_map has invalid key or value", e.getMessage()); } } From 90516b275ab7c70dda38eb6e82d9c0fe96be9072 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 14 Sep 2023 17:02:21 +0800 Subject: [PATCH 28/70] tidy Signed-off-by: zhichao-aws --- .../org/opensearch/neuralsearch/plugin/NeuralSearch.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 6634775ba..7be36aaa5 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -7,7 +7,12 @@ import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; -import java.util.*; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; import lombok.extern.log4j.Log4j2; From 4d79cc482e552a8c935a259b4f31ba586cf1c2df Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Sep 2023 11:29:43 +0800 Subject: [PATCH 29/70] enable nested Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NLPProcessor.java | 4 ++-- .../processor/SparseEncodingProcessor.java | 10 +--------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 34e8e0f1c..9c1c2b2eb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -79,7 +79,7 @@ private void validateEmbeddingConfiguration(Map fieldMap) { } @SuppressWarnings({ "rawtypes" }) - protected static void validateListTypeValue(String sourceKey, Object sourceValue) { + private static void validateListTypeValue(String sourceKey, Object sourceValue) { for (Object value : (List) sourceValue) { if (value == null) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); @@ -92,7 +92,7 @@ protected static void validateListTypeValue(String sourceKey, Object sourceValue } @SuppressWarnings({ "rawtypes", "unchecked" }) - protected void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 66da003b5..fc7f0d5c5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -22,7 +22,7 @@ public class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; - public static final String LIST_TYPE_NESTED_MAP_KEY = "sparseEncoding"; + public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; public SparseEncodingProcessor(String tag, String description, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) { super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); @@ -40,12 +40,4 @@ public void doExecute(IngestDocument ingestDocument, Map Process handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } - - @Override - protected void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - throw new IllegalArgumentException( - "[ " + TYPE + " ] ingest processor can not process nested source value. " + - "Please use plain string instead." - ); - } } From 79d861e9c42c412e26b736b6be58fea512dcfb39 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Sep 2023 14:34:31 +0800 Subject: [PATCH 30/70] fix test Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NLPProcessor.java | 2 +- .../ml/MLCommonsClientAccessorTests.java | 11 +++++------ .../processor/TextEmbeddingProcessorTests.java | 12 ++++++------ 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 9c1c2b2eb..17c74f0bc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -54,10 +54,10 @@ public NLPProcessor( Environment environment ) { super(tag, description); + this.type = type; if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); validateEmbeddingConfiguration(fieldMap); - this.type = type; this.listTypeNestedMapKey = listTypeNestedMapKey; this.modelId = modelId; this.fieldMap = fieldMap; diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index a635e3fa2..7652c127b 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -164,6 +164,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { } public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { +// final List> map = List.of(Map.of("key", "value")); final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); Mockito.doAnswer(invocation -> { @@ -175,7 +176,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { Mockito.verify(client) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(resultListener).onResponse(map); + Mockito.verify(resultListener).onResponse(List.of(map)); Mockito.verifyNoMoreInteractions(resultListener); } @@ -218,7 +219,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenEx Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenException() { + public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() { final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -243,10 +244,8 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTh accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) - .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); - Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); - assertEquals("Unexpected number of map result produced. Expected 1 map result to be returned, but got [2]", argumentCaptor.getValue().getMessage()); + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(List.of(Map.of("key","value"),Map.of("key","value"))); Mockito.verifyNoMoreInteractions(resultListener); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 029b50c98..846c23dcb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -347,7 +347,7 @@ public void testProcessResponse_successful() throws Exception { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); @@ -360,7 +360,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); // To assert the order is not changed between config map and generated map. List configValueList = new LinkedList<>(config.values()); @@ -371,7 +371,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { assertEquals(knnKeyList.get(lastIndex), configValueList.get(lastIndex).toString()); List> modelTensorList = createMockVectorResult(); - Map result = processor.buildTextEmbeddingResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + Map result = processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); assertTrue(result.containsKey("oriKey1_knn")); assertTrue(result.containsKey("oriKey2_knn")); assertTrue(result.containsKey("oriKey3_knn")); @@ -386,9 +386,9 @@ public void testBuildVectorOutput_withNestedMap_successful() { Map config = createNestedMapConfiguration(); IngestDocument ingestDocument = createNestedMapIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.buildTextEmbeddingResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); assertNotNull(favoritesMap); Map favoriteGames = (Map) favoritesMap.get("favorite.games"); @@ -402,7 +402,7 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); From 9ab2e74def679206cdda4fc3d3cdc5096b64c065 Mon Sep 17 00:00:00 2001 From: xinyual Date: Mon, 18 Sep 2023 11:21:13 +0800 Subject: [PATCH 31/70] Add ut it to sparse encoding processor (#6) * fix original UT problem Signed-off-by: xinyual * add UT IT Signed-off-by: xinyual * add more UT Signed-off-by: xinyual * add more ut Signed-off-by: xinyual * fix typo error Signed-off-by: xinyual --------- Signed-off-by: xinyual --- build.gradle | 2 +- gradle.properties | 3 +- .../neuralsearch/processor/NLPProcessor.java | 3 +- .../common/BaseNeuralSearchIT.java | 8 +- .../processor/SparseEncodingProcessIT.java | 89 +++++++++ .../SparseEncodingProcessorTests.java | 170 ++++++++++++++++++ .../TextEmbeddingProcessorTests.java | 4 +- .../processor/SparseIndexMappings.json | 20 +++ .../SparsePipelineConfiguration.json | 13 ++ .../UploadSparseModelRequestBody.json | 9 + 10 files changed, 313 insertions(+), 8 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java create mode 100644 src/test/resources/processor/SparseIndexMappings.json create mode 100644 src/test/resources/processor/SparsePipelineConfiguration.json create mode 100644 src/test/resources/processor/UploadSparseModelRequestBody.json diff --git a/build.gradle b/build.gradle index 41b3a3f7b..1d8eca483 100644 --- a/build.gradle +++ b/build.gradle @@ -144,7 +144,7 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" compileOnly fileTree(dir: knnJarDirectory, include: '*.jar') api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.13.0' + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.10' // ml-common excluded reflection for runtime so we need to add it by ourselves. // https://github.com/opensearch-project/ml-commons/commit/464bfe34c66d7a729a00dd457f03587ea4e504d9 // TODO: Remove following three lines of dependencies if ml-common include them in their jar diff --git a/gradle.properties b/gradle.properties index f4b55d2a3..90e7a8445 100644 --- a/gradle.properties +++ b/gradle.properties @@ -8,5 +8,4 @@ org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAME --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ - --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED -customDistributionUrl=https://artifacts.opensearch.org/snapshots/core/opensearch/3.0.0-SNAPSHOT/opensearch-min-3.0.0-SNAPSHOT-darwin-x64-latest.tar.gz \ No newline at end of file + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 17c74f0bc..02d5e5747 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -73,8 +73,7 @@ private void validateEmbeddingConfiguration(Map fieldMap) { .anyMatch( x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) )) { - throw new IllegalArgumentException("Unable to create the " + type - + " processor as field_map has invalid key or value"); + throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index b144ade6c..6aea1a9f2 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -64,8 +64,14 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; + protected String PIPELINE_CONFIGURATION_NAME = "processor/PipelineConfiguration.json"; + protected final ClassLoader classLoader = this.getClass().getClassLoader(); + protected void setPipelineConfigurationName(String pipelineConfigurationName){ + this.PIPELINE_CONFIGURATION_NAME = pipelineConfigurationName; + } + @Before public void setupSettings() { if (isUpdateClusterSettings()) { @@ -239,7 +245,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro toHttpEntity( String.format( LOCALE, - Files.readString(Path.of(classLoader.getResource("processor/PipelineConfiguration.json").toURI())), + Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGURATION_NAME).toURI())), modelId ) ), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java new file mode 100644 index 000000000..ea0dfaaff --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import com.google.common.collect.ImmutableList; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +public class SparseEncodingProcessIT extends BaseNeuralSearchIT { + + private static final String INDEX_NAME = "sparse_encoding_index"; + + private static final String PIPELINE_NAME = "pipeline-sparse-encoding"; + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); + } + + @Before + public void setPipelineName() { + this.setPipelineConfigurationName("processor/SparsePipelineConfiguration.json"); + } + + public void testSparseEncodingProcessor() throws Exception { + String modelId = uploadSparseEncodingModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME); + createTextEmbeddingIndex(); + ingestDocument(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + private String uploadSparseEncodingModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadSparseModelRequestBody.json").toURI())); + return uploadModel(requestBody); + } + + private void createTextEmbeddingIndex() throws Exception { + createIndexWithConfiguration( + INDEX_NAME, + Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())), + PIPELINE_NAME + ); + } + + private void ingestDocument() throws Exception { + String ingestDocument = "{\n" + + "\"passage_text\": \"This is a good day\"" + + "}\n"; + Response response = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals("created", map.get("result")); + } + + +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java new file mode 100644 index 000000000..e79f286fe --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -0,0 +1,170 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import lombok.SneakyThrows; +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.*; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +public class SparseEncodingProcessorTests extends OpenSearchTestCase { + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private Environment env; + + @InjectMocks + private SparseEncodingProcessorFactory SparseEncodingProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + } + + @SneakyThrows + private SparseEncodingProcessor createInstance() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + return SparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + List > dataAsMapList = createMockMapResult(2); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + @SneakyThrows + public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { + Map sourceAndMetadata = new HashMap<>(); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map registry = new HashMap<>(); + MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + SparseEncodingProcessorFactory sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(accessor, env); + + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + SparseEncodingProcessor processor = sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + public void testExecute_withListTypeInput_successful() { + List list1 = ImmutableList.of("test1", "test2", "test3"); + List list2 = ImmutableList.of("test4", "test5", "test6"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", list1); + sourceAndMetadata.put("key2", list2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + List > dataAsMapList = createMockMapResult(6); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_withMapTypeInput_successful() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test4", "test5"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("key2", map2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + List > dataAsMapList = createMockMapResult(2); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + + } + + + private List > createMockMapResult(int number) + { + List> mockSparseEncodingResult = new ArrayList<>(); + IntStream.range(0, number) + .forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + + List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); + return mockMapResult; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 846c23dcb..399cd1eb8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -75,9 +75,9 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA fieldMap.put("key2", "key2Mapped"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap); try { - textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + textEmbeddingProcessorFactory.create(registry, TextEmbeddingProcessor.TYPE, DESCRIPTION, config); } catch (IllegalArgumentException e) { - assertEquals("Unable to create the text_embedding processor as field_map has invalid key or value", e.getMessage()); + assertEquals("Unable to create the processor as field_map has invalid key or value", e.getMessage()); } } diff --git a/src/test/resources/processor/SparseIndexMappings.json b/src/test/resources/processor/SparseIndexMappings.json new file mode 100644 index 000000000..d06cfd600 --- /dev/null +++ b/src/test/resources/processor/SparseIndexMappings.json @@ -0,0 +1,20 @@ +{ + "settings":{ + "default_pipeline": "pipeline-sparse-encoding" + }, + "mappings": { + "properties": { + "passage_text": { + "type": "text" + }, + "passage_sparse": { + "type": "nested", + "properties":{ + "sparseEncoding":{ + "type": "rank_features" + } + } + } + } + } +} \ No newline at end of file diff --git a/src/test/resources/processor/SparsePipelineConfiguration.json b/src/test/resources/processor/SparsePipelineConfiguration.json new file mode 100644 index 000000000..297bbf80a --- /dev/null +++ b/src/test/resources/processor/SparsePipelineConfiguration.json @@ -0,0 +1,13 @@ +{ + "description": "An example sparse Encoding pipeline", + "processors" : [ + { + "sparse_encoding": { + "model_id": "%s", + "field_map": { + "passage_text": "passage_sparse" + } + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/processor/UploadSparseModelRequestBody.json b/src/test/resources/processor/UploadSparseModelRequestBody.json new file mode 100644 index 000000000..2d48a4170 --- /dev/null +++ b/src/test/resources/processor/UploadSparseModelRequestBody.json @@ -0,0 +1,9 @@ +{ + "name": "tokenize-idf-0915", + "version": "1.0.0", + "function_name": "TOKENIZE", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_content_hash_value": "e23969f8bd417e7aec26f49201da4adfc6b74e6187d1ddfdfb98e473bdd95978", + "url": "https://github.com/xinyual/demo/raw/main/tokenizer-idf-msmarco.zip" +} \ No newline at end of file From 84915f03836d7ce32a40036c8beb15bbab6a757c Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Sep 2023 15:56:15 +0800 Subject: [PATCH 32/70] utils, tidy Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 23 ++++-- .../processor/TextEmbeddingProcessor.java | 8 +- .../query/sparse/SparseQueryBuilder.java | 2 +- .../neuralsearch/util/TokenWeightUtil.java | 75 +++++++++++++------ 4 files changed, 77 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index fc7f0d5c5..2061e2f21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -10,12 +10,12 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.TokenWeightUtil; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; -import java.util.function.Supplier; @Log4j2 @@ -24,19 +24,32 @@ public class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; - public SparseEncodingProcessor(String tag, String description, String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) { + public SparseEncodingProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); } @Override - public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { + public void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { List > results = new ArrayList<>(); - for (Map map: resultMaps) - { + for (Map map: resultMaps) { results.addAll((List>)map.get("response") ); } setVectorFieldsToDocument(ingestDocument, ProcessMap, results); +// setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 4ae895cdb..bd690558d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -38,7 +38,13 @@ public TextEmbeddingProcessor( } @Override - public void doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList, BiConsumer handler) { + public void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ) { mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java index a967be2d8..20368bdbb 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java @@ -222,7 +222,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws modelId(), List.of(queryText), ActionListener.wrap(mapResultList -> { - queryTokensSetOnce.set(TokenWeightUtil.fetchQueryTokensList(mapResultList).get(0)); + queryTokensSetOnce.set(TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0)); actionListener.onResponse(null); }, actionListener::onFailure)) ) diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index ab5edaa9d..4e13c6e8f 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -5,42 +5,69 @@ package org.opensearch.neuralsearch.util; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; public class TokenWeightUtil { - //todo: change comments, add validation (throw exception) - /** - * Converts a Map result to Map - * - * @param mapResult {@link Map} of String and ? - * @return Map of String and Float - */ public static String RESPONSE_KEY = "response"; - public static List> fetchQueryTokensList(List> mapResultList) { - return mapResultList.stream().map(TokenWeightUtil::buildQueryTokensMap).collect(Collectors.toList()); + /** + * possible input data format + * case remote inference: + * { + * "response":{ + * [ + * { TOKEN_WEIGHT_MAP}, + * { TOKEN_WEIGHT_MAP} + * ] + * } + * } + * case local deploy: + * [{"response":{ + * [ + * { TOKEN_WEIGHT_MAP} + * ] + * } + * },{"response":{ + * [ + * { TOKEN_WEIGHT_MAP} + * ] + * }] + */ + public static List> fetchListOfTokenWeightMap(List> mapResultList) { + List results = new ArrayList<>(); + for (Map map: mapResultList) + { + if (!map.containsKey(RESPONSE_KEY)){ + throw new IllegalArgumentException("The inference result should be associated with the field [" + + RESPONSE_KEY + "]."); + } + if (!List.class.isAssignableFrom(map.get(RESPONSE_KEY).getClass())) { + throw new IllegalArgumentException("The data object associated with field [" + + RESPONSE_KEY + "] should be a list."); + } + results.addAll((List) map.get("response")); + } + return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList()); } - @SuppressWarnings("unchecked") - private static Map buildQueryTokensMap(Map mapResult) { - Object response = mapResult.get(RESPONSE_KEY); + private static Map buildTokenWeightMap(Object uncastedMap) { + if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { + throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + + " Float values."); + } Map result = new HashMap<>(); - if (response instanceof Map) { - for (Map.Entry entry: ((Map) response).entrySet()) { - assert entry.getValue() instanceof Number; - result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); - } - return result; - } else { - assert response instanceof List; - for (Map.Entry entry: ((Map) ((List) response).get(0)).entrySet()) { - assert entry.getValue() instanceof Number; - result.put(entry.getKey(), ((Number) entry.getValue()).floatValue()); + for (Map.Entry entry: ((Map) uncastedMap).entrySet()) { + if (!String.class.isAssignableFrom(entry.getKey().getClass()) + || !Number.class.isAssignableFrom(entry.getValue().getClass())){ + throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + + " Float values."); } - return result; + result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } + return result; } } From 3bb95e35a8670c22e19b51ae6249bda0e93c5216 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Fri, 15 Sep 2023 15:57:48 +0800 Subject: [PATCH 33/70] rename to sparse_encoding query Signed-off-by: zhichao-aws --- .../neuralsearch/plugin/NeuralSearch.java | 6 +-- ...r.java => SparseEncodingQueryBuilder.java} | 40 +++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) rename src/main/java/org/opensearch/neuralsearch/query/sparse/{SparseQueryBuilder.java => SparseEncodingQueryBuilder.java} (89%) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 7be36aaa5..53b4b30d7 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -42,7 +42,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import org.opensearch.neuralsearch.query.sparse.SparseQueryBuilder; +import org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; @@ -82,7 +82,7 @@ public Collection createComponents( final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); - SparseQueryBuilder.initialize(clientAccessor); + SparseEncodingQueryBuilder.initialize(clientAccessor); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } @@ -92,7 +92,7 @@ public List> getQueries() { return Arrays.asList( new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent), new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent), - new QuerySpec<>(SparseQueryBuilder.NAME, SparseQueryBuilder::new, SparseQueryBuilder::fromXContent) + new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java similarity index 89% rename from src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java rename to src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java index 20368bdbb..f54c2614d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java @@ -48,8 +48,8 @@ @Accessors(chain = true, fluent = true) @NoArgsConstructor @AllArgsConstructor -public class SparseQueryBuilder extends AbstractQueryBuilder { - public static final String NAME = "sparse"; +public class SparseEncodingQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "sparse_encoding"; @VisibleForTesting static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens"); @VisibleForTesting @@ -62,7 +62,7 @@ public class SparseQueryBuilder extends AbstractQueryBuilder private static MLCommonsClientAccessor ML_CLIENT; public static void initialize(MLCommonsClientAccessor mlClient) { - SparseQueryBuilder.ML_CLIENT = mlClient; + SparseEncodingQueryBuilder.ML_CLIENT = mlClient; } private String fieldName; @@ -72,7 +72,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private Float tokenScoreUpperBound; private Supplier> queryTokensSupplier; - public SparseQueryBuilder(StreamInput in) throws IOException { + public SparseEncodingQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); // we don't have readOptionalMap or write, need to do it manually @@ -132,8 +132,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * } * */ - public static SparseQueryBuilder fromXContent(XContentParser parser) throws IOException { - SparseQueryBuilder sparseQueryBuilder = new SparseQueryBuilder(); + public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder(); if (parser.currentToken() != XContentParser.Token.START_OBJECT) { throw new ParsingException( parser.getTokenLocation(), @@ -141,28 +141,28 @@ public static SparseQueryBuilder fromXContent(XContentParser parser) throws IOEx ); } parser.nextToken(); - sparseQueryBuilder.fieldName(parser.currentName()); + sparseEncodingQueryBuilder.fieldName(parser.currentName()); parser.nextToken(); - parseQueryParams(parser, sparseQueryBuilder); + parseQueryParams(parser, sparseEncodingQueryBuilder); if (parser.nextToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( parser.getTokenLocation(), "[" + NAME + "] query doesn't support multiple fields, found [" - + sparseQueryBuilder.fieldName() + + sparseEncodingQueryBuilder.fieldName() + "] and [" + parser.currentName() + "]" ); } - return sparseQueryBuilder; + return sparseEncodingQueryBuilder; } private static void parseQueryParams( XContentParser parser, - SparseQueryBuilder sparseQueryBuilder + SparseEncodingQueryBuilder sparseEncodingQueryBuilder ) throws IOException { XContentParser.Token token; String currentFieldName = ""; @@ -171,15 +171,15 @@ private static void parseQueryParams( currentFieldName = parser.currentName(); } else if (token.isValue()) { if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.queryName(parser.text()); + sparseEncodingQueryBuilder.queryName(parser.text()); } else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.boost(parser.floatValue()); + sparseEncodingQueryBuilder.boost(parser.floatValue()); } else if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.queryText(parser.text()); + sparseEncodingQueryBuilder.queryText(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.modelId(parser.text()); + sparseEncodingQueryBuilder.modelId(parser.text()); } else if (TOKEN_SCORE_UPPER_BOUND_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.tokenScoreUpperBound(parser.floatValue()); + sparseEncodingQueryBuilder.tokenScoreUpperBound(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -187,7 +187,7 @@ private static void parseQueryParams( ); } } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseQueryBuilder.queryTokens(parser.map(HashMap::new, XContentParser::floatValue)); + sparseEncodingQueryBuilder.queryTokens(parser.map(HashMap::new, XContentParser::floatValue)); } else { throw new ParsingException( parser.getTokenLocation(), @@ -207,7 +207,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws } if (null != queryTokensSupplier) { return queryTokensSupplier.get() == null ? this : - new SparseQueryBuilder() + new SparseEncodingQueryBuilder() .fieldName(fieldName) .queryTokens(queryTokensSupplier.get()) .queryText(queryText) @@ -227,7 +227,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws }, actionListener::onFailure)) ) ); - return new SparseQueryBuilder() + return new SparseEncodingQueryBuilder() .fieldName(fieldName) .queryText(queryText) .modelId(modelId) @@ -295,7 +295,7 @@ private static void validateQueryTokens(Map queryTokens) { } @Override - protected boolean doEquals(SparseQueryBuilder obj) { + protected boolean doEquals(SparseEncodingQueryBuilder obj) { // todo: validate if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; From b4156f003aa7de9bc78b0ffca38e20ba23193969 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 18 Sep 2023 11:22:04 +0800 Subject: [PATCH 34/70] add validation and ut Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessor.java | 7 +- .../sparse/SparseEncodingQueryBuilder.java | 17 + .../SparseEncodingQueryBuilderTests.java | 556 ++++++++++++++++++ 3 files changed, 574 insertions(+), 6 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 2061e2f21..ea15c8b8c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -44,12 +44,7 @@ public void doExecute( Exception> handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - List > results = new ArrayList<>(); - for (Map map: resultMaps) { - results.addAll((List>)map.get("response") ); - } - setVectorFieldsToDocument(ingestDocument, ProcessMap, results); -// setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java index f54c2614d..7957116e1 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java @@ -157,6 +157,23 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr ); } + requireValue( + sparseEncodingQueryBuilder.fieldName(), + "Field name must be provided for " + NAME + " query" + ); + if (null==sparseEncodingQueryBuilder.queryTokens()) { + requireValue( + sparseEncodingQueryBuilder.queryText(), + "Either " + QUERY_TOKENS_FIELD.getPreferredName() + " or " + + QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query" + ); + requireValue( + sparseEncodingQueryBuilder.modelId(), + MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + + " query when using " + QUERY_TEXT_FIELD.getPreferredName() + ); + } + return sparseEncodingQueryBuilder; } diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java new file mode 100644 index 000000000..9195e5800 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java @@ -0,0 +1,556 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query.sparse; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; +import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; +import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TOKENS_FIELD; +import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.TOKEN_SCORE_UPPER_BOUND_FIELD; + +import lombok.SneakyThrows; +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.FilterStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.MatchNoneQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.common.VectorUtil; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { + + private static final String FIELD_NAME = "testField"; + private static final String QUERY_TEXT = "Hello world!"; + private static final Map QUERY_TOKENS = Map.of("hello", 1.f, "world", 2.f); + private static final String MODEL_ID = "mfgfgdsfgfdgsde"; + private static final Float TOKEN_SCORE_UPPER_BOUND = 123f; + private static final float BOOST = 1.8f; + private static final String QUERY_NAME = "queryName"; + private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); + + @SneakyThrows + public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); + assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_tokens": { + "string":float, + } + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TOKENS, sparseEncodingQueryBuilder.queryTokens()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "boost": 10.0, + "_name": "something", + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), TOKEN_SCORE_UPPER_BOUND) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); + assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); + assertEquals(TOKEN_SCORE_UPPER_BOUND, sparseEncodingQueryBuilder.tokenScoreUpperBound()); + assertEquals(BOOST, sparseEncodingQueryBuilder.boost(), 0.0); + assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName()); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "boost": 10.0, + "_name": "something", + }, + "invalid": 10 + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .endObject() + .field("invalid", 10) + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(ParsingException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMissingParameters_thenFail() { + /* + { + "VECTOR_FIELD": { + + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(FIELD_NAME).endObject().endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMissingModelId_thenFail() { + /* + { + "VECTOR_FIELD": { + + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(FIELD_NAME).endObject().endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_text": "string", + "model_id": "string", + "model_id": "string", + "k": int, + "k": int + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "k": int, + "boost": 10.0, + "filter": 12 + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .field(FILTER_FIELD.getPreferredName(), 12) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(ParsingException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + + @SuppressWarnings("unchecked") + @SneakyThrows + public void testToXContent() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .modelId(MODEL_ID) + .queryText(QUERY_TEXT) + .k(K) + .filter(TEST_FILTER); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = neuralQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map out = xContentBuilderToMap(builder); + + Object outer = out.get(NAME); + if (!(outer instanceof Map)) { + fail("neural does not map to nested object"); + } + + Map outerMap = (Map) outer; + + assertEquals(1, outerMap.size()); + assertTrue(outerMap.containsKey(FIELD_NAME)); + + Object secondInner = outerMap.get(FIELD_NAME); + if (!(secondInner instanceof Map)) { + fail("field name does not map to nested object"); + } + + Map secondInnerMap = (Map) secondInner; + + assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); + assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); + assertEquals(K, secondInnerMap.get(K_FIELD.getPreferredName())); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + assertEquals( + xContentBuilderToMap(TEST_FILTER.toXContent(xContentBuilder, EMPTY_PARAMS)), + secondInnerMap.get(FILTER_FIELD.getPreferredName()) + ); + } + + @SneakyThrows + public void testStreams() { + NeuralQueryBuilder original = new NeuralQueryBuilder(); + original.fieldName(FIELD_NAME); + original.queryText(QUERY_TEXT); + original.modelId(MODEL_ID); + original.k(K); + original.boost(BOOST); + original.queryName(QUERY_NAME); + original.filter(TEST_FILTER); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + NeuralQueryBuilder copy = new NeuralQueryBuilder(filterStreamInput); + assertEquals(original, copy); + } + + public void testHashAndEquals() { + String fieldName1 = "field 1"; + String fieldName2 = "field 2"; + String queryText1 = "query text 1"; + String queryText2 = "query text 2"; + String modelId1 = "model-1"; + String modelId2 = "model-2"; + float boost1 = 1.8f; + float boost2 = 3.8f; + String queryName1 = "query-1"; + String queryName2 = "query-2"; + int k1 = 1; + int k2 = 2; + + QueryBuilder filter1 = new MatchAllQueryBuilder(); + QueryBuilder filter2 = new MatchNoneQueryBuilder(); + + NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline + NeuralQueryBuilder neuralQueryBuilder_baselineCopy = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except default boost and query name + NeuralQueryBuilder neuralQueryBuilder_defaultBoostAndQueryName = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except diff field name + NeuralQueryBuilder neuralQueryBuilder_diffFieldName = new NeuralQueryBuilder().fieldName(fieldName2) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except diff query text + NeuralQueryBuilder neuralQueryBuilder_diffQueryText = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText2) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except diff model ID + NeuralQueryBuilder neuralQueryBuilder_diffModelId = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId2) + .k(k1) + .boost(boost1) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except diff k + NeuralQueryBuilder neuralQueryBuilder_diffK = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k2) + .boost(boost1) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except diff boost + NeuralQueryBuilder neuralQueryBuilder_diffBoost = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost2) + .queryName(queryName1) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except diff query name + NeuralQueryBuilder neuralQueryBuilder_diffQueryName = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName2) + .filter(filter1); + + // Identical to neuralQueryBuilder_baseline except no filter + NeuralQueryBuilder neuralQueryBuilder_noFilter = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName2); + + // Identical to neuralQueryBuilder_baseline except no filter + NeuralQueryBuilder neuralQueryBuilder_diffFilter = new NeuralQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .k(k1) + .boost(boost1) + .queryName(queryName2) + .filter(filter2); + + assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baseline); + assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baseline.hashCode()); + + assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baselineCopy); + assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baselineCopy.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_defaultBoostAndQueryName); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_defaultBoostAndQueryName.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFieldName); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFieldName.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryText); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryText.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffModelId); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffModelId.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffK); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffK.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffBoost); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffBoost.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryName); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryName.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_noFilter); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_noFilter.hashCode()); + + assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFilter); + assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFilter.hashCode()); + } + + @SneakyThrows + public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(expectedVector); + return null; + }).when(mlCommonsClientAccessor).inferenceSentence(any(), any(), any()); + NeuralQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set vector supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.vectorSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); + } + + public void testRewrite_whenVectorNull_thenReturnCopy() { + Supplier nullSupplier = () -> null; + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(nullSupplier); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); + assertEquals(neuralQueryBuilder, queryBuilder); + } + + public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); + assertTrue(queryBuilder instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; + assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); + assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); + assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); + } + + public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); + QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); + assertTrue(queryBuilder instanceof KNNQueryBuilder); + KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; + assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); + } +} From 4771cd1a4d2c27fb379e5077a9e5ae82cba30b87 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 18 Sep 2023 15:29:16 +0800 Subject: [PATCH 35/70] sparse encoding query builder ut Signed-off-by: zhichao-aws --- .../sparse/SparseEncodingQueryBuilder.java | 1 - .../SparseEncodingQueryBuilderTests.java | 324 ++++++++---------- 2 files changed, 149 insertions(+), 176 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java index 7957116e1..e471381c5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java @@ -313,7 +313,6 @@ private static void validateQueryTokens(Map queryTokens) { @Override protected boolean doEquals(SparseEncodingQueryBuilder obj) { - // todo: validate if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; EqualsBuilder equalsBuilder = new EqualsBuilder() diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java index 9195e5800..ade08248e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java @@ -8,10 +8,8 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; -import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.NAME; @@ -23,28 +21,21 @@ import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.ParsingException; -import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.FilterStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.MatchAllQueryBuilder; -import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; -import org.opensearch.knn.index.query.KNNQueryBuilder; -import org.opensearch.neuralsearch.common.VectorUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; @@ -123,6 +114,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { "VECTOR_FIELD": { "query_text": "string", "model_id": "string", + "token_score_upper_bound":123.0, "boost": 10.0, "_name": "something", } @@ -181,15 +173,20 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { } @SneakyThrows - public void testFromXContent_whenBuildWithMissingParameters_thenFail() { + public void testFromXContent_whenBuildWithMissingQuery_thenFail() { /* { "VECTOR_FIELD": { - + "model_id": "string" } } */ - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(FIELD_NAME).endObject().endObject(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -201,28 +198,7 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() { /* { "VECTOR_FIELD": { - - } - } - */ - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject(FIELD_NAME).endObject().endObject(); - - XContentParser contentParser = createParser(xContentBuilder); - contentParser.nextToken(); - expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); - } - - @SneakyThrows - public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { - /* - { - "VECTOR_FIELD": { - "query_text": "string", - "query_text": "string", - "model_id": "string", - "model_id": "string", - "k": int, - "k": int + "query_text": "string" } } */ @@ -230,29 +206,23 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { .startObject() .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(K_FIELD.getPreferredName(), K) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(K_FIELD.getPreferredName(), K) .endObject() .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); } @SneakyThrows - public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { + public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { /* { "VECTOR_FIELD": { + "query_text": "string", "query_text": "string", "model_id": "string", - "k": int, - "boost": 10.0, - "filter": 12 + "model_id": "string" } } */ @@ -261,34 +231,32 @@ public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(K_FIELD.getPreferredName(), K) - .field(BOOST_FIELD.getPreferredName(), BOOST) - .field(NAME_FIELD.getPreferredName(), QUERY_NAME) - .field(FILTER_FIELD.getPreferredName(), 12) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) .endObject() .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); - expectThrows(ParsingException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + expectThrows(IOException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); } @SuppressWarnings("unchecked") @SneakyThrows public void testToXContent() { - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) .modelId(MODEL_ID) .queryText(QUERY_TEXT) - .k(K) - .filter(TEST_FILTER); + .queryTokens(QUERY_TOKENS) + .tokenScoreUpperBound(TOKEN_SCORE_UPPER_BOUND); XContentBuilder builder = XContentFactory.jsonBuilder(); - builder = neuralQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); Map out = xContentBuilderToMap(builder); Object outer = out.get(NAME); if (!(outer instanceof Map)) { - fail("neural does not map to nested object"); + fail("sparse encoding does not map to nested object"); } Map outerMap = (Map) outer; @@ -305,24 +273,29 @@ public void testToXContent() { assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); - assertEquals(K, secondInnerMap.get(K_FIELD.getPreferredName())); - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + // QUERY_TOKENS is map, the converted one use + Map convertedQueryTokensMap = (Map) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName()); + assertEquals(QUERY_TOKENS.size(), convertedQueryTokensMap.size()); + for (Map.Entry entry: QUERY_TOKENS.entrySet()) { + assertEquals(entry.getValue(), convertedQueryTokensMap.get(entry.getKey()).floatValue(), 0); + } assertEquals( - xContentBuilderToMap(TEST_FILTER.toXContent(xContentBuilder, EMPTY_PARAMS)), - secondInnerMap.get(FILTER_FIELD.getPreferredName()) + TOKEN_SCORE_UPPER_BOUND, + ((Double) secondInnerMap.get(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName())).floatValue(), + 0 ); } @SneakyThrows public void testStreams() { - NeuralQueryBuilder original = new NeuralQueryBuilder(); + SparseEncodingQueryBuilder original = new SparseEncodingQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); original.modelId(MODEL_ID); - original.k(K); + original.queryTokens(QUERY_TOKENS); + original.tokenScoreUpperBound(TOKEN_SCORE_UPPER_BOUND); original.boost(BOOST); original.queryName(QUERY_NAME); - original.filter(TEST_FILTER); BytesStreamOutput streamOutput = new BytesStreamOutput(); original.writeTo(streamOutput); @@ -334,7 +307,7 @@ public void testStreams() { ) ); - NeuralQueryBuilder copy = new NeuralQueryBuilder(filterStreamInput); + SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput); assertEquals(original, copy); } @@ -349,152 +322,164 @@ public void testHashAndEquals() { float boost2 = 3.8f; String queryName1 = "query-1"; String queryName2 = "query-2"; - int k1 = 1; - int k2 = 2; - - QueryBuilder filter1 = new MatchAllQueryBuilder(); - QueryBuilder filter2 = new MatchNoneQueryBuilder(); + Map queryTokens1 = Map.of("hello", 1f); + Map queryTokens2 = Map.of("hello", 2f); + float tokenScoreUpperBound1 = 1f; + float tokenScoreUpperBound2 = 2f; - NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost1) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline - NeuralQueryBuilder neuralQueryBuilder_baselineCopy = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost1) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except default boost and query name - NeuralQueryBuilder neuralQueryBuilder_defaultBoostAndQueryName = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except diff field name - NeuralQueryBuilder neuralQueryBuilder_diffFieldName = new NeuralQueryBuilder().fieldName(fieldName2) + // Identical to sparseEncodingQueryBuilder_baseline except diff field name + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder() + .fieldName(fieldName2) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost1) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except diff query text - NeuralQueryBuilder neuralQueryBuilder_diffQueryText = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except diff query text + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText2) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost1) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except diff model ID - NeuralQueryBuilder neuralQueryBuilder_diffModelId = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except diff model ID + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId2) - .k(k1) .boost(boost1) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except diff k - NeuralQueryBuilder neuralQueryBuilder_diffK = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except diff query tokens + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens2) .modelId(modelId1) - .k(k2) .boost(boost1) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except diff boost - NeuralQueryBuilder neuralQueryBuilder_diffBoost = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except diff boost + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost2) .queryName(queryName1) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except diff query name - NeuralQueryBuilder neuralQueryBuilder_diffQueryName = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except diff query name + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost1) .queryName(queryName2) - .filter(filter1); + .tokenScoreUpperBound(tokenScoreUpperBound1); - // Identical to neuralQueryBuilder_baseline except no filter - NeuralQueryBuilder neuralQueryBuilder_noFilter = new NeuralQueryBuilder().fieldName(fieldName1) + // Identical to sparseEncodingQueryBuilder_baseline except diff token_score_upper_bound + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffTokenScoreUpperBound = new SparseEncodingQueryBuilder() + .fieldName(fieldName1) .queryText(queryText1) + .queryTokens(queryTokens1) .modelId(modelId1) - .k(k1) .boost(boost1) - .queryName(queryName2); - - // Identical to neuralQueryBuilder_baseline except no filter - NeuralQueryBuilder neuralQueryBuilder_diffFilter = new NeuralQueryBuilder().fieldName(fieldName1) - .queryText(queryText1) - .modelId(modelId1) - .k(k1) - .boost(boost1) - .queryName(queryName2) - .filter(filter2); + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound2); - assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baseline); - assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baseline.hashCode()); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); + assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); - assertEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_baselineCopy); - assertEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_baselineCopy.hashCode()); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baselineCopy); + assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baselineCopy.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_defaultBoostAndQueryName); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_defaultBoostAndQueryName.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_defaultBoostAndQueryName); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_defaultBoostAndQueryName.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFieldName); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFieldName.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffFieldName); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffFieldName.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryText); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryText.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryText); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryText.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffModelId); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffModelId.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffModelId); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffModelId.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffK); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffK.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffBoost); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffBoost.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffBoost); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffBoost.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffQueryName); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffQueryName.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_noFilter); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_noFilter.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffTokenScoreUpperBound); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffTokenScoreUpperBound.hashCode()); + } - assertNotEquals(neuralQueryBuilder_baseline, neuralQueryBuilder_diffFilter); - assertNotEquals(neuralQueryBuilder_baseline.hashCode(), neuralQueryBuilder_diffFilter.hashCode()); + @SneakyThrows + public void testRewrite_whenQueryTokensNotNull_thenRewriteToSelf() { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .queryTokens(QUERY_TOKENS) + .fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID); + QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); + assert queryBuilder == sparseEncodingQueryBuilder; } @SneakyThrows - public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME).queryText(QUERY_TEXT).modelId(MODEL_ID).k(K); - List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID); + Map expectedMap = Map.of("1", 1f, "2", 2f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(2); - listener.onResponse(expectedVector); + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); return null; - }).when(mlCommonsClientAccessor).inferenceSentence(any(), any(), any()); - NeuralQueryBuilder.initialize(mlCommonsClientAccessor); + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + SparseEncodingQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); @@ -504,53 +489,42 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { null, ActionListener.wrap( response -> inProgressLatch.countDown(), - err -> fail("Failed to set vector supplier: " + err.getMessage()) + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) ) ); return null; }).when(queryRewriteContext).registerAsyncAction(any()); - NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext); - assertNotNull(queryBuilder.vectorSupplier()); + SparseEncodingQueryBuilder queryBuilder = (SparseEncodingQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.queryTokensSupplier()); assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); - assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); + assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get()); } + @SneakyThrows public void testRewrite_whenVectorNull_thenReturnCopy() { - Supplier nullSupplier = () -> null; - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + Supplier> nullSupplier = () -> null; + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) - .k(K) - .vectorSupplier(nullSupplier); - QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); - assertEquals(neuralQueryBuilder, queryBuilder); + .queryTokensSupplier(nullSupplier); + QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); + assertEquals(sparseEncodingQueryBuilder, queryBuilder); } - public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder() { - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + @SneakyThrows + public void testRewrite_whenQueryTokensSupplierSet_thenSetQueryTokens() { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) - .k(K) - .vectorSupplier(TEST_VECTOR_SUPPLIER); - QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); - assertTrue(queryBuilder instanceof KNNQueryBuilder); - KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; - assertEquals(neuralQueryBuilder.fieldName(), knnQueryBuilder.fieldName()); - assertEquals(neuralQueryBuilder.k(), knnQueryBuilder.getK()); - assertArrayEquals(TEST_VECTOR_SUPPLIER.get(), (float[]) knnQueryBuilder.vector(), 0.0f); - } - - public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); + QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); + SparseEncodingQueryBuilder targetQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) - .k(K) - .vectorSupplier(TEST_VECTOR_SUPPLIER) - .filter(TEST_FILTER); - QueryBuilder queryBuilder = neuralQueryBuilder.doRewrite(null); - assertTrue(queryBuilder instanceof KNNQueryBuilder); - KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder) queryBuilder; - assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); + .queryTokens(QUERY_TOKENS_SUPPLIER.get()); + assertEquals(queryBuilder, targetQueryBuilder); } } From 5d127586954d743d6312fc00bc7e16ef320a8666 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 19 Sep 2023 10:11:55 +0800 Subject: [PATCH 36/70] rename Signed-off-by: zhichao-aws --- .../query/sparse/SparseEncodingQueryBuilderTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java index ade08248e..5f7ecb2d0 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java @@ -502,7 +502,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() } @SneakyThrows - public void testRewrite_whenVectorNull_thenReturnCopy() { + public void testRewrite_whenSupplierContentNull_thenReturnCopy() { Supplier> nullSupplier = () -> null; SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) From 51a9ef3c00565ccdb812510dd759d84504e9c3f2 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 19 Sep 2023 13:28:42 +0800 Subject: [PATCH 37/70] UT for utils Signed-off-by: zhichao-aws --- .../neuralsearch/util/TokenWeightUtil.java | 4 +- .../util/TokenWeightUtilTests.java | 120 ++++++++++++++++++ 2 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 4e13c6e8f..bd118b480 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -17,14 +17,14 @@ public class TokenWeightUtil { /** * possible input data format * case remote inference: - * { + * [{ * "response":{ * [ * { TOKEN_WEIGHT_MAP}, * { TOKEN_WEIGHT_MAP} * ] * } - * } + * }] * case local deploy: * [{"response":{ * [ diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java new file mode 100644 index 000000000..6fa2a9d22 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; + +public class TokenWeightUtilTests extends OpenSearchTestCase { + private static final Map MOCK_DATA = Map.of("hello", 1.f, "world", 2.f); + + public void testFetchListOfTokenWeightMap_singleObject() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + assertEquals( + TokenWeightUtil.fetchListOfTokenWeightMap(inputData), + List.of(MOCK_DATA) + ); + } + + public void testFetchListOfTokenWeightMap_multipleObjectsInOneResponse() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0}, + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA, MOCK_DATA))); + assertEquals( + TokenWeightUtil.fetchListOfTokenWeightMap(inputData), + List.of(MOCK_DATA, MOCK_DATA) + ); + } + + public void testFetchListOfTokenWeightMap_multipleObjectsInMultipleResponse() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + },{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of( + Map.of("response", List.of(MOCK_DATA)), + Map.of("response", List.of(MOCK_DATA)) + ); + assertEquals( + TokenWeightUtil.fetchListOfTokenWeightMap(inputData), + List.of(MOCK_DATA, MOCK_DATA) + ); + } + + public void testFetchListOfTokenWeightMap_whenResponseValueNotList_thenFail() { + /* + [{ + "response": {"hello": 1.0, "world": 2.0} + }] + */ + List> inputData = List.of(Map.of("response", MOCK_DATA)); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenNotUseResponseKey_thenFail() { + /* + [{ + "some_key": [{"hello": 1.0, "world": 2.0}] + }] + */ + List> inputData = List.of(Map.of("some_key", List.of(MOCK_DATA))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenInputObjectIsNotMap_thenFail() { + /* + [{ + "response": [[{"hello": 1.0, "world": 2.0}]] + }] + */ + List> inputData = List.of(Map.of("response", List.of(List.of(MOCK_DATA)))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonStringKeys_thenFail() { + /* + [{ + "response": [[{"hello": 1.0, 2.3: 2.0}]] + }] + */ + Map MOCK_DATA = Map.of("hello", 1.f, 2.3f, 2.f); + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_thenFail() { + /* + [{ + "response": [[{"hello": 1.0, 2.3: 2.0}]] + }] + */ + Map MOCK_DATA = Map.of("hello", 1.f, "world", "world"); + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } +} From 77eb30077d8193e66830e79467376a5f2ec142dc Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 19 Sep 2023 13:41:40 +0800 Subject: [PATCH 38/70] enrich sparse encoding IT mappings Signed-off-by: zhichao-aws --- .../processor/SparseEncodingProcessIT.java | 13 +++++++++++-- .../resources/processor/SparseIndexMappings.json | 14 ++++++++++---- .../processor/SparsePipelineConfiguration.json | 7 ++++++- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index ea0dfaaff..f2ae6940c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -67,7 +67,17 @@ private void createTextEmbeddingIndex() throws Exception { private void ingestDocument() throws Exception { String ingestDocument = "{\n" - + "\"passage_text\": \"This is a good day\"" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"favor_list\": [\n" + + " \"test\",\n" + + " \"hello\",\n" + + " \"mock\"\n" + + " ],\n" + + " \"favorites\": {\n" + + " \"game\": \"overwatch\",\n" + + " \"movie\": null\n" + + " }\n" + "}\n"; Response response = makeRequest( client(), @@ -85,5 +95,4 @@ private void ingestDocument() throws Exception { assertEquals("created", map.get("result")); } - } diff --git a/src/test/resources/processor/SparseIndexMappings.json b/src/test/resources/processor/SparseIndexMappings.json index d06cfd600..87dee278e 100644 --- a/src/test/resources/processor/SparseIndexMappings.json +++ b/src/test/resources/processor/SparseIndexMappings.json @@ -4,16 +4,22 @@ }, "mappings": { "properties": { - "passage_text": { - "type": "text" + "title_sparse": { + "type": "rank_features" }, - "passage_sparse": { + "favor_list_sparse": { "type": "nested", "properties":{ - "sparseEncoding":{ + "sparse_encoding":{ "type": "rank_features" } } + }, + "favorites.game_sparse": { + "type": "rank_features" + }, + "favorites.movie_sparse": { + "type": "rank_features" } } } diff --git a/src/test/resources/processor/SparsePipelineConfiguration.json b/src/test/resources/processor/SparsePipelineConfiguration.json index 297bbf80a..82d13c8fe 100644 --- a/src/test/resources/processor/SparsePipelineConfiguration.json +++ b/src/test/resources/processor/SparsePipelineConfiguration.json @@ -5,7 +5,12 @@ "sparse_encoding": { "model_id": "%s", "field_map": { - "passage_text": "passage_sparse" + "title": "title_sparse", + "favor_list": "favor_list_sparse", + "favorites": { + "game": "game_sparse", + "movie": "movie_sparse" + } } } } From 854e9c4f5b3c453c3d987d24556f60aca5d44b9c Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 19 Sep 2023 17:14:30 +0800 Subject: [PATCH 39/70] add it Signed-off-by: zhichao-aws --- build.gradle | 2 + .../opensearch/neuralsearch/TestUtils.java | 24 +++ .../common/BaseNeuralSearchIT.java | 4 +- .../common/BaseSparseEncodingIT.java | 131 ++++++++++++ .../processor/SparseEncodingProcessIT.java | 16 +- .../query/sparse/SparseEncodingQueryIT.java | 198 ++++++++++++++++++ .../UploadSparseModelRequestBody.json | 1 + 7 files changed, 363 insertions(+), 13 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java diff --git a/build.gradle b/build.gradle index 1d8eca483..853aa85e7 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,8 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" + runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' + runtimeOnly group: 'org.json', name: 'json', version: '20230227' } // In order to add the jar to the classpath, we need to unzip the diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 3b131b886..7f433d47c 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -11,8 +11,12 @@ import static org.junit.Assert.assertTrue; import static org.opensearch.test.OpenSearchTestCase.randomFloat; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -72,6 +76,26 @@ public static float[] createRandomVector(int dimension) { return vector; } + public static Float createFloatNumberWithEffectiveDigits(float inputNumber, int scale) { + BigDecimal bd = new BigDecimal(inputNumber); + return bd.setScale(scale, RoundingMode.HALF_UP).floatValue(); + } + + /** + * Create a map of provided tokens, the values will be random float numbers + * + * @param tokens of the created map keys + * @return token weight map with random weight > 0 + */ + public static Map createRandomTokenWeightMap(Collection tokens) { + Map resultMap = new HashMap<>(); + for (String token: tokens) { + // use a small shift to ensure value > 0 + resultMap.put(token, createFloatNumberWithEffectiveDigits(Math.abs(randomFloat()) + 1e-3f, 3)); + } + return resultMap; + } + /** * Assert results of hybrid query after score normalization and combination * @param querySearchResults collection of query search results after they processed by normalization processor diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 6aea1a9f2..589f5d0d5 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -54,12 +54,12 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { - private static final Locale LOCALE = Locale.ROOT; + protected static final Locale LOCALE = Locale.ROOT; private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; - private static final String DEFAULT_USER_AGENT = "Kibana"; + protected static final String DEFAULT_USER_AGENT = "Kibana"; protected static final String DEFAULT_NORMALIZATION_METHOD = "min_max"; protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java new file mode 100644 index 000000000..65fe42151 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.common; + +import com.google.common.collect.ImmutableList; +import lombok.SneakyThrows; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.util.TokenWeightUtil; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +public abstract class BaseSparseEncodingIT extends BaseNeuralSearchIT{ + + @SneakyThrows + @Override + protected String prepareModel() { + String requestBody = Files.readString( + Path.of(classLoader.getResource("processor/UploadSparseModelRequestBody.json").toURI()) + ); + String modelId = uploadModel(requestBody); + loadModel(modelId); + return modelId; + } + + @SneakyThrows + protected void prepareSparseEncodingIndex(String indexName, List sparseEncodingFieldNames) { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("mappings") + .startObject("properties"); + + for (String fieldName: sparseEncodingFieldNames) { + xContentBuilder.startObject(fieldName) + .field("type", "rank_features") + .endObject(); + } + + xContentBuilder.endObject().endObject().endObject(); + String indexMappings = xContentBuilder.toString(); + createIndexWithConfiguration(indexName, indexMappings, ""); + } + + @SneakyThrows + protected void addSparseEncodingDoc( + String index, + String docId, + List fieldNames, + List> docs + ) { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldNames.size(); i++) { + builder.field(fieldNames.get(i), docs.get(i)); + } + + builder.endObject(); + + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected float computeExpectedScore(String modelId, Map tokenWeightMap, String queryText) { + Map queryTokens = runSparseModelInference(modelId, queryText); + return computeExpectedScore(tokenWeightMap, queryTokens); + } + + protected float computeExpectedScore(Map tokenWeightMap, Map queryTokens) { + Float score = 0f; + for (Map.Entry entry: queryTokens.entrySet()) { + if (tokenWeightMap.containsKey(entry.getKey())) { + score += entry.getValue() * getFeatureFieldCompressedNumber(tokenWeightMap.get(entry.getKey())); + } + } + return score; + } + + @SneakyThrows + protected Map runSparseModelInference(String modelId, String queryText) { + Response inferenceResponse = makeRequest( + client(), + "POST", + String.format(LOCALE, "/_plugins/_ml/models/%s/_predict", modelId), + null, + toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"]}", queryText)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + + Map inferenceResJson = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(inferenceResponse.getEntity()), + false + ); + + Object inference_results = inferenceResJson.get("inference_results"); + assertTrue(inference_results instanceof List); + List inferenceResultsAsMap = (List) inference_results; + assertEquals(1, inferenceResultsAsMap.size()); + Map result = (Map) inferenceResultsAsMap.get(0); + List output = (List) result.get("output"); + assertEquals(1, output.size()); + Map map = (Map) output.get(0); + assertEquals(1, map.size()); + Map dataAsMap = (Map) map.get("dataAsMap"); + return TokenWeightUtil.fetchListOfTokenWeightMap(List.of(dataAsMap)).get(0); + } + + // rank_features use lucene FeatureField, which will compress the Float number to 16 bit + // this function simulate the encoding and decoding progress in lucene FeatureField + protected Float getFeatureFieldCompressedNumber(Float originNumber) { + int freqBits = Float.floatToIntBits(originNumber); + freqBits = freqBits >> 15; + freqBits = ((int) ((float) freqBits)) << 15; + return Float.intBitsToFloat(freqBits); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index f2ae6940c..d5ab277f4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -15,13 +15,13 @@ import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -public class SparseEncodingProcessIT extends BaseNeuralSearchIT { +public class SparseEncodingProcessIT extends BaseSparseEncodingIT { private static final String INDEX_NAME = "sparse_encoding_index"; @@ -44,20 +44,14 @@ public void setPipelineName() { } public void testSparseEncodingProcessor() throws Exception { - String modelId = uploadSparseEncodingModel(); - loadModel(modelId); + String modelId = prepareModel(); createPipelineProcessor(modelId, PIPELINE_NAME); - createTextEmbeddingIndex(); + createSparseEncodingIndex(); ingestDocument(); assertEquals(1, getDocCount(INDEX_NAME)); } - private String uploadSparseEncodingModel() throws Exception { - String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadSparseModelRequestBody.json").toURI())); - return uploadModel(requestBody); - } - - private void createTextEmbeddingIndex() throws Exception { + private void createSparseEncodingIndex() throws Exception { createIndexWithConfiguration( INDEX_NAME, Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())), diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java new file mode 100644 index 000000000..aa3862567 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java @@ -0,0 +1,198 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query.sparse; + +import lombok.SneakyThrows; +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; + +import java.util.List; +import java.util.Map; + +import static org.opensearch.neuralsearch.TestUtils.objectToFloat; + +public class SparseEncodingQueryIT extends BaseSparseEncodingIT { + private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; + private static final String TEST_MULTI_VECTOR_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; + private static final String TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; + private static final String TEST_NESTED_INDEX_NAME = "test-sparse-nested-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME = "test-sparse-multi-doc-index"; + private static final String TEST_QUERY_TEXT = "Hello world a b"; + private static final String TEST_SPARSE_ENCODING_FIELD_NAME_1 = "test-sparse-encoding-1"; + private static final String TEST_SPARSE_ENCODING_FIELD_NAME_2 = "test-sparse-encoding-2"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field"; + private static final String TEST_SPARSE_ENCODING_FIELD_NAME_NESTED = "nested.sparse_encoding.field"; + + private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); + private final Map testTokenWeightMap = TestUtils.createRandomTokenWeightMap(TEST_TOKENS); + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + findDeployedModels().forEach(this::deleteModel); + } + + /** + * Tests basic query: + * { + * "query": { + * "sparse_encoding": { + * "text_sparse": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryUsingQueryText() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + + /** + * Tests basic query: + * { + * "query": { + * "sparse_encoding": { + * "text_sparse": { + * "query_tokens": { + * "hello": float, + * "a": float, + * "c": float + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryUsingQueryTokens() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + Map queryTokens = TestUtils.createRandomTokenWeightMap(List.of("hello","a","b")); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) + .queryTokens(queryTokens); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(testTokenWeightMap, queryTokens); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + + /** + * Tests basic query: + * { + * "query": { + * "sparse_encoding": { + * "text_sparse": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd", + * "boost": 2 + * } + * } + * } + * } + */ + @SneakyThrows + public void testBoostQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId) + .boost(2.0f); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + + /** + * Tests rescore query: + * { + * "query" : { + * "match_all": {} + * }, + * "rescore": { + * "query": { + * "rescore_query": { + * "sparse_encoding": { + * "text_sparse": { + * * "query_text": "Hello world a b", + * * "model_id": "dcsdcasd" + * * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRescoreQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + Map searchResponseAsMap = search( + TEST_BASIC_INDEX_NAME, + matchAllQueryBuilder, + sparseEncodingQueryBuilder, + 1 + ); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + + @SneakyThrows + protected void initializeIndexIfNotExist(String indexName) { + if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { + prepareSparseEncodingIndex( + TEST_BASIC_INDEX_NAME, + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1) + ); + addSparseEncodingDoc( + TEST_BASIC_INDEX_NAME, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), + List.of(testTokenWeightMap) + ); + assertEquals(1, getDocCount(TEST_BASIC_INDEX_NAME)); + } + } +} diff --git a/src/test/resources/processor/UploadSparseModelRequestBody.json b/src/test/resources/processor/UploadSparseModelRequestBody.json index 2d48a4170..e630e6dca 100644 --- a/src/test/resources/processor/UploadSparseModelRequestBody.json +++ b/src/test/resources/processor/UploadSparseModelRequestBody.json @@ -4,6 +4,7 @@ "function_name": "TOKENIZE", "description": "test model", "model_format": "TORCH_SCRIPT", + "model_group_id": "", "model_content_hash_value": "e23969f8bd417e7aec26f49201da4adfc6b74e6187d1ddfdfb98e473bdd95978", "url": "https://github.com/xinyual/demo/raw/main/tokenizer-idf-msmarco.zip" } \ No newline at end of file From 03ff8b8916b7dc88cb299068afeb0358aee8e1f4 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 19 Sep 2023 09:45:30 +0000 Subject: [PATCH 40/70] add it Signed-off-by: zhichao-aws --- .../opensearch/neuralsearch/TestUtils.java | 2 + .../query/sparse/SparseEncodingQueryIT.java | 75 +++++++++++++++++-- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 7f433d47c..f3e36a6f5 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -76,6 +76,8 @@ public static float[] createRandomVector(int dimension) { return vector; } + // When ingesting token weight map, float number will be decoded to json, which may lose precision + // To compute match score without losing precision, we limit the effective digits of float number public static Float createFloatNumberWithEffectiveDigits(float inputNumber, int scale) { BigDecimal bd = new BigDecimal(inputNumber); return bd.setScale(scale, RoundingMode.HALF_UP).floatValue(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java index aa3862567..db52fb73f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java @@ -8,6 +8,7 @@ import lombok.SneakyThrows; import org.junit.After; import org.junit.Before; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; @@ -19,8 +20,8 @@ public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; - private static final String TEST_MULTI_VECTOR_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; - private static final String TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; + private static final String TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; + private static final String TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; private static final String TEST_NESTED_INDEX_NAME = "test-sparse-nested-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-sparse-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world a b"; @@ -179,20 +180,82 @@ public void testRescoreQuery() { assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } + /** + * Tests bool should query with query tokens: + * { + * "query": { + * "bool" : { + * "should": [ + * "sparse_encoding": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * }, + * "sparse_encoding": { + * "field2": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBooleanQuery_withMultipleNeuralQueries() { + initializeIndexIfNotExist(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder1 = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder2 = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_2) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + + boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); + + Map searchResponseAsMap = search(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + @SneakyThrows protected void initializeIndexIfNotExist(String indexName) { - if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { + if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { prepareSparseEncodingIndex( - TEST_BASIC_INDEX_NAME, + indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1) ); addSparseEncodingDoc( - TEST_BASIC_INDEX_NAME, + indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testTokenWeightMap) ); - assertEquals(1, getDocCount(TEST_BASIC_INDEX_NAME)); + assertEquals(1, getDocCount(indexName)); + } + + if (TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex( + indexName, + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2) + ); + addSparseEncodingDoc( + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2), + List.of(testTokenWeightMap, testTokenWeightMap) + ); + assertEquals(1, getDocCount(indexName)); } } } From e79180786efd4b25aebdcb74f657fe07412470bf Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 13:05:48 +0800 Subject: [PATCH 41/70] add integ test Signed-off-by: zhichao-aws --- .../common/BaseSparseEncodingIT.java | 16 ++++ .../query/sparse/SparseEncodingQueryIT.java | 78 ++++++++++++++++++- 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java index 65fe42151..0b96b85ab 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java @@ -21,6 +21,7 @@ import java.nio.file.Files; import java.nio.file.Path; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -61,6 +62,18 @@ protected void addSparseEncodingDoc( String docId, List fieldNames, List> docs + ) { + addSparseEncodingDoc(index, docId, fieldNames, docs, Collections.emptyList(), Collections.emptyList()); + } + + @SneakyThrows + protected void addSparseEncodingDoc( + String index, + String docId, + List fieldNames, + List> docs, + List textFieldNames, + List texts ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -68,6 +81,9 @@ protected void addSparseEncodingDoc( builder.field(fieldNames.get(i), docs.get(i)); } + for (int i = 0; i < textFieldNames.size(); i++) { + builder.field(textFieldNames.get(i), texts.get(i)); + } builder.endObject(); request.setJsonEntity(builder.toString()); diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java index db52fb73f..156a997d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java @@ -10,6 +10,7 @@ import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; @@ -23,7 +24,6 @@ public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final String TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; private static final String TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; private static final String TEST_NESTED_INDEX_NAME = "test-sparse-nested-index"; - private static final String TEST_MULTI_DOC_INDEX_NAME = "test-sparse-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world a b"; private static final String TEST_SPARSE_ENCODING_FIELD_NAME_1 = "test-sparse-encoding-1"; private static final String TEST_SPARSE_ENCODING_FIELD_NAME_2 = "test-sparse-encoding-2"; @@ -204,7 +204,7 @@ public void testRescoreQuery() { * } */ @SneakyThrows - public void testBooleanQuery_withMultipleNeuralQueries() { + public void testBooleanQuery_withMultipleSparseEncodingQueries() { initializeIndexIfNotExist(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME); String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); @@ -228,6 +228,50 @@ public void testBooleanQuery_withMultipleNeuralQueries() { assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } + /** + * Tests bool should query with query tokens: + * { + * "query": { + * "bool" : { + * "should": [ + * "sparse_encoding": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * }, + * "sparse_encoding": { + * "field2": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBooleanQuery_withSparseEncodingAndBM25Queries() { + initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() + .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + boolQueryBuilder.should(sparseEncodingQueryBuilder).should(matchQueryBuilder); + + Map searchResponseAsMap = search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float minExpectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); + } + @SneakyThrows protected void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { @@ -257,5 +301,35 @@ protected void initializeIndexIfNotExist(String indexName) { ); assertEquals(1, getDocCount(indexName)); } + + if (TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex( + indexName, + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1) + ); + addSparseEncodingDoc( + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), + List.of(testTokenWeightMap), + List.of(TEST_TEXT_FIELD_NAME_1), + List.of(TEST_QUERY_TEXT) + ); + assertEquals(1, getDocCount(indexName)); + } + + if (TEST_NESTED_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex( + indexName, + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED) + ); + addSparseEncodingDoc( + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), + List.of(testTokenWeightMap) + ); + assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME)); + } } } From 10e599a300de0ba8d1f84ba8ce2a15ee517dffad Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 13:08:22 +0800 Subject: [PATCH 42/70] rename resource file Signed-off-by: zhichao-aws --- .../opensearch/neuralsearch/common/BaseSparseEncodingIT.java | 2 +- .../neuralsearch/processor/SparseEncodingProcessIT.java | 4 ++-- ...rseIndexMappings.json => SparseEncodingIndexMappings.json} | 0 ...guration.json => SparseEncodingPipelineConfiguration.json} | 0 ...estBody.json => UploadSparseEncodingModelRequestBody.json} | 0 5 files changed, 3 insertions(+), 3 deletions(-) rename src/test/resources/processor/{SparseIndexMappings.json => SparseEncodingIndexMappings.json} (100%) rename src/test/resources/processor/{SparsePipelineConfiguration.json => SparseEncodingPipelineConfiguration.json} (100%) rename src/test/resources/processor/{UploadSparseModelRequestBody.json => UploadSparseEncodingModelRequestBody.json} (100%) diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java index 0b96b85ab..ea81d403f 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java @@ -31,7 +31,7 @@ public abstract class BaseSparseEncodingIT extends BaseNeuralSearchIT{ @Override protected String prepareModel() { String requestBody = Files.readString( - Path.of(classLoader.getResource("processor/UploadSparseModelRequestBody.json").toURI()) + Path.of(classLoader.getResource("processor/UploadSparseEncodingModelRequestBody.json").toURI()) ); String modelId = uploadModel(requestBody); loadModel(modelId); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index d5ab277f4..611fe23a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -40,7 +40,7 @@ public void tearDown() { @Before public void setPipelineName() { - this.setPipelineConfigurationName("processor/SparsePipelineConfiguration.json"); + this.setPipelineConfigurationName("processor/SparseEncodingPipelineConfiguration.json"); } public void testSparseEncodingProcessor() throws Exception { @@ -54,7 +54,7 @@ public void testSparseEncodingProcessor() throws Exception { private void createSparseEncodingIndex() throws Exception { createIndexWithConfiguration( INDEX_NAME, - Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())), + Files.readString(Path.of(classLoader.getResource("processor/SparseEncodingIndexMappings.json").toURI())), PIPELINE_NAME ); } diff --git a/src/test/resources/processor/SparseIndexMappings.json b/src/test/resources/processor/SparseEncodingIndexMappings.json similarity index 100% rename from src/test/resources/processor/SparseIndexMappings.json rename to src/test/resources/processor/SparseEncodingIndexMappings.json diff --git a/src/test/resources/processor/SparsePipelineConfiguration.json b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json similarity index 100% rename from src/test/resources/processor/SparsePipelineConfiguration.json rename to src/test/resources/processor/SparseEncodingPipelineConfiguration.json diff --git a/src/test/resources/processor/UploadSparseModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json similarity index 100% rename from src/test/resources/processor/UploadSparseModelRequestBody.json rename to src/test/resources/processor/UploadSparseEncodingModelRequestBody.json From 1e14a2601d2dcf917d52b1e11ddea4f421e9fa1a Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 13:14:02 +0800 Subject: [PATCH 43/70] tidy Signed-off-by: zhichao-aws --- .../ml/MLCommonsClientAccessor.java | 55 ++- .../neuralsearch/plugin/NeuralSearch.java | 2 +- .../neuralsearch/processor/NLPProcessor.java | 114 +++--- .../processor/SparseEncodingProcessor.java | 14 +- .../processor/TextEmbeddingProcessor.java | 3 +- .../SparseEncodingProcessorFactory.java | 20 +- .../sparse/BoundedLinearFeatureQuery.java | 31 +- .../sparse/SparseEncodingQueryBuilder.java | 174 +++++---- .../neuralsearch/util/TokenWeightUtil.java | 22 +- .../opensearch/neuralsearch/TestUtils.java | 2 +- .../common/BaseNeuralSearchIT.java | 8 +- .../common/BaseSparseEncodingIT.java | 70 ++-- .../ml/MLCommonsClientAccessorTests.java | 49 +-- .../processor/SparseEncodingProcessIT.java | 58 +-- .../SparseEncodingProcessorTests.java | 35 +- .../SparseEncodingQueryBuilderTests.java | 333 +++++++++--------- .../query/sparse/SparseEncodingQueryIT.java | 123 +++---- .../util/TokenWeightUtilTests.java | 40 +-- 18 files changed, 516 insertions(+), 637 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index bedeff46e..2571dc4e2 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -47,16 +47,16 @@ public class MLCommonsClientAccessor { * @param listener {@link ActionListener} which will be called when prediction is completed or errored out */ public void inferenceSentence( - @NonNull final String modelId, - @NonNull final String inputText, - @NonNull final ActionListener> listener + @NonNull final String modelId, + @NonNull final String inputText, + @NonNull final ActionListener> listener ) { inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), ActionListener.wrap(response -> { if (response.size() != 1) { listener.onFailure( - new IllegalStateException( - "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]" - ) + new IllegalStateException( + "Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]" + ) ); return; } @@ -77,9 +77,9 @@ public void inferenceSentence( * @param listener {@link ActionListener} which will be called when prediction is completed or errored out */ public void inferenceSentences( - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener ) { inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener); } @@ -97,10 +97,10 @@ public void inferenceSentences( * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. */ public void inferenceSentences( - @NonNull final List targetResponseFilters, - @NonNull final String modelId, - @NonNull final List inputText, - @NonNull final ActionListener>> listener + @NonNull final List targetResponseFilters, + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener ) { retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); } @@ -108,7 +108,8 @@ public void inferenceSentences( public void inferenceSentencesWithMapResult( @NonNull final String modelId, @NonNull final List inputText, - @NonNull final ActionListener>> listener) { + @NonNull final ActionListener>> listener + ) { retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); } @@ -134,11 +135,11 @@ private void retryableInferenceSentencesWithMapResult( } private void retryableInferenceSentencesWithVectorResult( - final List targetResponseFilters, - final String modelId, - final List inputText, - final int retryTime, - final ActionListener>> listener + final List targetResponseFilters, + final String modelId, + final List inputText, + final int retryTime, + final ActionListener>> listener ) { MLInput mlInput = createMLInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { @@ -174,24 +175,20 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } - private List > buildMapResultFromResponse(MLOutput mlOutput) { + private List> buildMapResultFromResponse(MLOutput mlOutput) { final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { - throw new IllegalStateException( - "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]" - ); + throw new IllegalStateException("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]"); } - List > resultMaps = new ArrayList<>(); - for (ModelTensors tensors: tensorOutputList) - { + List> resultMaps = new ArrayList<>(); + for (ModelTensors tensors : tensorOutputList) { List tensorList = tensors.getMlModelTensors(); - for (ModelTensor tensor: tensorList) - { + for (ModelTensor tensor : tensorList) { resultMaps.add(tensor.getDataAsMap()); } } return resultMaps; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 53b4b30d7..fc93af934 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -36,8 +36,8 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; -import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 02d5e5747..d4b3c340f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -5,16 +5,6 @@ package org.opensearch.neuralsearch.processor; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; -import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang3.StringUtils; -import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.ingest.AbstractProcessor; -import org.opensearch.ingest.IngestDocument; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; - import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; @@ -24,6 +14,18 @@ import java.util.function.Supplier; import java.util.stream.IntStream; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; + @Log4j2 public abstract class NLPProcessor extends AbstractProcessor { @@ -44,14 +46,14 @@ public abstract class NLPProcessor extends AbstractProcessor { protected final Environment environment; public NLPProcessor( - String tag, - String description, - String type, - String listTypeNestedMapKey, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment + String tag, + String description, + String type, + String listTypeNestedMapKey, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment ) { super(tag, description); this.type = type; @@ -67,11 +69,11 @@ public NLPProcessor( private void validateEmbeddingConfiguration(Map fieldMap) { if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() + || fieldMap.size() == 0 + || fieldMap.entrySet() .stream() .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) )) { throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); } @@ -99,9 +101,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl validateListTypeValue(sourceKey, sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { @@ -128,20 +130,20 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { } private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes ) { if (processorKey == null || sourceAndMetadataMap == null) return; if (processorKey instanceof Map) { Map next = new LinkedHashMap<>(); for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next ); } treeRes.put(parentKey, next); @@ -197,7 +199,12 @@ private List createInferenceList(Map knnKeyMap) { return texts; } - public abstract void doExecute(IngestDocument ingestDocument,Map ProcessMap, List inferenceList, BiConsumer handler); + public abstract void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ); @Override public IngestDocument execute(IngestDocument ingestDocument) throws Exception { @@ -211,7 +218,7 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. */ @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler){ + public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { validateEmbeddingFieldsValue(ingestDocument); Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); @@ -235,11 +242,7 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map buildNLPResult( - Map processorMap, - List results, - Map sourceAndMetadataMap - ) { + Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0); Map result = new LinkedHashMap<>(); for (Map.Entry knnMapEntry : processorMap.entrySet()) { @@ -258,41 +261,38 @@ Map buildNLPResult( @SuppressWarnings({ "unchecked" }) private void putNLPResultToSourceMapForMapType( - String processorKey, - Object sourceValue, - List results, - NLPProcessor.IndexWrapper indexWrapper, - Map sourceAndMetadataMap + String processorKey, + Object sourceValue, + List results, + NLPProcessor.IndexWrapper indexWrapper, + Map sourceAndMetadataMap ) { if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; if (sourceValue instanceof Map) { for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { putNLPResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - results, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) ); } } else if (sourceValue instanceof String) { sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put( - processorKey, - buildNLPResultForListType((List) sourceValue, results, indexWrapper) - ); + sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); } } private List> buildNLPResultForListType( - List sourceValue, - List results, - NLPProcessor.IndexWrapper indexWrapper + List sourceValue, + List results, + NLPProcessor.IndexWrapper indexWrapper ) { List> keyToResult = new ArrayList<>(); IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); return keyToResult; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index ea15c8b8c..217d551c4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -5,19 +5,18 @@ package org.opensearch.neuralsearch.processor; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + import lombok.extern.log4j.Log4j2; + import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.TokenWeightUtil; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; - - @Log4j2 public class SparseEncodingProcessor extends NLPProcessor { @@ -40,8 +39,7 @@ public void doExecute( IngestDocument ingestDocument, Map ProcessMap, List inferenceList, - BiConsumer handler + BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index bd690558d..354b53945 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -42,8 +42,7 @@ public void doExecute( IngestDocument ingestDocument, Map ProcessMap, List inferenceList, - BiConsumer handler + BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 8065af445..dff56e9c8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -5,17 +5,19 @@ package org.opensearch.neuralsearch.processor.factory; +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.*; + +import java.util.Map; + import lombok.extern.log4j.Log4j2; + import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; -import java.util.Map; - -import static org.opensearch.ingest.ConfigurationUtils.readMap; -import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.*; @Log4j2 public class SparseEncodingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; @@ -28,10 +30,10 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En @Override public SparseEncodingProcessor create( - Map registry, - String processorTag, - String description, - Map config + Map registry, + String processorTag, + String description, + Map config ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java index f0b9b498c..5ce291c25 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java @@ -81,8 +81,8 @@ public boolean equals(Object obj) { } BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj; return Objects.equals(fieldName, that.fieldName) - && Objects.equals(featureName, that.featureName) - && Objects.equals(scoreUpperBound, that.scoreUpperBound); + && Objects.equals(featureName, that.featureName) + && Objects.equals(scoreUpperBound, that.scoreUpperBound); } @Override @@ -95,8 +95,7 @@ public int hashCode() { } @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) - throws IOException { + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { if (!scoreMode.needsScores()) { // We don't need scores (e.g. for faceting), and since features are stored as terms, // allow TermQuery to optimize in this case @@ -133,14 +132,11 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio float featureValue = decodeFeatureValue(freq); float score = boost * featureValue; return Explanation.match( - score, - "Linear function on the " - + fieldName - + " field for the " - + featureName - + " feature, computed as w * S from:", - Explanation.match(boost, "w, weight of this function"), - Explanation.match(featureValue, "S, feature value")); + score, + "Linear function on the " + fieldName + " field for the " + featureName + " feature, computed as w * S from:", + Explanation.match(boost, "w, weight of this function"), + Explanation.match(featureValue, "S, feature value") + ); } @Override @@ -205,16 +201,11 @@ public void visit(QueryVisitor visitor) { @Override public String toString(String field) { - return "BoundedLinearFeatureQuery(field=" - + fieldName - + ", feature=" - + featureName - + ", scoreUpperBound=" - + scoreUpperBound - + ")"; + return "BoundedLinearFeatureQuery(field=" + fieldName + ", feature=" + featureName + ", scoreUpperBound=" + scoreUpperBound + ")"; } static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15; + private float decodeFeatureValue(float freq) { if (freq > MAX_FREQ) { return scoreUpperBound; @@ -223,4 +214,4 @@ private float decodeFeatureValue(float freq) { int featureBits = tf << 15; return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java index e471381c5..957e3477a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java @@ -37,11 +37,11 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; - -import com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.TokenWeightUtil; +import com.google.common.annotations.VisibleForTesting; + @Log4j2 @Getter @Setter @@ -105,8 +105,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (null != queryTokens) xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokens); if (null != queryText) xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); if (null != modelId) xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); - if (null != tokenScoreUpperBound) - xContentBuilder.field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), tokenScoreUpperBound); + if (null != tokenScoreUpperBound) xContentBuilder.field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), tokenScoreUpperBound); printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -135,10 +134,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder(); if (parser.currentToken() != XContentParser.Token.START_OBJECT) { - throw new ParsingException( - parser.getTokenLocation(), - "First token of " + NAME + "query must be START_OBJECT" - ); + throw new ParsingException(parser.getTokenLocation(), "First token of " + NAME + "query must be START_OBJECT"); } parser.nextToken(); sparseEncodingQueryBuilder.fieldName(parser.currentName()); @@ -146,41 +142,43 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr parseQueryParams(parser, sparseEncodingQueryBuilder); if (parser.nextToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( - parser.getTokenLocation(), - "[" - + NAME - + "] query doesn't support multiple fields, found [" - + sparseEncodingQueryBuilder.fieldName() - + "] and [" - + parser.currentName() - + "]" + parser.getTokenLocation(), + "[" + + NAME + + "] query doesn't support multiple fields, found [" + + sparseEncodingQueryBuilder.fieldName() + + "] and [" + + parser.currentName() + + "]" ); } - requireValue( - sparseEncodingQueryBuilder.fieldName(), - "Field name must be provided for " + NAME + " query" - ); - if (null==sparseEncodingQueryBuilder.queryTokens()) { + requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); + if (null == sparseEncodingQueryBuilder.queryTokens()) { requireValue( - sparseEncodingQueryBuilder.queryText(), - "Either " + QUERY_TOKENS_FIELD.getPreferredName() + " or " + - QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query" + sparseEncodingQueryBuilder.queryText(), + "Either " + + QUERY_TOKENS_FIELD.getPreferredName() + + " or " + + QUERY_TEXT_FIELD.getPreferredName() + + " must be provided for " + + NAME + + " query" ); requireValue( - sparseEncodingQueryBuilder.modelId(), - MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + - " query when using " + QUERY_TEXT_FIELD.getPreferredName() + sparseEncodingQueryBuilder.modelId(), + MODEL_ID_FIELD.getPreferredName() + + " must be provided for " + + NAME + + " query when using " + + QUERY_TEXT_FIELD.getPreferredName() ); } return sparseEncodingQueryBuilder; } - private static void parseQueryParams( - XContentParser parser, - SparseEncodingQueryBuilder sparseEncodingQueryBuilder - ) throws IOException { + private static void parseQueryParams(XContentParser parser, SparseEncodingQueryBuilder sparseEncodingQueryBuilder) throws IOException { XContentParser.Token token; String currentFieldName = ""; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { @@ -199,16 +197,16 @@ private static void parseQueryParams( sparseEncodingQueryBuilder.tokenScoreUpperBound(parser.floatValue()); } else { throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]" + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseEncodingQueryBuilder.queryTokens(parser.map(HashMap::new, XContentParser::floatValue)); } else { throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" ); } } @@ -223,33 +221,32 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return this; } if (null != queryTokensSupplier) { - return queryTokensSupplier.get() == null ? this : - new SparseEncodingQueryBuilder() - .fieldName(fieldName) - .queryTokens(queryTokensSupplier.get()) - .queryText(queryText) - .modelId(modelId) - .tokenScoreUpperBound(tokenScoreUpperBound); + return queryTokensSupplier.get() == null + ? this + : new SparseEncodingQueryBuilder().fieldName(fieldName) + .queryTokens(queryTokensSupplier.get()) + .queryText(queryText) + .modelId(modelId) + .tokenScoreUpperBound(tokenScoreUpperBound); } validateForRewrite(queryText, modelId); SetOnce> queryTokensSetOnce = new SetOnce<>(); queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult( - modelId(), - List.of(queryText), - ActionListener.wrap(mapResultList -> { - queryTokensSetOnce.set(TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0)); - actionListener.onResponse(null); - }, actionListener::onFailure)) - ) + ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult( + modelId(), + List.of(queryText), + ActionListener.wrap(mapResultList -> { + queryTokensSetOnce.set(TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0)); + actionListener.onResponse(null); + }, actionListener::onFailure) + )) ); - return new SparseEncodingQueryBuilder() - .fieldName(fieldName) - .queryText(queryText) - .modelId(modelId) - .tokenScoreUpperBound(tokenScoreUpperBound) - .queryTokensSupplier(queryTokensSetOnce::get); + return new SparseEncodingQueryBuilder().fieldName(fieldName) + .queryText(queryText) + .modelId(modelId) + .tokenScoreUpperBound(tokenScoreUpperBound) + .queryTokensSupplier(queryTokensSetOnce::get); } @Override @@ -259,31 +256,28 @@ protected Query doToQuery(QueryShardContext context) throws IOException { validateQueryTokens(queryTokens); // the tokenScoreUpperBound from query has higher priority - final Float scoreUpperBound = null != tokenScoreUpperBound? tokenScoreUpperBound: Float.MAX_VALUE; + final Float scoreUpperBound = null != tokenScoreUpperBound ? tokenScoreUpperBound : Float.MAX_VALUE; BooleanQuery.Builder builder = new BooleanQuery.Builder(); - for (Map.Entry entry: queryTokens.entrySet()) { + for (Map.Entry entry : queryTokens.entrySet()) { builder.add( - new BoostQuery( - new BoundedLinearFeatureQuery( - fieldName, - entry.getKey(), - scoreUpperBound - ), - entry.getValue() - ), - BooleanClause.Occur.SHOULD + new BoostQuery(new BoundedLinearFeatureQuery(fieldName, entry.getKey(), scoreUpperBound), entry.getValue()), + BooleanClause.Occur.SHOULD ); } return builder.build(); } private static void validateForRewrite(String queryText, String modelId) { - if (null == queryText||null == modelId) { + if (null == queryText || null == modelId) { throw new IllegalArgumentException( - "When " + QUERY_TOKENS_FIELD.getPreferredName() + " are not provided," + - QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + - " can not be null." + "When " + + QUERY_TOKENS_FIELD.getPreferredName() + + " are not provided," + + QUERY_TEXT_FIELD.getPreferredName() + + " and " + + MODEL_ID_FIELD.getPreferredName() + + " can not be null." ); } } @@ -291,22 +285,18 @@ private static void validateForRewrite(String queryText, String modelId) { private static void validateFieldType(MappedFieldType fieldType) { if (!fieldType.typeName().equals("rank_features")) { throw new IllegalArgumentException( - "[" + NAME + "] query only works on [rank_features] fields, " - + "not [" + fieldType.typeName() + "]" + "[" + NAME + "] query only works on [rank_features] fields, " + "not [" + fieldType.typeName() + "]" ); } } private static void validateQueryTokens(Map queryTokens) { if (null == queryTokens) { - throw new IllegalArgumentException( - QUERY_TOKENS_FIELD.getPreferredName() + " field can not be null." - ); + throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field can not be null."); } - for (Map.Entry entry: queryTokens.entrySet()) { + for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { - throw new IllegalArgumentException( - "weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()); + throw new IllegalArgumentException("weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()); } } } @@ -315,28 +305,26 @@ private static void validateQueryTokens(Map queryTokens) { protected boolean doEquals(SparseEncodingQueryBuilder obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; - EqualsBuilder equalsBuilder = new EqualsBuilder() - .append(fieldName, obj.fieldName) - .append(queryTokens, obj.queryTokens) - .append(queryText, obj.queryText) - .append(modelId, obj.modelId) - .append(tokenScoreUpperBound, obj.tokenScoreUpperBound); + EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) + .append(queryTokens, obj.queryTokens) + .append(queryText, obj.queryText) + .append(modelId, obj.modelId) + .append(tokenScoreUpperBound, obj.tokenScoreUpperBound); return equalsBuilder.isEquals(); } @Override protected int doHashCode() { - return new HashCodeBuilder() - .append(fieldName) - .append(queryTokens) - .append(queryText) - .append(modelId) - .append(tokenScoreUpperBound) - .toHashCode(); + return new HashCodeBuilder().append(fieldName) + .append(queryTokens) + .append(queryText) + .append(modelId) + .append(tokenScoreUpperBound) + .toHashCode(); } @Override public String getWriteableName() { return NAME; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index bd118b480..d552a4362 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -39,15 +39,12 @@ public class TokenWeightUtil { */ public static List> fetchListOfTokenWeightMap(List> mapResultList) { List results = new ArrayList<>(); - for (Map map: mapResultList) - { - if (!map.containsKey(RESPONSE_KEY)){ - throw new IllegalArgumentException("The inference result should be associated with the field [" - + RESPONSE_KEY + "]."); + for (Map map : mapResultList) { + if (!map.containsKey(RESPONSE_KEY)) { + throw new IllegalArgumentException("The inference result should be associated with the field [" + RESPONSE_KEY + "]."); } if (!List.class.isAssignableFrom(map.get(RESPONSE_KEY).getClass())) { - throw new IllegalArgumentException("The data object associated with field [" - + RESPONSE_KEY + "] should be a list."); + throw new IllegalArgumentException("The data object associated with field [" + RESPONSE_KEY + "] should be a list."); } results.addAll((List) map.get("response")); } @@ -56,15 +53,12 @@ public static List> fetchListOfTokenWeightMap(List buildTokenWeightMap(Object uncastedMap) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { - throw new IllegalArgumentException("The expected inference result is a Map with String keys and " - + " Float values."); + throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + " Float values."); } Map result = new HashMap<>(); - for (Map.Entry entry: ((Map) uncastedMap).entrySet()) { - if (!String.class.isAssignableFrom(entry.getKey().getClass()) - || !Number.class.isAssignableFrom(entry.getValue().getClass())){ - throw new IllegalArgumentException("The expected inference result is a Map with String keys and " - + " Float values."); + for (Map.Entry entry : ((Map) uncastedMap).entrySet()) { + if (!String.class.isAssignableFrom(entry.getKey().getClass()) || !Number.class.isAssignableFrom(entry.getValue().getClass())) { + throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + " Float values."); } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index f3e36a6f5..9565be474 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -91,7 +91,7 @@ public static Float createFloatNumberWithEffectiveDigits(float inputNumber, int */ public static Map createRandomTokenWeightMap(Collection tokens) { Map resultMap = new HashMap<>(); - for (String token: tokens) { + for (String token : tokens) { // use a small shift to ensure value > 0 resultMap.put(token, createFloatNumberWithEffectiveDigits(Math.abs(randomFloat()) + 1e-3f, 3)); } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 589f5d0d5..84672d479 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -68,7 +68,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected final ClassLoader classLoader = this.getClass().getClassLoader(); - protected void setPipelineConfigurationName(String pipelineConfigurationName){ + protected void setPipelineConfigurationName(String pipelineConfigurationName) { this.PIPELINE_CONFIGURATION_NAME = pipelineConfigurationName; } @@ -243,11 +243,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro "/_ingest/pipeline/" + pipelineName, null, toHttpEntity( - String.format( - LOCALE, - Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGURATION_NAME).toURI())), - modelId - ) + String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGURATION_NAME).toURI())), modelId) ), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java index ea81d403f..d0231cfe6 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java @@ -5,8 +5,14 @@ package org.opensearch.neuralsearch.common; -import com.google.common.collect.ImmutableList; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Map; + import lombok.SneakyThrows; + import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; @@ -19,19 +25,15 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.neuralsearch.util.TokenWeightUtil; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import com.google.common.collect.ImmutableList; -public abstract class BaseSparseEncodingIT extends BaseNeuralSearchIT{ +public abstract class BaseSparseEncodingIT extends BaseNeuralSearchIT { @SneakyThrows @Override protected String prepareModel() { String requestBody = Files.readString( - Path.of(classLoader.getResource("processor/UploadSparseEncodingModelRequestBody.json").toURI()) + Path.of(classLoader.getResource("processor/UploadSparseEncodingModelRequestBody.json").toURI()) ); String modelId = uploadModel(requestBody); loadModel(modelId); @@ -40,15 +42,10 @@ protected String prepareModel() { @SneakyThrows protected void prepareSparseEncodingIndex(String indexName, List sparseEncodingFieldNames) { - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject("mappings") - .startObject("properties"); - - for (String fieldName: sparseEncodingFieldNames) { - xContentBuilder.startObject(fieldName) - .field("type", "rank_features") - .endObject(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("mappings").startObject("properties"); + + for (String fieldName : sparseEncodingFieldNames) { + xContentBuilder.startObject(fieldName).field("type", "rank_features").endObject(); } xContentBuilder.endObject().endObject().endObject(); @@ -57,23 +54,18 @@ protected void prepareSparseEncodingIndex(String indexName, List sparseE } @SneakyThrows - protected void addSparseEncodingDoc( - String index, - String docId, - List fieldNames, - List> docs - ) { + protected void addSparseEncodingDoc(String index, String docId, List fieldNames, List> docs) { addSparseEncodingDoc(index, docId, fieldNames, docs, Collections.emptyList(), Collections.emptyList()); } @SneakyThrows protected void addSparseEncodingDoc( - String index, - String docId, - List fieldNames, - List> docs, - List textFieldNames, - List texts + String index, + String docId, + List fieldNames, + List> docs, + List textFieldNames, + List texts ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -98,7 +90,7 @@ protected float computeExpectedScore(String modelId, Map tokenWei protected float computeExpectedScore(Map tokenWeightMap, Map queryTokens) { Float score = 0f; - for (Map.Entry entry: queryTokens.entrySet()) { + for (Map.Entry entry : queryTokens.entrySet()) { if (tokenWeightMap.containsKey(entry.getKey())) { score += entry.getValue() * getFeatureFieldCompressedNumber(tokenWeightMap.get(entry.getKey())); } @@ -109,18 +101,18 @@ protected float computeExpectedScore(Map tokenWeightMap, Map runSparseModelInference(String modelId, String queryText) { Response inferenceResponse = makeRequest( - client(), - "POST", - String.format(LOCALE, "/_plugins/_ml/models/%s/_predict", modelId), - null, - toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"]}", queryText)), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + client(), + "POST", + String.format(LOCALE, "/_plugins/_ml/models/%s/_predict", modelId), + null, + toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"]}", queryText)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); Map inferenceResJson = XContentHelper.convertToMap( - XContentType.JSON.xContent(), - EntityUtils.toString(inferenceResponse.getEntity()), - false + XContentType.JSON.xContent(), + EntityUtils.toString(inferenceResponse.getEntity()), + false ); Object inference_results = inferenceResJson.get("inference_results"); diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 7652c127b..a51c62977 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -23,7 +23,6 @@ import org.mockito.MockitoAnnotations; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.action.ActionListener; -import org.opensearch.index.IndexNotFoundException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; @@ -164,7 +163,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { } public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { -// final List> map = List.of(Map.of("key", "value")); + // final List> map = List.of(Map.of("key", "value")); final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); Mockito.doAnswer(invocation -> { @@ -194,7 +193,10 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenE .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); - assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage()); + assertEquals( + "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", + argumentCaptor.getValue().getMessage() + ); Mockito.verifyNoMoreInteractions(resultListener); } @@ -215,7 +217,10 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenEx .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); - assertEquals("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", argumentCaptor.getValue().getMessage()); + assertEquals( + "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", + argumentCaptor.getValue().getMessage() + ); Mockito.verifyNoMoreInteractions(resultListener); } @@ -223,15 +228,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTh final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); - final ModelTensor tensor = new ModelTensor( - "response", - null, - null, - null, - null, - null, - Map.of("key", "value") - ); + final ModelTensor tensor = new ModelTensor("response", null, null, null, null, null, Map.of("key", "value")); mlModelTensorList.add(tensor); mlModelTensorList.add(tensor); tensorsList.add(new ModelTensors(mlModelTensorList)); @@ -244,8 +241,8 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTh accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client) - .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); - Mockito.verify(resultListener).onResponse(List.of(Map.of("key","value"),Map.of("key","value"))); + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(List.of(Map.of("key", "value"), Map.of("key", "value"))); Mockito.verifyNoMoreInteractions(resultListener); } @@ -260,11 +257,7 @@ public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Ti return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); - accessor.inferenceSentencesWithMapResult( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, - resultListener - ); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client, times(4)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -279,11 +272,7 @@ public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenF return null; }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); final ActionListener>> resultListener = mock(ActionListener.class); - accessor.inferenceSentencesWithMapResult( - TestCommonConstants.MODEL_ID, - TestCommonConstants.SENTENCES_LIST, - resultListener - ); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); Mockito.verify(client, times(1)) .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); @@ -311,15 +300,7 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) { private ModelTensorOutput createModelTensorOutput(final Map map) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); - final ModelTensor tensor = new ModelTensor( - "response", - null, - null, - null, - null, - null, - map - ); + final ModelTensor tensor = new ModelTensor("response", null, null, null, null, null, map); mlModelTensorList.add(tensor); final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); tensorsList.add(modelTensors); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index 611fe23a6..0312eaef7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -5,8 +5,12 @@ package org.opensearch.neuralsearch.processor; -import com.google.common.collect.ImmutableList; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + import lombok.SneakyThrows; + import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; @@ -17,9 +21,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Map; +import com.google.common.collect.ImmutableList; public class SparseEncodingProcessIT extends BaseSparseEncodingIT { @@ -53,38 +55,38 @@ public void testSparseEncodingProcessor() throws Exception { private void createSparseEncodingIndex() throws Exception { createIndexWithConfiguration( - INDEX_NAME, - Files.readString(Path.of(classLoader.getResource("processor/SparseEncodingIndexMappings.json").toURI())), - PIPELINE_NAME + INDEX_NAME, + Files.readString(Path.of(classLoader.getResource("processor/SparseEncodingIndexMappings.json").toURI())), + PIPELINE_NAME ); } private void ingestDocument() throws Exception { String ingestDocument = "{\n" - + " \"title\": \"This is a good day\",\n" - + " \"description\": \"daily logging\",\n" - + " \"favor_list\": [\n" - + " \"test\",\n" - + " \"hello\",\n" - + " \"mock\"\n" - + " ],\n" - + " \"favorites\": {\n" - + " \"game\": \"overwatch\",\n" - + " \"movie\": null\n" - + " }\n" - + "}\n"; + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"favor_list\": [\n" + + " \"test\",\n" + + " \"hello\",\n" + + " \"mock\"\n" + + " ],\n" + + " \"favorites\": {\n" + + " \"game\": \"overwatch\",\n" + + " \"movie\": null\n" + + " }\n" + + "}\n"; Response response = makeRequest( - client(), - "POST", - INDEX_NAME + "/_doc?refresh", - null, - toHttpEntity(ingestDocument), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Map map = XContentHelper.convertToMap( - XContentType.JSON.xContent(), - EntityUtils.toString(response.getEntity()), - false + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false ); assertEquals("created", map.get("result")); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index e79f286fe..209db58a8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -5,9 +5,16 @@ package org.opensearch.neuralsearch.processor; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +import java.util.*; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; + import lombok.SneakyThrows; + import org.junit.Before; import org.mockito.InjectMocks; import org.mockito.Mock; @@ -19,16 +26,10 @@ import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; -import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.test.OpenSearchTestCase; -import java.util.*; -import java.util.function.BiConsumer; -import java.util.stream.IntStream; - -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.*; -import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; public class SparseEncodingProcessorTests extends OpenSearchTestCase { @Mock @@ -65,7 +66,7 @@ public void testExecute_successful() { IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); SparseEncodingProcessor processor = createInstance(); - List > dataAsMapList = createMockMapResult(2); + List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); listener.onResponse(dataAsMapList); @@ -104,7 +105,7 @@ public void testExecute_withListTypeInput_successful() { IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); SparseEncodingProcessor processor = createInstance(); - List > dataAsMapList = createMockMapResult(6); + List> dataAsMapList = createMockMapResult(6); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); listener.onResponse(dataAsMapList); @@ -116,7 +117,6 @@ public void testExecute_withListTypeInput_successful() { verify(handler).accept(any(IngestDocument.class), isNull()); } - public void testExecute_MLClientAccessorThrowFail_handlerFailure() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "value1"); @@ -144,7 +144,7 @@ public void testExecute_withMapTypeInput_successful() { IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); SparseEncodingProcessor processor = createInstance(); - List > dataAsMapList = createMockMapResult(2); + List> dataAsMapList = createMockMapResult(2); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); listener.onResponse(dataAsMapList); @@ -157,12 +157,9 @@ public void testExecute_withMapTypeInput_successful() { } - - private List > createMockMapResult(int number) - { + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); - IntStream.range(0, number) - .forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); return mockMapResult; diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java index 5f7ecb2d0..e01825d75 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java @@ -17,7 +17,16 @@ import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TOKENS_FIELD; import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.TOKEN_SCORE_UPPER_BOUND_FIELD; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + import lombok.SneakyThrows; + import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; @@ -35,19 +44,11 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.test.OpenSearchTestCase; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; -import java.util.function.Supplier; - public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { private static final String FIELD_NAME = "testField"; private static final String QUERY_TEXT = "Hello world!"; - private static final Map QUERY_TOKENS = Map.of("hello", 1.f, "world", 2.f); + private static final Map QUERY_TOKENS = Map.of("hello", 1.f, "world", 2.f); private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final Float TOKEN_SCORE_UPPER_BOUND = 123f; private static final float BOOST = 1.8f; @@ -65,12 +66,12 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .endObject() - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -93,11 +94,11 @@ public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS) - .endObject() - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -121,15 +122,15 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), TOKEN_SCORE_UPPER_BOUND) - .field(BOOST_FIELD.getPreferredName(), BOOST) - .field(NAME_FIELD.getPreferredName(), QUERY_NAME) - .endObject() - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), TOKEN_SCORE_UPPER_BOUND) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -157,15 +158,15 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(BOOST_FIELD.getPreferredName(), BOOST) - .field(NAME_FIELD.getPreferredName(), QUERY_NAME) - .endObject() - .field("invalid", 10) - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .endObject() + .field("invalid", 10) + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -182,11 +183,11 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .endObject() - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -203,11 +204,11 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .endObject() - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -227,14 +228,14 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { } */ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) - .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .endObject() - .endObject(); + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); @@ -245,10 +246,10 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { @SneakyThrows public void testToXContent() { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) - .modelId(MODEL_ID) - .queryText(QUERY_TEXT) - .queryTokens(QUERY_TOKENS) - .tokenScoreUpperBound(TOKEN_SCORE_UPPER_BOUND); + .modelId(MODEL_ID) + .queryText(QUERY_TEXT) + .queryTokens(QUERY_TOKENS) + .tokenScoreUpperBound(TOKEN_SCORE_UPPER_BOUND); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -276,13 +277,13 @@ public void testToXContent() { // QUERY_TOKENS is map, the converted one use Map convertedQueryTokensMap = (Map) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName()); assertEquals(QUERY_TOKENS.size(), convertedQueryTokensMap.size()); - for (Map.Entry entry: QUERY_TOKENS.entrySet()) { + for (Map.Entry entry : QUERY_TOKENS.entrySet()) { assertEquals(entry.getValue(), convertedQueryTokensMap.get(entry.getKey()).floatValue(), 0); } assertEquals( - TOKEN_SCORE_UPPER_BOUND, - ((Double) secondInnerMap.get(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName())).floatValue(), - 0 + TOKEN_SCORE_UPPER_BOUND, + ((Double) secondInnerMap.get(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName())).floatValue(), + 0 ); } @@ -301,10 +302,10 @@ public void testStreams() { original.writeTo(streamOutput); FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( - streamOutput.bytes().streamInput(), - new NamedWriteableRegistry( - List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) - ) + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) ); SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput); @@ -327,102 +328,92 @@ public void testHashAndEquals() { float tokenScoreUpperBound1 = 1f; float tokenScoreUpperBound2 = 2f; - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName( + fieldName1 + ).queryText(queryText1).queryTokens(queryTokens1).modelId(modelId1).tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff field name - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder() - .fieldName(fieldName2) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff query text - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText2) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText2) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff model ID - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId2) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId2) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff query tokens - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens2) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .queryTokens(queryTokens2) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff boost - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost2) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost2) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff query name - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName2) - .tokenScoreUpperBound(tokenScoreUpperBound1); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName2) + .tokenScoreUpperBound(tokenScoreUpperBound1); // Identical to sparseEncodingQueryBuilder_baseline except diff token_score_upper_bound - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffTokenScoreUpperBound = new SparseEncodingQueryBuilder() - .fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound2); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffTokenScoreUpperBound = new SparseEncodingQueryBuilder().fieldName( + fieldName1 + ) + .queryText(queryText1) + .queryTokens(queryTokens1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .tokenScoreUpperBound(tokenScoreUpperBound2); assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -457,21 +448,19 @@ public void testHashAndEquals() { @SneakyThrows public void testRewrite_whenQueryTokensNotNull_thenRewriteToSelf() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .queryTokens(QUERY_TOKENS) - .fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().queryTokens(QUERY_TOKENS) + .fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID); QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); assert queryBuilder == sparseEncodingQueryBuilder; } @SneakyThrows public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID); Map expectedMap = Map.of("1", 1f, "2", 2f); MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); doAnswer(invocation -> { @@ -486,11 +475,11 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() doAnswer(invocation -> { BiConsumer> biConsumer = invocation.getArgument(0); biConsumer.accept( - null, - ActionListener.wrap( - response -> inProgressLatch.countDown(), - err -> fail("Failed to set query tokens supplier: " + err.getMessage()) - ) + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) + ) ); return null; }).when(queryRewriteContext).registerAsyncAction(any()); @@ -505,26 +494,24 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() public void testRewrite_whenSupplierContentNull_thenReturnCopy() { Supplier> nullSupplier = () -> null; SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .queryTokensSupplier(nullSupplier); + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .queryTokensSupplier(nullSupplier); QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); assertEquals(sparseEncodingQueryBuilder, queryBuilder); } @SneakyThrows public void testRewrite_whenQueryTokensSupplierSet_thenSetQueryTokens() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - SparseEncodingQueryBuilder targetQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .queryTokens(QUERY_TOKENS_SUPPLIER.get()); + SparseEncodingQueryBuilder targetQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .queryTokens(QUERY_TOKENS_SUPPLIER.get()); assertEquals(queryBuilder, targetQueryBuilder); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java index 156a997d7..9c19226ea 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java @@ -5,7 +5,13 @@ package org.opensearch.neuralsearch.query.sparse; +import static org.opensearch.neuralsearch.TestUtils.objectToFloat; + +import java.util.List; +import java.util.Map; + import lombok.SneakyThrows; + import org.junit.After; import org.junit.Before; import org.opensearch.index.query.BoolQueryBuilder; @@ -14,11 +20,6 @@ import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; -import java.util.List; -import java.util.Map; - -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; - public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; private static final String TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; @@ -64,10 +65,9 @@ public void tearDown() { public void testBasicQueryUsingQueryText() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); String modelId = getDeployedModelId(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); @@ -95,10 +95,10 @@ public void testBasicQueryUsingQueryText() { @SneakyThrows public void testBasicQueryUsingQueryTokens() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); - Map queryTokens = TestUtils.createRandomTokenWeightMap(List.of("hello","a","b")); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) - .queryTokens(queryTokens); + Map queryTokens = TestUtils.createRandomTokenWeightMap(List.of("hello", "a", "b")); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryTokens(queryTokens); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); @@ -125,11 +125,9 @@ public void testBasicQueryUsingQueryTokens() { public void testBoostQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); String modelId = getDeployedModelId(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId) - .boost(2.0f); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId).boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); @@ -163,16 +161,10 @@ public void testRescoreQuery() { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); String modelId = getDeployedModelId(); MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId); - Map searchResponseAsMap = search( - TEST_BASIC_INDEX_NAME, - matchAllQueryBuilder, - sparseEncodingQueryBuilder, - 1 - ); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, matchAllQueryBuilder, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); @@ -209,14 +201,12 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder1 = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder2 = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_2) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder1 = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder2 = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_2 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); @@ -257,10 +247,9 @@ public void testBooleanQuery_withSparseEncodingAndBM25Queries() { String modelId = getDeployedModelId(); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder() - .fieldName(TEST_SPARSE_ENCODING_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); boolQueryBuilder.should(sparseEncodingQueryBuilder).should(matchQueryBuilder); @@ -275,60 +264,38 @@ public void testBooleanQuery_withSparseEncodingAndBM25Queries() { @SneakyThrows protected void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex( - indexName, - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1) - ); - addSparseEncodingDoc( - indexName, - "1", - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), - List.of(testTokenWeightMap) - ); + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testTokenWeightMap)); assertEquals(1, getDocCount(indexName)); } if (TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex( - indexName, - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2) - ); + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2)); addSparseEncodingDoc( - indexName, - "1", - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2), - List.of(testTokenWeightMap, testTokenWeightMap) + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2), + List.of(testTokenWeightMap, testTokenWeightMap) ); assertEquals(1, getDocCount(indexName)); } if (TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex( - indexName, - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1) - ); + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); addSparseEncodingDoc( - indexName, - "1", - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), - List.of(testTokenWeightMap), - List.of(TEST_TEXT_FIELD_NAME_1), - List.of(TEST_QUERY_TEXT) + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), + List.of(testTokenWeightMap), + List.of(TEST_TEXT_FIELD_NAME_1), + List.of(TEST_QUERY_TEXT) ); assertEquals(1, getDocCount(indexName)); } if (TEST_NESTED_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { - prepareSparseEncodingIndex( - indexName, - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED) - ); - addSparseEncodingDoc( - indexName, - "1", - List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), - List.of(testTokenWeightMap) - ); + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), List.of(testTokenWeightMap)); assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME)); } } diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java index 6fa2a9d22..a4bc2c495 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -5,13 +5,13 @@ package org.opensearch.neuralsearch.util; -import org.opensearch.test.OpenSearchTestCase; - import java.util.List; import java.util.Map; +import org.opensearch.test.OpenSearchTestCase; + public class TokenWeightUtilTests extends OpenSearchTestCase { - private static final Map MOCK_DATA = Map.of("hello", 1.f, "world", 2.f); + private static final Map MOCK_DATA = Map.of("hello", 1.f, "world", 2.f); public void testFetchListOfTokenWeightMap_singleObject() { /* @@ -22,10 +22,7 @@ public void testFetchListOfTokenWeightMap_singleObject() { }] */ List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); - assertEquals( - TokenWeightUtil.fetchListOfTokenWeightMap(inputData), - List.of(MOCK_DATA) - ); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData), List.of(MOCK_DATA)); } public void testFetchListOfTokenWeightMap_multipleObjectsInOneResponse() { @@ -38,10 +35,7 @@ public void testFetchListOfTokenWeightMap_multipleObjectsInOneResponse() { }] */ List> inputData = List.of(Map.of("response", List.of(MOCK_DATA, MOCK_DATA))); - assertEquals( - TokenWeightUtil.fetchListOfTokenWeightMap(inputData), - List.of(MOCK_DATA, MOCK_DATA) - ); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData), List.of(MOCK_DATA, MOCK_DATA)); } public void testFetchListOfTokenWeightMap_multipleObjectsInMultipleResponse() { @@ -56,14 +50,8 @@ public void testFetchListOfTokenWeightMap_multipleObjectsInMultipleResponse() { ] }] */ - List> inputData = List.of( - Map.of("response", List.of(MOCK_DATA)), - Map.of("response", List.of(MOCK_DATA)) - ); - assertEquals( - TokenWeightUtil.fetchListOfTokenWeightMap(inputData), - List.of(MOCK_DATA, MOCK_DATA) - ); + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA)), Map.of("response", List.of(MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData), List.of(MOCK_DATA, MOCK_DATA)); } public void testFetchListOfTokenWeightMap_whenResponseValueNotList_thenFail() { @@ -95,26 +83,26 @@ public void testFetchListOfTokenWeightMap_whenInputObjectIsNotMap_thenFail() { List> inputData = List.of(Map.of("response", List.of(List.of(MOCK_DATA)))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } - + public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonStringKeys_thenFail() { /* [{ - "response": [[{"hello": 1.0, 2.3: 2.0}]] + "response": [{"hello": 1.0, 2.3: 2.0}] }] */ - Map MOCK_DATA = Map.of("hello", 1.f, 2.3f, 2.f); - List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + Map mockData = Map.of("hello", 1.f, 2.3f, 2.f); + List> inputData = List.of(Map.of("response", List.of(mockData))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_thenFail() { /* [{ - "response": [[{"hello": 1.0, 2.3: 2.0}]] + "response": [{"hello": 1.0, "world": "world"}] }] */ - Map MOCK_DATA = Map.of("hello", 1.f, "world", "world"); - List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + Map mockData = Map.of("hello", 1.f, "world", "world"); + List> inputData = List.of(Map.of("response", List.of(mockData))); expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); } } From 473c68d655f03038e80f0aa2ce14626c429096f2 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 14:47:41 +0800 Subject: [PATCH 44/70] remove BoundedLinearQuery and TokenScoreUpperBound Signed-off-by: zhichao-aws --- .../sparse/BoundedLinearFeatureQuery.java | 217 ------------------ .../sparse/SparseEncodingQueryBuilder.java | 31 +-- .../SparseEncodingQueryBuilderTests.java | 56 +---- 3 files changed, 17 insertions(+), 287 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java deleted file mode 100644 index 5ce291c25..000000000 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/BoundedLinearFeatureQuery.java +++ /dev/null @@ -1,217 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -/* - * This class is built based on lucene FeatureQuery. We use LinearFuntion to - * build the query and add an upperbound to it. - */ - -package org.opensearch.neuralsearch.query.sparse; - -import java.io.IOException; -import java.util.Objects; - -import org.apache.lucene.index.ImpactsEnum; -import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.PostingsEnum; -import org.apache.lucene.index.Term; -import org.apache.lucene.index.Terms; -import org.apache.lucene.index.TermsEnum; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Explanation; -import org.apache.lucene.search.ImpactsDISI; -import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.Weight; -import org.apache.lucene.search.similarities.Similarity.SimScorer; -import org.apache.lucene.util.BytesRef; - -public final class BoundedLinearFeatureQuery extends Query { - - private final String fieldName; - private final String featureName; - private final Float scoreUpperBound; - - public BoundedLinearFeatureQuery(String fieldName, String featureName, Float scoreUpperBound) { - this.fieldName = Objects.requireNonNull(fieldName); - this.featureName = Objects.requireNonNull(featureName); - this.scoreUpperBound = Objects.requireNonNull(scoreUpperBound); - } - - @Override - public Query rewrite(IndexSearcher indexSearcher) throws IOException { - return super.rewrite(indexSearcher); - } - - @Override - public boolean equals(Object obj) { - if (obj == null || getClass() != obj.getClass()) { - return false; - } - BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj; - return Objects.equals(fieldName, that.fieldName) - && Objects.equals(featureName, that.featureName) - && Objects.equals(scoreUpperBound, that.scoreUpperBound); - } - - @Override - public int hashCode() { - int h = getClass().hashCode(); - h = 31 * h + fieldName.hashCode(); - h = 31 * h + featureName.hashCode(); - h = 31 * h + scoreUpperBound.hashCode(); - return h; - } - - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - if (!scoreMode.needsScores()) { - // We don't need scores (e.g. for faceting), and since features are stored as terms, - // allow TermQuery to optimize in this case - TermQuery tq = new TermQuery(new Term(fieldName, featureName)); - return searcher.rewrite(tq).createWeight(searcher, scoreMode, boost); - } - - return new Weight(this) { - - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) throws IOException { - String desc = "weight(" + getQuery() + " in " + doc + ") [\" BoundedLinearFeatureQuery \"]"; - - Terms terms = context.reader().terms(fieldName); - if (terms == null) { - return Explanation.noMatch(desc + ". Field " + fieldName + " doesn't exist."); - } - TermsEnum termsEnum = terms.iterator(); - if (termsEnum.seekExact(new BytesRef(featureName)) == false) { - return Explanation.noMatch(desc + ". Feature " + featureName + " doesn't exist."); - } - - PostingsEnum postings = termsEnum.postings(null, PostingsEnum.FREQS); - if (postings.advance(doc) != doc) { - return Explanation.noMatch(desc + ". Feature " + featureName + " isn't set."); - } - - int freq = postings.freq(); - float featureValue = decodeFeatureValue(freq); - float score = boost * featureValue; - return Explanation.match( - score, - "Linear function on the " + fieldName + " field for the " + featureName + " feature, computed as w * S from:", - Explanation.match(boost, "w, weight of this function"), - Explanation.match(featureValue, "S, feature value") - ); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - Terms terms = Terms.getTerms(context.reader(), fieldName); - TermsEnum termsEnum = terms.iterator(); - if (termsEnum.seekExact(new BytesRef(featureName)) == false) { - return null; - } - - final SimScorer scorer = new SimScorer() { - @Override - public float score(float freq, long norm) { - return boost * decodeFeatureValue(freq); - } - }; - final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS); - final ImpactsDISI impactsDisi = new ImpactsDISI(impacts, impacts, scorer); - - return new Scorer(this) { - - @Override - public int docID() { - return impacts.docID(); - } - - @Override - public float score() throws IOException { - return scorer.score(impacts.freq(), 1L); - } - - @Override - public DocIdSetIterator iterator() { - return impactsDisi; - } - - @Override - public int advanceShallow(int target) throws IOException { - return impactsDisi.advanceShallow(target); - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return impactsDisi.getMaxScore(upTo); - } - - @Override - public void setMinCompetitiveScore(float minScore) { - impactsDisi.setMinCompetitiveScore(minScore); - } - }; - } - }; - } - - @Override - public void visit(QueryVisitor visitor) { - if (visitor.acceptField(fieldName)) { - visitor.visitLeaf(this); - } - } - - @Override - public String toString(String field) { - return "BoundedLinearFeatureQuery(field=" + fieldName + ", feature=" + featureName + ", scoreUpperBound=" + scoreUpperBound + ")"; - } - - static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15; - - private float decodeFeatureValue(float freq) { - if (freq > MAX_FREQ) { - return scoreUpperBound; - } - int tf = (int) freq; // lossless - int featureBits = tf << 15; - return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java index 957e3477a..812f255f2 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java @@ -20,9 +20,9 @@ import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.lucene.document.FeatureField; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; @@ -56,8 +56,6 @@ public class SparseEncodingQueryBuilder extends AbstractQueryBuilder queryTokens; private String queryText; private String modelId; - private Float tokenScoreUpperBound; private Supplier> queryTokensSupplier; public SparseEncodingQueryBuilder(StreamInput in) throws IOException { @@ -81,7 +78,6 @@ public SparseEncodingQueryBuilder(StreamInput in) throws IOException { } this.queryText = in.readOptionalString(); this.modelId = in.readOptionalString(); - this.tokenScoreUpperBound = in.readOptionalFloat(); } @Override @@ -95,7 +91,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { } out.writeOptionalString(queryText); out.writeOptionalString(modelId); - out.writeOptionalFloat(tokenScoreUpperBound); } @Override @@ -105,7 +100,6 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws if (null != queryTokens) xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokens); if (null != queryText) xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); if (null != modelId) xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); - if (null != tokenScoreUpperBound) xContentBuilder.field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), tokenScoreUpperBound); printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -119,15 +113,13 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * "token_a": float, * "token_b": float, * ... - * }, - * "token_score_upper_bound": float (optional) + * } * } * } * or * "SAMPLE_FIELD": { * "query_text": "string", - * "model_id": "string", - * "token_score_upper_bound": float (optional) + * "model_id": "string" * } * */ @@ -193,8 +185,6 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB sparseEncodingQueryBuilder.queryText(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseEncodingQueryBuilder.modelId(parser.text()); - } else if (TOKEN_SCORE_UPPER_BOUND_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseEncodingQueryBuilder.tokenScoreUpperBound(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), @@ -226,8 +216,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws : new SparseEncodingQueryBuilder().fieldName(fieldName) .queryTokens(queryTokensSupplier.get()) .queryText(queryText) - .modelId(modelId) - .tokenScoreUpperBound(tokenScoreUpperBound); + .modelId(modelId); } validateForRewrite(queryText, modelId); @@ -245,7 +234,6 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return new SparseEncodingQueryBuilder().fieldName(fieldName) .queryText(queryText) .modelId(modelId) - .tokenScoreUpperBound(tokenScoreUpperBound) .queryTokensSupplier(queryTokensSetOnce::get); } @@ -255,14 +243,11 @@ protected Query doToQuery(QueryShardContext context) throws IOException { validateFieldType(ft); validateQueryTokens(queryTokens); - // the tokenScoreUpperBound from query has higher priority - final Float scoreUpperBound = null != tokenScoreUpperBound ? tokenScoreUpperBound : Float.MAX_VALUE; - BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Map.Entry entry : queryTokens.entrySet()) { builder.add( - new BoostQuery(new BoundedLinearFeatureQuery(fieldName, entry.getKey(), scoreUpperBound), entry.getValue()), - BooleanClause.Occur.SHOULD + FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), + BooleanClause.Occur.SHOULD ); } return builder.build(); @@ -308,8 +293,7 @@ protected boolean doEquals(SparseEncodingQueryBuilder obj) { EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryTokens, obj.queryTokens) .append(queryText, obj.queryText) - .append(modelId, obj.modelId) - .append(tokenScoreUpperBound, obj.tokenScoreUpperBound); + .append(modelId, obj.modelId); return equalsBuilder.isEquals(); } @@ -319,7 +303,6 @@ protected int doHashCode() { .append(queryTokens) .append(queryText) .append(modelId) - .append(tokenScoreUpperBound) .toHashCode(); } diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java index e01825d75..4461b6d90 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java @@ -15,7 +15,6 @@ import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TOKENS_FIELD; -import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.TOKEN_SCORE_UPPER_BOUND_FIELD; import java.io.IOException; import java.util.List; @@ -50,7 +49,6 @@ public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { private static final String QUERY_TEXT = "Hello world!"; private static final Map QUERY_TOKENS = Map.of("hello", 1.f, "world", 2.f); private static final String MODEL_ID = "mfgfgdsfgfdgsde"; - private static final Float TOKEN_SCORE_UPPER_BOUND = 123f; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); @@ -115,7 +113,6 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { "VECTOR_FIELD": { "query_text": "string", "model_id": "string", - "token_score_upper_bound":123.0, "boost": 10.0, "_name": "something", } @@ -126,7 +123,6 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) - .field(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName(), TOKEN_SCORE_UPPER_BOUND) .field(BOOST_FIELD.getPreferredName(), BOOST) .field(NAME_FIELD.getPreferredName(), QUERY_NAME) .endObject() @@ -139,7 +135,6 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); - assertEquals(TOKEN_SCORE_UPPER_BOUND, sparseEncodingQueryBuilder.tokenScoreUpperBound()); assertEquals(BOOST, sparseEncodingQueryBuilder.boost(), 0.0); assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName()); } @@ -248,8 +243,7 @@ public void testToXContent() { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) .modelId(MODEL_ID) .queryText(QUERY_TEXT) - .queryTokens(QUERY_TOKENS) - .tokenScoreUpperBound(TOKEN_SCORE_UPPER_BOUND); + .queryTokens(QUERY_TOKENS); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -280,11 +274,6 @@ public void testToXContent() { for (Map.Entry entry : QUERY_TOKENS.entrySet()) { assertEquals(entry.getValue(), convertedQueryTokensMap.get(entry.getKey()).floatValue(), 0); } - assertEquals( - TOKEN_SCORE_UPPER_BOUND, - ((Double) secondInnerMap.get(TOKEN_SCORE_UPPER_BOUND_FIELD.getPreferredName())).floatValue(), - 0 - ); } @SneakyThrows @@ -294,7 +283,6 @@ public void testStreams() { original.queryText(QUERY_TEXT); original.modelId(MODEL_ID); original.queryTokens(QUERY_TOKENS); - original.tokenScoreUpperBound(TOKEN_SCORE_UPPER_BOUND); original.boost(BOOST); original.queryName(QUERY_NAME); @@ -325,16 +313,13 @@ public void testHashAndEquals() { String queryName2 = "query-2"; Map queryTokens1 = Map.of("hello", 1f); Map queryTokens2 = Map.of("hello", 2f); - float tokenScoreUpperBound1 = 1f; - float tokenScoreUpperBound2 = 2f; SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1) @@ -342,13 +327,12 @@ public void testHashAndEquals() { .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName( fieldName1 - ).queryText(queryText1).queryTokens(queryTokens1).modelId(modelId1).tokenScoreUpperBound(tokenScoreUpperBound1); + ).queryText(queryText1).queryTokens(queryTokens1).modelId(modelId1); // Identical to sparseEncodingQueryBuilder_baseline except diff field name SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2) @@ -356,8 +340,7 @@ public void testHashAndEquals() { .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff query text SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1) @@ -365,8 +348,7 @@ public void testHashAndEquals() { .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff model ID SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1) @@ -374,8 +356,7 @@ public void testHashAndEquals() { .queryTokens(queryTokens1) .modelId(modelId2) .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff query tokens SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new SparseEncodingQueryBuilder().fieldName(fieldName1) @@ -383,8 +364,7 @@ public void testHashAndEquals() { .queryTokens(queryTokens2) .modelId(modelId1) .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff boost SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1) @@ -392,8 +372,7 @@ public void testHashAndEquals() { .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost2) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound1); + .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except diff query name SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1) @@ -401,19 +380,7 @@ public void testHashAndEquals() { .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) - .queryName(queryName2) - .tokenScoreUpperBound(tokenScoreUpperBound1); - - // Identical to sparseEncodingQueryBuilder_baseline except diff token_score_upper_bound - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffTokenScoreUpperBound = new SparseEncodingQueryBuilder().fieldName( - fieldName1 - ) - .queryText(queryText1) - .queryTokens(queryTokens1) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1) - .tokenScoreUpperBound(tokenScoreUpperBound2); + .queryName(queryName2); assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -441,9 +408,6 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); - - assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffTokenScoreUpperBound); - assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffTokenScoreUpperBound.hashCode()); } @SneakyThrows From fa11056211b309b6f6cd2a2873822dc7e4ce911e Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 14:49:04 +0800 Subject: [PATCH 45/70] tidy Signed-off-by: zhichao-aws --- .../query/sparse/SparseEncodingQueryBuilder.java | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java index 812f255f2..d49f6d33d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java @@ -245,10 +245,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { BooleanQuery.Builder builder = new BooleanQuery.Builder(); for (Map.Entry entry : queryTokens.entrySet()) { - builder.add( - FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), - BooleanClause.Occur.SHOULD - ); + builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD); } return builder.build(); } @@ -299,11 +296,7 @@ protected boolean doEquals(SparseEncodingQueryBuilder obj) { @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName) - .append(queryTokens) - .append(queryText) - .append(modelId) - .toHashCode(); + return new HashCodeBuilder().append(fieldName).append(queryTokens).append(queryText).append(modelId).toHashCode(); } @Override From 439d62815c3f3d45b0339562e4e1c42df8419681 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 16:53:35 +0800 Subject: [PATCH 46/70] add delta to loose the equal Signed-off-by: zhichao-aws --- .../java/org/opensearch/neuralsearch/TestUtils.java | 9 +-------- .../query/sparse/SparseEncodingQueryIT.java | 12 +++++++----- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 9565be474..d9cb1ffee 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -76,13 +76,6 @@ public static float[] createRandomVector(int dimension) { return vector; } - // When ingesting token weight map, float number will be decoded to json, which may lose precision - // To compute match score without losing precision, we limit the effective digits of float number - public static Float createFloatNumberWithEffectiveDigits(float inputNumber, int scale) { - BigDecimal bd = new BigDecimal(inputNumber); - return bd.setScale(scale, RoundingMode.HALF_UP).floatValue(); - } - /** * Create a map of provided tokens, the values will be random float numbers * @@ -93,7 +86,7 @@ public static Map createRandomTokenWeightMap(Collection t Map resultMap = new HashMap<>(); for (String token : tokens) { // use a small shift to ensure value > 0 - resultMap.put(token, createFloatNumberWithEffectiveDigits(Math.abs(randomFloat()) + 1e-3f, 3)); + resultMap.put(token, Math.abs(randomFloat()) + 1e-3f); } return resultMap; } diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java index 9c19226ea..7d125e67c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java @@ -32,6 +32,8 @@ public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final String TEST_SPARSE_ENCODING_FIELD_NAME_NESTED = "nested.sparse_encoding.field"; private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); + + private static final Float DELTA = 1e-5f; private final Map testTokenWeightMap = TestUtils.createRandomTokenWeightMap(TEST_TOKENS); @Before @@ -73,7 +75,7 @@ public void testBasicQueryUsingQueryText() { assertEquals("1", firstInnerHit.get("_id")); float expectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** @@ -104,7 +106,7 @@ public void testBasicQueryUsingQueryTokens() { assertEquals("1", firstInnerHit.get("_id")); float expectedScore = computeExpectedScore(testTokenWeightMap, queryTokens); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** @@ -133,7 +135,7 @@ public void testBoostQuery() { assertEquals("1", firstInnerHit.get("_id")); float expectedScore = 2 * computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** @@ -169,7 +171,7 @@ public void testRescoreQuery() { assertEquals("1", firstInnerHit.get("_id")); float expectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** @@ -215,7 +217,7 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { assertEquals("1", firstInnerHit.get("_id")); float expectedScore = 2 * computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** From 99a739dc0a52badaee79ab0670d0016be7a8f2f0 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 16:56:06 +0800 Subject: [PATCH 47/70] move SparseEncodingQueryBuilder to upper level path Signed-off-by: zhichao-aws --- .../opensearch/neuralsearch/plugin/NeuralSearch.java | 2 +- .../{sparse => }/SparseEncodingQueryBuilder.java | 2 +- .../{sparse => }/SparseEncodingQueryBuilderTests.java | 11 ++++++----- .../query/{sparse => }/SparseEncodingQueryIT.java | 3 ++- 4 files changed, 10 insertions(+), 8 deletions(-) rename src/main/java/org/opensearch/neuralsearch/query/{sparse => }/SparseEncodingQueryBuilder.java (99%) rename src/test/java/org/opensearch/neuralsearch/query/{sparse => }/SparseEncodingQueryBuilderTests.java (97%) rename src/test/java/org/opensearch/neuralsearch/query/{sparse => }/SparseEncodingQueryIT.java (99%) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index fc93af934..d72e1a1ed 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -42,7 +42,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder; +import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; diff --git a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java similarity index 99% rename from src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java rename to src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index d49f6d33d..adb4315b5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.query.sparse; +package org.opensearch.neuralsearch.query; import java.io.IOException; import java.util.HashMap; diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java similarity index 97% rename from src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java rename to src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java index 4461b6d90..4805287b3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.query.sparse; +package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -11,10 +11,10 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; -import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.MODEL_ID_FIELD; -import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.NAME; -import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; -import static org.opensearch.neuralsearch.query.sparse.SparseEncodingQueryBuilder.QUERY_TOKENS_FIELD; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TOKENS_FIELD; import java.io.IOException; import java.util.List; @@ -41,6 +41,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; import org.opensearch.test.OpenSearchTestCase; public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java similarity index 99% rename from src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java rename to src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java index 7d125e67c..ce9155f6a 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/sparse/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.query.sparse; +package org.opensearch.neuralsearch.query; import static org.opensearch.neuralsearch.TestUtils.objectToFloat; @@ -19,6 +19,7 @@ import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; +import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; From 65b1e4fc0115d998eff79ba07c0634a385586bc6 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 08:59:57 +0000 Subject: [PATCH 48/70] tidy Signed-off-by: zhichao-aws --- src/test/java/org/opensearch/neuralsearch/TestUtils.java | 2 -- .../neuralsearch/query/SparseEncodingQueryBuilderTests.java | 1 - .../opensearch/neuralsearch/query/SparseEncodingQueryIT.java | 1 - 3 files changed, 4 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index d9cb1ffee..385855a2e 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -11,8 +11,6 @@ import static org.junit.Assert.assertTrue; import static org.opensearch.test.OpenSearchTestCase.randomFloat; -import java.math.BigDecimal; -import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java index 4805287b3..7f33c44c2 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java @@ -41,7 +41,6 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; import org.opensearch.test.OpenSearchTestCase; public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java index ce9155f6a..27517ddcb 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java @@ -19,7 +19,6 @@ import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.neuralsearch.TestUtils; import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; -import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; From 916b3cf4e0380586291deeebe31c06d1ae5c5993 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 20 Sep 2023 17:13:10 +0800 Subject: [PATCH 49/70] add it Signed-off-by: zhichao-aws --- .../neuralsearch/query/SparseEncodingQueryIT.java | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java index 27517ddcb..7d2a2314c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java @@ -14,6 +14,7 @@ import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; @@ -263,6 +264,18 @@ public void testBooleanQuery_withSparseEncodingAndBM25Queries() { assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); } + @SneakyThrows + public void testBasicQueryUsingQueryText_whenQueryWrongFieldType_thenFail() { + initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(TEST_TEXT_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + + expectThrows(ResponseException.class, () -> search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, sparseEncodingQueryBuilder, 1)); + } + @SneakyThrows protected void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { From 283a7a364d222c00f43a075e138c8e107adaf439 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 15:33:10 +0800 Subject: [PATCH 50/70] Update src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java Co-authored-by: zane-neo Signed-off-by: zhichao-aws --- .../org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 2571dc4e2..2c752b7ea 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -179,7 +179,7 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { - throw new IllegalStateException("Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]"); + throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"); } List> resultMaps = new ArrayList<>(); for (ModelTensors tensors : tensorOutputList) { From c3d9fd3c81ca37ac3494de31c38c72ddc9a62ea2 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 19:13:25 +0800 Subject: [PATCH 51/70] Update src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java Co-authored-by: zane-neo Signed-off-by: zhichao-aws --- .../java/org/opensearch/neuralsearch/util/TokenWeightUtil.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index d552a4362..2b9613be3 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -53,7 +53,7 @@ public static List> fetchListOfTokenWeightMap(List buildTokenWeightMap(Object uncastedMap) { if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { - throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + " Float values."); + throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } Map result = new HashMap<>(); for (Map.Entry entry : ((Map) uncastedMap).entrySet()) { From 11cf97d0c31a568c71ed9ce54cbf316690d7d514 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 15:37:09 +0800 Subject: [PATCH 52/70] restore gradle.propeties Signed-off-by: zhichao-aws --- gradle.properties | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradle.properties b/gradle.properties index 90e7a8445..5e5cd9ced 100644 --- a/gradle.properties +++ b/gradle.properties @@ -8,4 +8,4 @@ org.gradle.jvmargs=--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAME --add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \ --add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \ - --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED \ No newline at end of file + --add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED From 057f43508ce36cf980f04905ace5fa593c9beeb8 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 17:55:42 +0800 Subject: [PATCH 53/70] add release notes Signed-off-by: zhichao-aws --- release-notes/opensearch-neural-search.release-notes-2.10.0.0.md | 1 + 1 file changed, 1 insertion(+) diff --git a/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md b/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md index 5c86a24dd..104460dc1 100644 --- a/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md +++ b/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md @@ -4,6 +4,7 @@ Compatible with OpenSearch 2.10.0 ### Features * Improved Hybrid Search relevancy by Score Normalization and Combination ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) +* Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements * Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) From 351bae93732ff60670c0ea3a7950e3f5efa16a85 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 18:35:28 +0800 Subject: [PATCH 54/70] change field modifier to private for NLPProcessor Signed-off-by: zhichao-aws --- .../opensearch/neuralsearch/processor/NLPProcessor.java | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index d4b3c340f..17570190b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -32,18 +32,17 @@ public abstract class NLPProcessor extends AbstractProcessor { public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; - protected final String type; + private final String type; - protected final String listTypeNestedMapKey; + private final String listTypeNestedMapKey; - @VisibleForTesting protected final String modelId; - protected final Map fieldMap; + private final Map fieldMap; protected final MLCommonsClientAccessor mlCommonsClientAccessor; - protected final Environment environment; + private final Environment environment; public NLPProcessor( String tag, From f58a0734a34b14e0cf704eb4efe39522663cd2c7 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 19:43:30 +0800 Subject: [PATCH 55/70] add comments Signed-off-by: zhichao-aws --- .../org/opensearch/neuralsearch/processor/NLPProcessor.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 17570190b..7e81c3922 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -34,6 +34,8 @@ public abstract class NLPProcessor extends AbstractProcessor { private final String type; + // This field is used for nested knn_vector/rank_features field. The value of the field will be used as the + // default key for the nested object. private final String listTypeNestedMapKey; protected final String modelId; From 791c6ca98dd086bfa6b189a13962a34d71d9f5ad Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 19:49:33 +0800 Subject: [PATCH 56/70] use StringUtils to check Signed-off-by: zhichao-aws --- .../neuralsearch/query/SparseEncodingQueryBuilder.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index adb4315b5..f75a56b50 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -18,6 +18,7 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.document.FeatureField; @@ -251,7 +252,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { } private static void validateForRewrite(String queryText, String modelId) { - if (null == queryText || null == modelId) { + if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( "When " + QUERY_TOKENS_FIELD.getPreferredName() From 6878c0121fb471be4e4dc7ef6a7f78256a1998d7 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 25 Sep 2023 20:52:25 +0800 Subject: [PATCH 57/70] null check Signed-off-by: zhichao-aws --- .../neuralsearch/query/SparseEncodingQueryBuilder.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index f75a56b50..3da98a186 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -266,7 +266,7 @@ private static void validateForRewrite(String queryText, String modelId) { } private static void validateFieldType(MappedFieldType fieldType) { - if (!fieldType.typeName().equals("rank_features")) { + if (null == fieldType || !fieldType.typeName().equals("rank_features")) { throw new IllegalArgumentException( "[" + NAME + "] query only works on [rank_features] fields, " + "not [" + fieldType.typeName() + "]" ); From c6c631e35a802ad250d919cb8e63537a8583c8df Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 06:20:28 +0000 Subject: [PATCH 58/70] modify changelog Signed-off-by: zhichao-aws --- CHANGELOG.md | 2 ++ .../opensearch-neural-search.release-notes-2.10.0.0.md | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b3ea389e..038fe41e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements ### Bug Fixes ### Infrastructure @@ -14,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x) ### Features +Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md b/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md index 104460dc1..5c86a24dd 100644 --- a/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md +++ b/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md @@ -4,7 +4,6 @@ Compatible with OpenSearch 2.10.0 ### Features * Improved Hybrid Search relevancy by Score Normalization and Combination ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) -* Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements * Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) From ec70c34f7e97e18d46733a3f5a2333a28c599929 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 14:54:10 +0800 Subject: [PATCH 59/70] nit Signed-off-by: zhichao-aws --- .../neuralsearch/plugin/NeuralSearch.java | 10 ++++---- .../neuralsearch/processor/NLPProcessor.java | 23 +++++++++++-------- .../query/SparseEncodingQueryBuilder.java | 4 ++-- .../ml/MLCommonsClientAccessorTests.java | 13 +++++------ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index d72e1a1ed..601bf9003 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -99,10 +99,12 @@ public List> getQueries() { @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); - Map allProcessors = new HashMap<>(); - allProcessors.put(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); - allProcessors.put(SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env)); - return allProcessors; + return Map.of( + TextEmbeddingProcessor.TYPE, + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + SparseEncodingProcessor.TYPE, + new SparseEncodingProcessorFactory(clientAccessor, parameters.env) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 7e81c3922..09edc301f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -26,6 +26,9 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; +// The abstract class for text processing use cases. Users provide a field name map +// and a model id. During ingestion, the processor will use the corresponding model +// to inference the input texts, and set the target fields according to the field name map. @Log4j2 public abstract class NLPProcessor extends AbstractProcessor { @@ -58,7 +61,7 @@ public NLPProcessor( ) { super(tag, description); this.type = type; - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); validateEmbeddingConfiguration(fieldMap); this.listTypeNestedMapKey = listTypeNestedMapKey; @@ -81,14 +84,14 @@ private void validateEmbeddingConfiguration(Map fieldMap) { } @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { + private void validateListTypeValue(String sourceKey, Object sourceValue) { for (Object value : (List) sourceValue) { if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); } } } @@ -97,7 +100,7 @@ private static void validateListTypeValue(String sourceKey, Object sourceValue) private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { validateListTypeValue(sourceKey, sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { @@ -106,9 +109,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl .filter(Objects::nonNull) .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); } } @@ -122,9 +125,9 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { validateNestedTypeValue(sourceKey, sourceValue, () -> 1); } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); } } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 3da98a186..eb369d1a7 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -260,7 +260,7 @@ private static void validateForRewrite(String queryText, String modelId) { + QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() - + " can not be null." + + " cannot be null." ); } } @@ -275,7 +275,7 @@ private static void validateFieldType(MappedFieldType fieldType) { private static void validateQueryTokens(Map queryTokens) { if (null == queryTokens) { - throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field can not be null."); + throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field cannot be null."); } for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index a51c62977..000b9598b 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -162,8 +162,7 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } - public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { - // final List> map = List.of(Map.of("key", "value")); + public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { final Map map = Map.of("key", "value"); final ActionListener>> resultListener = mock(ActionListener.class); Mockito.doAnswer(invocation -> { @@ -179,7 +178,7 @@ public void test_inferenceSentencesWithMapResult_whenValidInput_thenSuccess() { Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { + public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { final ActionListener>> resultListener = mock(ActionListener.class); final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); Mockito.doAnswer(invocation -> { @@ -200,7 +199,7 @@ public void test_inferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenE Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { + public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -224,7 +223,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListEmpty_thenEx Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() { + public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() { final ActionListener>> resultListener = mock(ActionListener.class); final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -246,7 +245,7 @@ public void test_inferenceSentencesWithMapResult_whenModelTensorListSizeBiggerTh Mockito.verifyNoMoreInteractions(resultListener); } - public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Times() { + public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Times() { final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( mock(DiscoveryNode.class), "Node not connected" @@ -264,7 +263,7 @@ public void test_inferenceSentencesWithMapResult_whenRetryableException_retry3Ti Mockito.verify(resultListener).onFailure(nodeNodeConnectedException); } - public void test_inferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { + public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { final IllegalStateException illegalStateException = new IllegalStateException("Illegal state"); Mockito.doAnswer(invocation -> { final ActionListener actionListener = invocation.getArgument(2); From 9223e3115b85716a2697d989467c64a3cecfe93f Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 15:01:10 +0800 Subject: [PATCH 60/70] nit Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NLPProcessor.java | 222 +++++++++--------- 1 file changed, 111 insertions(+), 111 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 09edc301f..2ca516404 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -83,77 +83,65 @@ private void validateEmbeddingConfiguration(Map fieldMap) { } } - @SuppressWarnings({ "rawtypes" }) - private void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { - if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); - } - } - } + public abstract void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ); - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); - } + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return ingestDocument; } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); - } + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + try { + validateEmbeddingFieldsValue(ingestDocument); + Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(ProcessMap); + if (inferenceList.size() == 0) { + handler.accept(ingestDocument, null); + } else { + doExecute(ingestDocument, ProcessMap, inferenceList, handler); } + } catch (Exception e) { + handler.accept(null, e); } } - private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (processorKey == null || sourceAndMetadataMap == null) return; - if (processorKey instanceof Map) { - Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); } - treeRes.put(parentKey, next); + }); + return texts; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); } else { - String key = String.valueOf(processorKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + if (sourceValue == null) return; + texts.add(sourceValue.toString()); } } @@ -175,65 +163,77 @@ Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge return mapWithProcessorKeys; } - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); + private void buildMapWithProcessorKeyAndOriginalValueForMapType( + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (processorKey == null || sourceAndMetadataMap == null) return; + if (processorKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); + String key = String.valueOf(processorKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); } } - @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + } } - }); - return texts; + } } - public abstract void doExecute( - IngestDocument ingestDocument, - Map ProcessMap, - List inferenceList, - BiConsumer handler - ); - - @Override - public IngestDocument execute(IngestDocument ingestDocument) throws Exception { - return ingestDocument; + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + } } - /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to - * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. - */ - @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { - try { - validateEmbeddingFieldsValue(ingestDocument); - Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(ProcessMap); - if (inferenceList.size() == 0) { - handler.accept(ingestDocument, null); - } else { - doExecute(ingestDocument, ProcessMap, inferenceList, handler); + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); } - } catch (Exception e) { - handler.accept(null, e); } } From ba10e27dadf5d4003add308c2aeae2eec1a75ce7 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 16:09:03 +0800 Subject: [PATCH 61/70] remove query tokens from user interface Signed-off-by: zhichao-aws --- .../query/SparseEncodingQueryBuilder.java | 91 +++++------------- .../SparseEncodingQueryBuilderTests.java | 93 ++----------------- .../query/SparseEncodingQueryIT.java | 55 +++-------- 3 files changed, 44 insertions(+), 195 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index eb369d1a7..3e938933c 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -52,8 +52,6 @@ public class SparseEncodingQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "sparse_encoding"; @VisibleForTesting - static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens"); - @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); @VisibleForTesting static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); @@ -65,7 +63,6 @@ public static void initialize(MLCommonsClientAccessor mlClient) { } private String fieldName; - private Map queryTokens; private String queryText; private String modelId; private Supplier> queryTokensSupplier; @@ -73,10 +70,6 @@ public static void initialize(MLCommonsClientAccessor mlClient) { public SparseEncodingQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); - // we don't have readOptionalMap or write, need to do it manually - if (in.readBoolean()) { - this.queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); - } this.queryText = in.readOptionalString(); this.modelId = in.readOptionalString(); } @@ -84,12 +77,6 @@ public SparseEncodingQueryBuilder(StreamInput in) throws IOException { @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); - if (null != queryTokens) { - out.writeBoolean(true); - out.writeMap(queryTokens, StreamOutput::writeString, StreamOutput::writeFloat); - } else { - out.writeBoolean(false); - } out.writeOptionalString(queryText); out.writeOptionalString(modelId); } @@ -98,7 +85,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); - if (null != queryTokens) xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokens); if (null != queryText) xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); if (null != modelId) xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); printBoostAndQueryName(xContentBuilder); @@ -108,16 +94,6 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws /** * The expected parsing form looks like: - * { - * "SAMPLE_FIELD": { - * "query_tokens": { - * "token_a": float, - * "token_b": float, - * ... - * } - * } - * } - * or * "SAMPLE_FIELD": { * "query_text": "string", * "model_id": "string" @@ -147,26 +123,20 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr } requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); - if (null == sparseEncodingQueryBuilder.queryTokens()) { - requireValue( - sparseEncodingQueryBuilder.queryText(), - "Either " - + QUERY_TOKENS_FIELD.getPreferredName() - + " or " - + QUERY_TEXT_FIELD.getPreferredName() - + " must be provided for " - + NAME - + " query" - ); - requireValue( - sparseEncodingQueryBuilder.modelId(), - MODEL_ID_FIELD.getPreferredName() - + " must be provided for " - + NAME - + " query when using " - + QUERY_TEXT_FIELD.getPreferredName() - ); - } + requireValue( + sparseEncodingQueryBuilder.queryText(), + QUERY_TEXT_FIELD.getPreferredName() + + " must be provided for " + + NAME + + " query" + ); + requireValue( + sparseEncodingQueryBuilder.modelId(), + MODEL_ID_FIELD.getPreferredName() + + " must be provided for " + + NAME + + " query" + ); return sparseEncodingQueryBuilder; } @@ -192,8 +162,6 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } - } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { - sparseEncodingQueryBuilder.queryTokens(parser.map(HashMap::new, XContentParser::floatValue)); } else { throw new ParsingException( parser.getTokenLocation(), @@ -205,19 +173,10 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB @Override protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - // If the user has specified query_tokens field, then we don't need to inference the sentence, - // just re-rewrite to self. Otherwise, we need to inference the sentence to get the queryTokens. Then the - // logic is similar to NeuralQueryBuilder - if (null != queryTokens) { - return this; - } + // We need to inference the sentence to get the queryTokens. The logic is similar to NeuralQueryBuilder + // If the inference is finished, then rewrite to self and call doToQuery, otherwise, continue doRewrite if (null != queryTokensSupplier) { - return queryTokensSupplier.get() == null - ? this - : new SparseEncodingQueryBuilder().fieldName(fieldName) - .queryTokens(queryTokensSupplier.get()) - .queryText(queryText) - .modelId(modelId); + return this; } validateForRewrite(queryText, modelId); @@ -242,6 +201,8 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws protected Query doToQuery(QueryShardContext context) throws IOException { final MappedFieldType ft = context.fieldMapper(fieldName); validateFieldType(ft); + + Map queryTokens = queryTokensSupplier.get(); validateQueryTokens(queryTokens); BooleanQuery.Builder builder = new BooleanQuery.Builder(); @@ -254,10 +215,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { private static void validateForRewrite(String queryText, String modelId) { if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - "When " - + QUERY_TOKENS_FIELD.getPreferredName() - + " are not provided," - + QUERY_TEXT_FIELD.getPreferredName() + QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + " cannot be null." @@ -268,18 +226,18 @@ private static void validateForRewrite(String queryText, String modelId) { private static void validateFieldType(MappedFieldType fieldType) { if (null == fieldType || !fieldType.typeName().equals("rank_features")) { throw new IllegalArgumentException( - "[" + NAME + "] query only works on [rank_features] fields, " + "not [" + fieldType.typeName() + "]" + "[" + NAME + "] query only works on [rank_features] fields" ); } } private static void validateQueryTokens(Map queryTokens) { if (null == queryTokens) { - throw new IllegalArgumentException(QUERY_TOKENS_FIELD.getPreferredName() + " field cannot be null."); + throw new IllegalArgumentException("Query tokens cannot be null."); } for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { - throw new IllegalArgumentException("weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()); + throw new IllegalArgumentException("Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()); } } } @@ -289,7 +247,6 @@ protected boolean doEquals(SparseEncodingQueryBuilder obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) - .append(queryTokens, obj.queryTokens) .append(queryText, obj.queryText) .append(modelId, obj.modelId); return equalsBuilder.isEquals(); @@ -297,7 +254,7 @@ protected boolean doEquals(SparseEncodingQueryBuilder obj) { @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName).append(queryTokens).append(queryText).append(modelId).toHashCode(); + return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode(); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java index 7f33c44c2..6cb122c4f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java @@ -14,7 +14,6 @@ import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; -import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TOKENS_FIELD; import java.io.IOException; import java.util.List; @@ -47,7 +46,6 @@ public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { private static final String FIELD_NAME = "testField"; private static final String QUERY_TEXT = "Hello world!"; - private static final Map QUERY_TOKENS = Map.of("hello", 1.f, "world", 2.f); private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; @@ -80,32 +78,6 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); } - @SneakyThrows - public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() { - /* - { - "VECTOR_FIELD": { - "query_tokens": { - "string":float, - } - } - } - */ - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject(FIELD_NAME) - .field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS) - .endObject() - .endObject(); - - XContentParser contentParser = createParser(xContentBuilder); - contentParser.nextToken(); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); - - assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); - assertEquals(QUERY_TOKENS, sparseEncodingQueryBuilder.queryTokens()); - } - @SneakyThrows public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { /* @@ -242,8 +214,7 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { public void testToXContent() { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) .modelId(MODEL_ID) - .queryText(QUERY_TEXT) - .queryTokens(QUERY_TOKENS); + .queryText(QUERY_TEXT); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -268,12 +239,6 @@ public void testToXContent() { assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); - // QUERY_TOKENS is map, the converted one use - Map convertedQueryTokensMap = (Map) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName()); - assertEquals(QUERY_TOKENS.size(), convertedQueryTokensMap.size()); - for (Map.Entry entry : QUERY_TOKENS.entrySet()) { - assertEquals(entry.getValue(), convertedQueryTokensMap.get(entry.getKey()).floatValue(), 0); - } } @SneakyThrows @@ -282,7 +247,6 @@ public void testStreams() { original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); original.modelId(MODEL_ID); - original.queryTokens(QUERY_TOKENS); original.boost(BOOST); original.queryName(QUERY_NAME); @@ -311,12 +275,9 @@ public void testHashAndEquals() { float boost2 = 3.8f; String queryName1 = "query-1"; String queryName2 = "query-2"; - Map queryTokens1 = Map.of("hello", 1f); - Map queryTokens2 = Map.of("hello", 2f); SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText1) - .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) .queryName(queryName1); @@ -324,7 +285,6 @@ public void testHashAndEquals() { // Identical to sparseEncodingQueryBuilder_baseline SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText1) - .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) .queryName(queryName1); @@ -332,12 +292,11 @@ public void testHashAndEquals() { // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName( fieldName1 - ).queryText(queryText1).queryTokens(queryTokens1).modelId(modelId1); + ).queryText(queryText1).modelId(modelId1); // Identical to sparseEncodingQueryBuilder_baseline except diff field name SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2) .queryText(queryText1) - .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) .queryName(queryName1); @@ -345,7 +304,6 @@ public void testHashAndEquals() { // Identical to sparseEncodingQueryBuilder_baseline except diff query text SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText2) - .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) .queryName(queryName1); @@ -353,23 +311,13 @@ public void testHashAndEquals() { // Identical to sparseEncodingQueryBuilder_baseline except diff model ID SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText1) - .queryTokens(queryTokens1) .modelId(modelId2) .boost(boost1) .queryName(queryName1); - // Identical to sparseEncodingQueryBuilder_baseline except diff query tokens - SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new SparseEncodingQueryBuilder().fieldName(fieldName1) - .queryText(queryText1) - .queryTokens(queryTokens2) - .modelId(modelId1) - .boost(boost1) - .queryName(queryName1); - // Identical to sparseEncodingQueryBuilder_baseline except diff boost SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText1) - .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost2) .queryName(queryName1); @@ -377,7 +325,6 @@ public void testHashAndEquals() { // Identical to sparseEncodingQueryBuilder_baseline except diff query name SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1) .queryText(queryText1) - .queryTokens(queryTokens1) .modelId(modelId1) .boost(boost1) .queryName(queryName2); @@ -400,9 +347,6 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffModelId); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffModelId.hashCode()); - assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryTokens); - assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); - assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffBoost); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffBoost.hashCode()); @@ -410,16 +354,6 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); } - @SneakyThrows - public void testRewrite_whenQueryTokensNotNull_thenRewriteToSelf() { - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().queryTokens(QUERY_TOKENS) - .fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID); - QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - assert queryBuilder == sparseEncodingQueryBuilder; - } - @SneakyThrows public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) @@ -455,27 +389,16 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() } @SneakyThrows - public void testRewrite_whenSupplierContentNull_thenReturnCopy() { - Supplier> nullSupplier = () -> null; - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .queryTokensSupplier(nullSupplier); - QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - assertEquals(sparseEncodingQueryBuilder, queryBuilder); - } - - @SneakyThrows - public void testRewrite_whenQueryTokensSupplierSet_thenSetQueryTokens() { + public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) .modelId(MODEL_ID) .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - SparseEncodingQueryBuilder targetQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .queryTokens(QUERY_TOKENS_SUPPLIER.get()); - assertEquals(queryBuilder, targetQueryBuilder); + assertTrue(queryBuilder == sparseEncodingQueryBuilder); + + sparseEncodingQueryBuilder.queryTokensSupplier(() -> null); + queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); + assertTrue(queryBuilder == sparseEncodingQueryBuilder); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java index 7d2a2314c..54991d7e2 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java @@ -35,7 +35,7 @@ public class SparseEncodingQueryIT extends BaseSparseEncodingIT { private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); private static final Float DELTA = 1e-5f; - private final Map testTokenWeightMap = TestUtils.createRandomTokenWeightMap(TEST_TOKENS); + private final Map testRankFeaturesDoc = TestUtils.createRandomTokenWeightMap(TEST_TOKENS); @Before public void setUp() throws Exception { @@ -75,38 +75,7 @@ public void testBasicQueryUsingQueryText() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); - } - - /** - * Tests basic query: - * { - * "query": { - * "sparse_encoding": { - * "text_sparse": { - * "query_tokens": { - * "hello": float, - * "a": float, - * "c": float - * } - * } - * } - * } - * } - */ - @SneakyThrows - public void testBasicQueryUsingQueryTokens() { - initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); - Map queryTokens = TestUtils.createRandomTokenWeightMap(List.of("hello", "a", "b")); - SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( - TEST_SPARSE_ENCODING_FIELD_NAME_1 - ).queryTokens(queryTokens); - Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); - Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); - - assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(testTokenWeightMap, queryTokens); + float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } @@ -135,7 +104,7 @@ public void testBoostQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } @@ -171,12 +140,12 @@ public void testRescoreQuery() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** - * Tests bool should query with query tokens: + * Tests bool should query with query text: * { * "query": { * "bool" : { @@ -217,12 +186,12 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } /** - * Tests bool should query with query tokens: + * Tests bool should query with query text: * { * "query": { * "bool" : { @@ -260,7 +229,7 @@ public void testBooleanQuery_withSparseEncodingAndBM25Queries() { Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float minExpectedScore = computeExpectedScore(modelId, testTokenWeightMap, TEST_QUERY_TEXT); + float minExpectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); } @@ -280,7 +249,7 @@ public void testBasicQueryUsingQueryText_whenQueryWrongFieldType_thenFail() { protected void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); - addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testTokenWeightMap)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testRankFeaturesDoc)); assertEquals(1, getDocCount(indexName)); } @@ -290,7 +259,7 @@ protected void initializeIndexIfNotExist(String indexName) { indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2), - List.of(testTokenWeightMap, testTokenWeightMap) + List.of(testRankFeaturesDoc, testRankFeaturesDoc) ); assertEquals(1, getDocCount(indexName)); } @@ -301,7 +270,7 @@ protected void initializeIndexIfNotExist(String indexName) { indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), - List.of(testTokenWeightMap), + List.of(testRankFeaturesDoc), List.of(TEST_TEXT_FIELD_NAME_1), List.of(TEST_QUERY_TEXT) ); @@ -310,7 +279,7 @@ protected void initializeIndexIfNotExist(String indexName) { if (TEST_NESTED_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED)); - addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), List.of(testTokenWeightMap)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), List.of(testRankFeaturesDoc)); assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME)); } } From 9647ac9bf1f49874d313fe88e77281ee8c83d8a9 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 16:15:58 +0800 Subject: [PATCH 62/70] fix test Signed-off-by: zhichao-aws --- .../neuralsearch/ml/MLCommonsClientAccessorTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 000b9598b..e93033517 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -193,7 +193,7 @@ public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenEx ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); assertEquals( - "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]", argumentCaptor.getValue().getMessage() ); Mockito.verifyNoMoreInteractions(resultListener); @@ -217,7 +217,7 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); assertEquals( - "Empty model result produced. Expected 1 tensor output and 1 model tensor, but got [0]", + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]", argumentCaptor.getValue().getMessage() ); Mockito.verifyNoMoreInteractions(resultListener); From 169934a0f4bde405f115506693e31d0d67e812bb Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 16:16:22 +0800 Subject: [PATCH 63/70] tidy Signed-off-by: zhichao-aws --- .../ml/MLCommonsClientAccessor.java | 4 ++- .../neuralsearch/plugin/NeuralSearch.java | 9 +++---- .../neuralsearch/processor/NLPProcessor.java | 22 +++++++-------- .../query/SparseEncodingQueryBuilder.java | 27 +++++-------------- .../ml/MLCommonsClientAccessorTests.java | 2 +- 5 files changed, 26 insertions(+), 38 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 2c752b7ea..ed0d95d1a 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -179,7 +179,9 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { - throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]"); + throw new IllegalStateException( + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]" + ); } List> resultMaps = new ArrayList<>(); for (ModelTensors tensors : tensorOutputList) { diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 601bf9003..2ac8853e4 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -9,7 +9,6 @@ import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -100,10 +99,10 @@ public List> getQueries() { public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); return Map.of( - TextEmbeddingProcessor.TYPE, - new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), - SparseEncodingProcessor.TYPE, - new SparseEncodingProcessorFactory(clientAccessor, parameters.env) + TextEmbeddingProcessor.TYPE, + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + SparseEncodingProcessor.TYPE, + new SparseEncodingProcessorFactory(clientAccessor, parameters.env) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 2ca516404..52962c5bb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -164,20 +164,20 @@ Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument inge } private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes ) { if (processorKey == null || sourceAndMetadataMap == null) return; if (processorKey instanceof Map) { Map next = new LinkedHashMap<>(); for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next ); } treeRes.put(parentKey, next); @@ -214,9 +214,9 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl validateListTypeValue(sourceKey, sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 3e938933c..2f0647d68 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -6,7 +6,6 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Supplier; @@ -125,18 +124,9 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); requireValue( sparseEncodingQueryBuilder.queryText(), - QUERY_TEXT_FIELD.getPreferredName() - + " must be provided for " - + NAME - + " query" - ); - requireValue( - sparseEncodingQueryBuilder.modelId(), - MODEL_ID_FIELD.getPreferredName() - + " must be provided for " - + NAME - + " query" + QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query" ); + requireValue(sparseEncodingQueryBuilder.modelId(), MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + " query"); return sparseEncodingQueryBuilder; } @@ -215,19 +205,14 @@ protected Query doToQuery(QueryShardContext context) throws IOException { private static void validateForRewrite(String queryText, String modelId) { if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - QUERY_TEXT_FIELD.getPreferredName() - + " and " - + MODEL_ID_FIELD.getPreferredName() - + " cannot be null." + QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + " cannot be null." ); } } private static void validateFieldType(MappedFieldType fieldType) { if (null == fieldType || !fieldType.typeName().equals("rank_features")) { - throw new IllegalArgumentException( - "[" + NAME + "] query only works on [rank_features] fields" - ); + throw new IllegalArgumentException("[" + NAME + "] query only works on [rank_features] fields"); } } @@ -237,7 +222,9 @@ private static void validateQueryTokens(Map queryTokens) { } for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { - throw new IllegalArgumentException("Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()); + throw new IllegalArgumentException( + "Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey() + ); } } } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index e93033517..295daa948 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -217,7 +217,7 @@ public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenExc ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); assertEquals( - "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]", + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]", argumentCaptor.getValue().getMessage() ); Mockito.verifyNoMoreInteractions(resultListener); From a47c8b6dce159555a05429e63697403a72d2562b Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 08:44:32 +0000 Subject: [PATCH 64/70] update function name Signed-off-by: zhichao-aws --- .../processor/UploadSparseEncodingModelRequestBody.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json index e630e6dca..eae58829b 100644 --- a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json +++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json @@ -1,7 +1,7 @@ { "name": "tokenize-idf-0915", "version": "1.0.0", - "function_name": "TOKENIZE", + "function_name": "SPARSE_TOKENIZE", "description": "test model", "model_format": "TORCH_SCRIPT", "model_group_id": "", From b48091fb78aa65ae719956ceb3ad03d1061a026d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 26 Sep 2023 18:01:25 +0800 Subject: [PATCH 65/70] add javadoc Signed-off-by: zhichao-aws --- .../neuralsearch/processor/NLPProcessor.java | 8 +++++--- .../processor/SparseEncodingProcessor.java | 4 ++++ .../query/SparseEncodingQueryBuilder.java | 15 +++++++++++++++ .../neuralsearch/util/TokenWeightUtil.java | 8 ++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java index 52962c5bb..4ac63d419 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -26,9 +26,11 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableMap; -// The abstract class for text processing use cases. Users provide a field name map -// and a model id. During ingestion, the processor will use the corresponding model -// to inference the input texts, and set the target fields according to the field name map. +/** + * The abstract class for text processing use cases. Users provide a field name map and a model id. + * During ingestion, the processor will use the corresponding model to inference the input texts, + * and set the target fields according to the field name map. + */ @Log4j2 public abstract class NLPProcessor extends AbstractProcessor { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 217d551c4..62857541e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -17,6 +17,10 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.TokenWeightUtil; +/** + * This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results. + */ @Log4j2 public class SparseEncodingProcessor extends NLPProcessor { diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 2f0647d68..430e6a1f6 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -42,6 +42,12 @@ import com.google.common.annotations.VisibleForTesting; +/** + * SparseEncodingQueryBuilder is responsible for handling "sparse_encoding" query types. It uses an ML SPARSE_ENCODING model + * or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed + * to Lucene FeatureQuery wrapped by Lucene BooleanQuery. + */ + @Log4j2 @Getter @Setter @@ -66,6 +72,12 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String modelId; private Supplier> queryTokensSupplier; + /** + * Constructor from stream input + * + * @param in StreamInput to initialize object from + * @throws IOException thrown if unable to read from input stream + */ public SparseEncodingQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); @@ -98,6 +110,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * "model_id": "string" * } * + * @param parser XContentParser + * @return NeuralQueryBuilder + * @throws IOException can be thrown by parser */ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException { SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder(); diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index 2b9613be3..db249de0f 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -11,6 +11,12 @@ import java.util.Map; import java.util.stream.Collectors; +/** + * Utility class for working with sparse_encoding queries and ingest processor. + * Used to fetch the (token, weight) Map from the response returned by {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} + * + */ + public class TokenWeightUtil { public static String RESPONSE_KEY = "response"; @@ -36,6 +42,8 @@ public class TokenWeightUtil { * { TOKEN_WEIGHT_MAP} * ] * }] + * + * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ public static List> fetchListOfTokenWeightMap(List> mapResultList) { List results = new ArrayList<>(); From cfc847deba0549050bd78716b779be1858095a7e Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 08:08:45 +0800 Subject: [PATCH 66/70] remove debug log including inference result Signed-off-by: zhichao-aws --- .../org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index ed0d95d1a..6f8b790bb 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -122,7 +122,6 @@ private void retryableInferenceSentencesWithMapResult( MLInput mlInput = createMLInput(null, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> result = buildMapResultFromResponse(mlOutput); - log.debug("Inference Response for input sentence {} is : {} ", inputText, result); listener.onResponse(result); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { @@ -144,7 +143,6 @@ private void retryableInferenceSentencesWithVectorResult( MLInput mlInput = createMLInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); listener.onResponse(vector); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { From 508b462111d58180a2361a41e54c6991164cbf5e Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 08:11:36 +0800 Subject: [PATCH 67/70] make query text and model id required Signed-off-by: zhichao-aws --- .../query/SparseEncodingQueryBuilder.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 430e6a1f6..07f581a9d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -81,23 +81,23 @@ public static void initialize(MLCommonsClientAccessor mlClient) { public SparseEncodingQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); - this.queryText = in.readOptionalString(); - this.modelId = in.readOptionalString(); + this.queryText = in.readString(); + this.modelId = in.readString(); } @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); - out.writeOptionalString(queryText); - out.writeOptionalString(modelId); + out.writeString(queryText); + out.writeString(modelId); } @Override protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); - if (null != queryText) xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); - if (null != modelId) xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); From aae62d48d84b23ab5033265a66d491da4ea0ae8b Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 09:07:23 +0800 Subject: [PATCH 68/70] minor changes based on comments Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 - .../processor/SparseEncodingProcessor.java | 2 +- .../processor/TextEmbeddingProcessor.java | 2 +- .../SparseEncodingProcessorFactory.java | 3 ++ .../TextEmbeddingProcessorFactory.java | 3 ++ .../query/SparseEncodingQueryBuilder.java | 28 ++++++++++--------- .../neuralsearch/util/TokenWeightUtil.java | 5 +++- 7 files changed, 27 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 038fe41e5..da2ae9ec9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features -Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 62857541e..275117809 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -22,7 +22,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results. */ @Log4j2 -public class SparseEncodingProcessor extends NLPProcessor { +public final class SparseEncodingProcessor extends NLPProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 354b53945..1df60baea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -21,7 +21,7 @@ * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. */ @Log4j2 -public class TextEmbeddingProcessor extends NLPProcessor { +public final class TextEmbeddingProcessor extends NLPProcessor { public static final String TYPE = "text_embedding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index dff56e9c8..104418ec5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -18,6 +18,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; +/** + * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ @Log4j2 public class SparseEncodingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index f805b29e1..0c9a6fa2c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -16,6 +16,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +/** + * Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index 07f581a9d..a8c2baaf7 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -126,22 +126,24 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr if (parser.nextToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( parser.getTokenLocation(), - "[" - + NAME - + "] query doesn't support multiple fields, found [" - + sparseEncodingQueryBuilder.fieldName() - + "] and [" - + parser.currentName() - + "]" + String.format( + "[%s] query doesn't support multiple fields, found [%s] and [%s]", + NAME, + sparseEncodingQueryBuilder.fieldName(), + parser.currentName() + ) ); } requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); requireValue( sparseEncodingQueryBuilder.queryText(), - QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query" + String.format("%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) + ); + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format("%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) ); - requireValue(sparseEncodingQueryBuilder.modelId(), MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + " query"); return sparseEncodingQueryBuilder; } @@ -164,13 +166,13 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB } else { throw new ParsingException( parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]" + String.format("[%s] query does not support [%s] field", NAME, currentFieldName) ); } } else { throw new ParsingException( parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + String.format("[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName) ); } } @@ -220,7 +222,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException { private static void validateForRewrite(String queryText, String modelId) { if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + " cannot be null." + String.format("%s and %s cannot be null", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName()) ); } } @@ -238,7 +240,7 @@ private static void validateQueryTokens(Map queryTokens) { for (Map.Entry entry : queryTokens.entrySet()) { if (entry.getValue() <= 0) { throw new IllegalArgumentException( - "Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey() + "Feature weight must be larger than 0, feature [" + entry.getValue() + "] has negative weight." ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java index db249de0f..76ce0fa16 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -46,6 +46,9 @@ public class TokenWeightUtil { * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} */ public static List> fetchListOfTokenWeightMap(List> mapResultList) { + if (null == mapResultList || mapResultList.isEmpty()) { + throw new IllegalArgumentException("The inference result can not be null or empty."); + } List results = new ArrayList<>(); for (Map map : mapResultList) { if (!map.containsKey(RESPONSE_KEY)) { @@ -66,7 +69,7 @@ private static Map buildTokenWeightMap(Object uncastedMap) { Map result = new HashMap<>(); for (Map.Entry entry : ((Map) uncastedMap).entrySet()) { if (!String.class.isAssignableFrom(entry.getKey().getClass()) || !Number.class.isAssignableFrom(entry.getValue().getClass())) { - throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + " Float values."); + throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); } result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); } From 2d51bb9b54c4fd17b79044552d405b5cac8e202d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 10:54:54 +0800 Subject: [PATCH 69/70] add locale to String.format Signed-off-by: zhichao-aws --- .../query/SparseEncodingQueryBuilder.java | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java index a8c2baaf7..4b8b6f0d4 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -7,6 +7,7 @@ import java.io.IOException; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Supplier; @@ -127,6 +128,7 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr throw new ParsingException( parser.getTokenLocation(), String.format( + Locale.ROOT, "[%s] query doesn't support multiple fields, found [%s] and [%s]", NAME, sparseEncodingQueryBuilder.fieldName(), @@ -138,11 +140,11 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); requireValue( sparseEncodingQueryBuilder.queryText(), - String.format("%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) + String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) ); requireValue( sparseEncodingQueryBuilder.modelId(), - String.format("%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) + String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) ); return sparseEncodingQueryBuilder; @@ -166,13 +168,13 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB } else { throw new ParsingException( parser.getTokenLocation(), - String.format("[%s] query does not support [%s] field", NAME, currentFieldName) + String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName) ); } } else { throw new ParsingException( parser.getTokenLocation(), - String.format("[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName) + String.format(Locale.ROOT, "[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName) ); } } @@ -222,7 +224,12 @@ protected Query doToQuery(QueryShardContext context) throws IOException { private static void validateForRewrite(String queryText, String modelId) { if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { throw new IllegalArgumentException( - String.format("%s and %s cannot be null", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName()) + String.format( + Locale.ROOT, + "%s and %s cannot be null", + QUERY_TEXT_FIELD.getPreferredName(), + MODEL_ID_FIELD.getPreferredName() + ) ); } } From 96114111f3f6103ab19b58ea2cbc09f86da8b88b Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 05:32:24 +0000 Subject: [PATCH 70/70] update mock model url Signed-off-by: zhichao-aws --- .../processor/UploadSparseEncodingModelRequestBody.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json index eae58829b..c45334bae 100644 --- a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json +++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json @@ -5,6 +5,6 @@ "description": "test model", "model_format": "TORCH_SCRIPT", "model_group_id": "", - "model_content_hash_value": "e23969f8bd417e7aec26f49201da4adfc6b74e6187d1ddfdfb98e473bdd95978", - "url": "https://github.com/xinyual/demo/raw/main/tokenizer-idf-msmarco.zip" + "model_content_hash_value": "b345e9e943b62c405a8dd227ef2c46c84c5ff0a0b71b584be9132b37bce91a9a", + "url": "https://github.com/opensearch-project/ml-commons/raw/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/sparse_encoding/sparse_demo.zip" } \ No newline at end of file