diff --git a/libs/common/src/main/java/org/opensearch/common/unit/TimeValue.java b/libs/common/src/main/java/org/opensearch/common/unit/TimeValue.java index 670275397893c..380dc0f7801ad 100644 --- a/libs/common/src/main/java/org/opensearch/common/unit/TimeValue.java +++ b/libs/common/src/main/java/org/opensearch/common/unit/TimeValue.java @@ -32,6 +32,7 @@ package org.opensearch.common.unit; +import java.io.Serializable; import java.util.Locale; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -41,7 +42,7 @@ * * @opensearch.api */ -public class TimeValue implements Comparable { +public class TimeValue implements Comparable, Serializable { /** How many nano-seconds in one milli-second */ public static final long NSEC_PER_MSEC = TimeUnit.NANOSECONDS.convert(1, TimeUnit.MILLISECONDS); diff --git a/server/src/main/java/org/opensearch/action/search/ProtobufSearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/ProtobufSearchQueryThenFetchAsyncAction.java index b3dbf51cdfcfe..04b7a0a0d69cd 100644 --- a/server/src/main/java/org/opensearch/action/search/ProtobufSearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/ProtobufSearchQueryThenFetchAsyncAction.java @@ -152,7 +152,8 @@ private ProtobufShardSearchRequest rewriteShardSearchRequest(ProtobufShardSearch // set the current best bottom field doc if (bottomSortCollector.getBottomSortValues() != null) { - request.setBottomSortValues(bottomSortCollector.getBottomSortValues()); + // request.setBottomSortValues(bottomSortCollector.getBottomSortValues()); + System.out.println("Bottom sort values is not null......now what????"); } return request; } diff --git a/server/src/main/java/org/opensearch/index/query/QueryBuilder.java b/server/src/main/java/org/opensearch/index/query/QueryBuilder.java index a40ccf427794a..e4a96009b64cf 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/QueryBuilder.java @@ -37,13 +37,14 @@ import org.opensearch.core.xcontent.ToXContentObject; import java.io.IOException; +import java.io.Serializable; /** * Foundation class for all OpenSearch query builders * * @opensearch.internal */ -public interface QueryBuilder extends NamedWriteable, ToXContentObject, Rewriteable { +public interface QueryBuilder extends NamedWriteable, ToXContentObject, Rewriteable, Serializable { /** * Converts this QueryBuilder to a lucene {@link Query}. diff --git a/server/src/main/java/org/opensearch/index/query/Rewriteable.java b/server/src/main/java/org/opensearch/index/query/Rewriteable.java index ea884f720f4fc..66278f085eeb8 100644 --- a/server/src/main/java/org/opensearch/index/query/Rewriteable.java +++ b/server/src/main/java/org/opensearch/index/query/Rewriteable.java @@ -112,31 +112,34 @@ static > void rewriteAndFetch( ActionListener rewriteResponse, int iteration ) { + System.out.println("In rewriteAndFetch"); + System.out.println("Original: " + original.getClass()); + System.out.println("Context: " + context.getClass()); T builder = original; try { - for (T rewrittenBuilder = builder.rewrite(context); rewrittenBuilder != builder; rewrittenBuilder = builder.rewrite(context)) { - builder = rewrittenBuilder; - if (iteration++ >= MAX_REWRITE_ROUNDS) { - // this is some protection against user provided queries if they don't obey the contract of rewrite we allow 16 rounds - // and then we fail to prevent infinite loops - throw new IllegalStateException( - "too many rewrite rounds, rewriteable might return new objects even if they are not " + "rewritten" - ); - } - if (context.hasAsyncActions()) { - T finalBuilder = builder; - final int currentIterationNumber = iteration; - context.executeAsyncActions( - ActionListener.wrap( - n -> rewriteAndFetch(finalBuilder, context, rewriteResponse, currentIterationNumber), - rewriteResponse::onFailure - ) - ); - return; - } - } + // for (T rewrittenBuilder = builder.rewrite(context); rewrittenBuilder != builder; rewrittenBuilder = builder.rewrite(context)) { + // builder = rewrittenBuilder; + // if (iteration++ >= MAX_REWRITE_ROUNDS) { + // // this is some protection against user provided queries if they don't obey the contract of rewrite we allow 16 rounds + // // and then we fail to prevent infinite loops + // throw new IllegalStateException( + // "too many rewrite rounds, rewriteable might return new objects even if they are not " + "rewritten" + // ); + // } + // if (context.hasAsyncActions()) { + // T finalBuilder = builder; + // final int currentIterationNumber = iteration; + // context.executeAsyncActions( + // ActionListener.wrap( + // n -> rewriteAndFetch(finalBuilder, context, rewriteResponse, currentIterationNumber), + // rewriteResponse::onFailure + // ) + // ); + // return; + // } + // } rewriteResponse.onResponse(builder); - } catch (IOException | IllegalArgumentException | ParsingException ex) { + } catch (IllegalArgumentException | ParsingException ex) { rewriteResponse.onFailure(ex); } } diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 59377389dcaf8..1861708cd0200 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -394,7 +394,12 @@ public void preProcess(boolean rewrite) { } // initialize the filtering alias based on the provided filters try { - final QueryBuilder queryBuilder = request.getAliasFilter().getQueryBuilder(); + final QueryBuilder queryBuilder; + if (request == null) { + queryBuilder = protobufShardSearchRequest.getAliasFilter().getQueryBuilder(); + } else { + queryBuilder = request.getAliasFilter().getQueryBuilder(); + } aliasFilter = queryBuilder == null ? null : queryBuilder.toQuery(queryShardContext); } catch (IOException e) { throw new UncheckedIOException(e); @@ -465,6 +470,11 @@ public ShardSearchRequest request() { return this.request; } + @Override + public ProtobufShardSearchRequest protobufShardSearchRequest() { + return this.protobufShardSearchRequest; + } + @Override public SearchType searchType() { return this.searchType; @@ -969,8 +979,16 @@ public SearchShardTask getTask() { return task; } + @Override + public ProtobufSearchShardTask getProtobufTask() { + return protobufSearchShardTask; + } + @Override public boolean isCancelled() { + if (task == null) { + return protobufSearchShardTask.isCancelled(); + } return task.isCancelled(); } diff --git a/server/src/main/java/org/opensearch/search/Scroll.java b/server/src/main/java/org/opensearch/search/Scroll.java index 562979b98ec7d..aad5c370c2a62 100644 --- a/server/src/main/java/org/opensearch/search/Scroll.java +++ b/server/src/main/java/org/opensearch/search/Scroll.java @@ -38,6 +38,7 @@ import org.opensearch.common.unit.TimeValue; import java.io.IOException; +import java.io.Serializable; import java.util.Objects; /** @@ -46,7 +47,7 @@ * * @opensearch.internal */ -public final class Scroll implements Writeable { +public final class Scroll implements Writeable, Serializable { private final TimeValue keepAlive; diff --git a/server/src/main/java/org/opensearch/search/SearchExtBuilder.java b/server/src/main/java/org/opensearch/search/SearchExtBuilder.java index 4d86c6c2e2277..d9a92f44b1cf7 100644 --- a/server/src/main/java/org/opensearch/search/SearchExtBuilder.java +++ b/server/src/main/java/org/opensearch/search/SearchExtBuilder.java @@ -32,6 +32,8 @@ package org.opensearch.search; +import java.io.Serializable; + import org.opensearch.common.CheckedFunction; import org.opensearch.core.common.io.stream.NamedWriteable; import org.opensearch.core.common.io.stream.StreamInput; @@ -58,7 +60,7 @@ * * @opensearch.internal */ -public abstract class SearchExtBuilder implements NamedWriteable, ToXContentFragment { +public abstract class SearchExtBuilder implements NamedWriteable, ToXContentFragment, Serializable { public abstract int hashCode(); diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 40d22ed492d5a..a7748f0a48bef 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -624,6 +624,7 @@ public void onFailure(Exception exc) { } private IndexShard getShard(ShardSearchRequest request) { + System.out.println("getShard"); if (request.readerId() != null) { return findReaderContext(request.readerId(), request).indexShard(); } else { @@ -632,6 +633,7 @@ private IndexShard getShard(ShardSearchRequest request) { } private IndexShard getShardProtobuf(ProtobufShardSearchRequest request) { + System.out.println("getShardProtobuf"); if (request.readerId() != null) { return findReaderContext(request.readerId(), request).indexShard(); } else { @@ -939,6 +941,8 @@ public void executeFetchPhaseProtobuf(ProtobufShardFetchRequest request, Protobu } private ReaderContext getReaderContext(ShardSearchContextId id) { + System.out.println("getReaderContext"); + System.out.println(id.getSessionId()); if (sessionId.equals(id.getSessionId()) == false && id.getSessionId().isEmpty() == false) { throw new SearchContextMissingException(id); } @@ -946,6 +950,8 @@ private ReaderContext getReaderContext(ShardSearchContextId id) { } private ReaderContext findReaderContext(ShardSearchContextId id, TransportRequest request) throws SearchContextMissingException { + System.out.println("findReaderContext"); + System.out.println(id); final ReaderContext reader = getReaderContext(id); if (reader == null) { throw new SearchContextMissingException(id); @@ -1357,7 +1363,7 @@ private DefaultSearchContext createSearchContextProtobuf(ReaderContext reader, P // might end up with incorrect state since we are using now() or script services // during rewrite and normalized / evaluate templates etc. QueryShardContext context = new QueryShardContext(searchContext.getQueryShardContext()); - Rewriteable.rewrite(request.getRewriteable(), context, true); + // Rewriteable.rewrite(request.getRewriteable(), context, true); assert searchContext.getQueryShardContext().isCacheable(); success = true; } finally { @@ -1746,7 +1752,7 @@ private void shortcutDocIdsToLoad(SearchContext context) { } else { completionSuggestions = Collections.emptyList(); } - if (context.request().scroll() != null) { + if ((context.request() != null && context.request().scroll() != null) || (context.protobufShardSearchRequest() != null && context.protobufShardSearchRequest().scroll() != null)) { TopDocs topDocs = context.queryResult().topDocs().topDocs; docIdsToLoad = new int[topDocs.scoreDocs.length + numSuggestDocs]; for (int i = 0; i < topDocs.scoreDocs.length; i++) { @@ -1991,6 +1997,8 @@ public static boolean canRewriteToMatchNone(SearchSourceBuilder source) { } private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest request, ActionListener listener) { + System.out.println("SearchService rewriteAndFetchShardRequest"); + System.out.println("ShardSearchRequest: " + request); ActionListener actionListener = ActionListener.wrap(r -> { if (request.readerId() != null) { listener.onResponse(request); @@ -2006,6 +2014,8 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re } private void rewriteAndFetchShardRequestProtobuf(IndexShard shard, ProtobufShardSearchRequest request, ActionListener listener) { + System.out.println("SearchService rewriteAndFetchShardRequestProtobuf"); + System.out.println("ProtobufShardSearchRequest: " + request); ActionListener actionListener = ActionListener.wrap(r -> { if (request.readerId() != null) { listener.onResponse(request); diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java index 04fa34466e0ff..6b80fad804dcc 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorFactories.java @@ -59,6 +59,7 @@ import org.opensearch.search.profile.aggregation.ProfilingAggregator; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -333,7 +334,7 @@ public int countAggregators() { * * @opensearch.internal */ - public static class Builder implements Writeable, ToXContentObject { + public static class Builder implements Writeable, ToXContentObject, Serializable { private final Set names = new HashSet<>(); // Using LinkedHashSets to preserve the order of insertion, that makes the results diff --git a/server/src/main/java/org/opensearch/search/builder/PointInTimeBuilder.java b/server/src/main/java/org/opensearch/search/builder/PointInTimeBuilder.java index 26a7738177759..29c50e88f38de 100644 --- a/server/src/main/java/org/opensearch/search/builder/PointInTimeBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/PointInTimeBuilder.java @@ -44,6 +44,7 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.io.Serializable; import java.util.Objects; /** @@ -52,7 +53,7 @@ * * @opensearch.internal */ -public final class PointInTimeBuilder implements Writeable, ToXContentObject { +public final class PointInTimeBuilder implements Writeable, ToXContentObject, Serializable { private static final ParseField ID_FIELD = new ParseField("id"); private static final ParseField KEEP_ALIVE_FIELD = new ParseField("keep_alive"); private static final ObjectParser PARSER; diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 78cb895a0a4c0..e174acfe072d2 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -74,6 +74,7 @@ import org.opensearch.search.suggest.SuggestBuilder; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -93,7 +94,7 @@ * * @opensearch.internal */ -public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable { +public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable, Serializable { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class); public static final ParseField FROM_FIELD = new ParseField("from"); @@ -1495,7 +1496,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws * * @opensearch.internal */ - public static class IndexBoost implements Writeable, ToXContentObject { + public static class IndexBoost implements Writeable, ToXContentObject, Serializable { private final String index; private final float boost; @@ -1596,7 +1597,7 @@ public boolean equals(Object obj) { * * @opensearch.internal */ - public static class ScriptField implements Writeable, ToXContentFragment { + public static class ScriptField implements Writeable, ToXContentFragment, Serializable { private final boolean ignoreFailure; private final String fieldName; diff --git a/server/src/main/java/org/opensearch/search/collapse/CollapseBuilder.java b/server/src/main/java/org/opensearch/search/collapse/CollapseBuilder.java index 288ca9339f8bd..5af80f545f734 100644 --- a/server/src/main/java/org/opensearch/search/collapse/CollapseBuilder.java +++ b/server/src/main/java/org/opensearch/search/collapse/CollapseBuilder.java @@ -49,6 +49,7 @@ import org.opensearch.index.query.QueryShardContext; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -59,7 +60,7 @@ * * @opensearch.internal */ -public class CollapseBuilder implements Writeable, ToXContentObject { +public class CollapseBuilder implements Writeable, ToXContentObject, Serializable { public static final ParseField FIELD_FIELD = new ParseField("field"); public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits"); public static final ParseField MAX_CONCURRENT_GROUP_REQUESTS_FIELD = new ParseField("max_concurrent_group_searches"); diff --git a/server/src/main/java/org/opensearch/search/fetch/StoredFieldsContext.java b/server/src/main/java/org/opensearch/search/fetch/StoredFieldsContext.java index e8c1dc57627fb..a2cad0fda7a8e 100644 --- a/server/src/main/java/org/opensearch/search/fetch/StoredFieldsContext.java +++ b/server/src/main/java/org/opensearch/search/fetch/StoredFieldsContext.java @@ -42,6 +42,7 @@ import org.opensearch.rest.RestRequest; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -53,7 +54,7 @@ * * @opensearch.internal */ -public class StoredFieldsContext implements Writeable { +public class StoredFieldsContext implements Writeable, Serializable { public static final String _NONE_ = "_none_"; private final List fieldNames; diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourceContext.java b/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourceContext.java index 5b9b9e1e70cfa..1590921814102 100644 --- a/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourceContext.java +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/FetchSourceContext.java @@ -46,6 +46,7 @@ import org.opensearch.rest.RestRequest; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -57,7 +58,7 @@ * * @opensearch.internal */ -public class FetchSourceContext implements Writeable, ToXContentObject { +public class FetchSourceContext implements Writeable, ToXContentObject, Serializable { public static final ParseField INCLUDES_FIELD = new ParseField("includes", "include"); public static final ParseField EXCLUDES_FIELD = new ParseField("excludes", "exclude"); diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/FieldAndFormat.java b/server/src/main/java/org/opensearch/search/fetch/subphase/FieldAndFormat.java index f7e4b06624c76..49e295e2d6158 100644 --- a/server/src/main/java/org/opensearch/search/fetch/subphase/FieldAndFormat.java +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/FieldAndFormat.java @@ -44,6 +44,7 @@ import org.opensearch.core.xcontent.XContentParser; import java.io.IOException; +import java.io.Serializable; import java.util.Objects; /** @@ -52,7 +53,7 @@ * * @opensearch.internal */ -public final class FieldAndFormat implements Writeable, ToXContentObject { +public final class FieldAndFormat implements Writeable, ToXContentObject, Serializable { private static final ParseField FIELD_FIELD = new ParseField("field"); private static final ParseField FORMAT_FIELD = new ParseField("format"); diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/AbstractHighlighterBuilder.java b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/AbstractHighlighterBuilder.java index 162c79c28f982..4a02b819149e2 100644 --- a/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/AbstractHighlighterBuilder.java +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/highlight/AbstractHighlighterBuilder.java @@ -52,6 +52,7 @@ import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder.Order; import java.io.IOException; +import java.io.Serializable; import java.util.Arrays; import java.util.Locale; import java.util.Map; @@ -71,7 +72,8 @@ public abstract class AbstractHighlighterBuilder, - ToXContentObject { + ToXContentObject, + Serializable { public static final ParseField PRE_TAGS_FIELD = new ParseField("pre_tags"); public static final ParseField POST_TAGS_FIELD = new ParseField("post_tags"); public static final ParseField FIELDS_FIELD = new ParseField("fields"); diff --git a/server/src/main/java/org/opensearch/search/internal/AliasFilter.java b/server/src/main/java/org/opensearch/search/internal/AliasFilter.java index 408f67f5002d9..30e49bcbfb0bc 100644 --- a/server/src/main/java/org/opensearch/search/internal/AliasFilter.java +++ b/server/src/main/java/org/opensearch/search/internal/AliasFilter.java @@ -41,6 +41,7 @@ import org.opensearch.index.query.Rewriteable; import java.io.IOException; +import java.io.Serializable; import java.util.Arrays; import java.util.Objects; @@ -49,7 +50,7 @@ * * @opensearch.internal */ -public final class AliasFilter implements Writeable, Rewriteable { +public final class AliasFilter implements Writeable, Rewriteable, Serializable { private final String[] aliases; private final QueryBuilder filter; diff --git a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java index bb990e69e7722..5e024c604a912 100644 --- a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java @@ -36,6 +36,7 @@ import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; +import org.opensearch.action.search.ProtobufSearchShardTask; import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.common.unit.TimeValue; @@ -140,6 +141,11 @@ public ShardSearchRequest request() { return in.request(); } + @Override + public ProtobufShardSearchRequest protobufShardSearchRequest() { + return in.protobufShardSearchRequest(); + } + @Override public SearchType searchType() { return in.searchType(); @@ -520,6 +526,11 @@ public SearchShardTask getTask() { return in.getTask(); } + @Override + public ProtobufSearchShardTask getProtobufTask() { + return in.getProtobufTask(); + } + @Override public boolean isCancelled() { return in.isCancelled(); diff --git a/server/src/main/java/org/opensearch/search/internal/ProtobufShardSearchRequest.java b/server/src/main/java/org/opensearch/search/internal/ProtobufShardSearchRequest.java index 99fdf228cfcbf..9a835040b9df6 100644 --- a/server/src/main/java/org/opensearch/search/internal/ProtobufShardSearchRequest.java +++ b/server/src/main/java/org/opensearch/search/internal/ProtobufShardSearchRequest.java @@ -8,12 +8,10 @@ package org.opensearch.search.internal; -import org.opensearch.Version; import org.opensearch.action.IndicesRequest; import org.opensearch.action.OriginalIndices; import org.opensearch.action.search.ProtobufSearchShardTask; import org.opensearch.action.search.ProtobufSearchRequest; -import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.action.support.IndicesOptions; import org.opensearch.cluster.metadata.AliasMetadata; @@ -22,8 +20,6 @@ import org.opensearch.common.Nullable; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.ToXContent; @@ -38,23 +34,24 @@ import org.opensearch.indices.AliasFilterParsingException; import org.opensearch.indices.InvalidAliasNameException; import org.opensearch.search.Scroll; -import org.opensearch.search.SearchSortValuesAndFormats; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.server.proto.ShardSearchRequestProto; import org.opensearch.tasks.ProtobufTask; -import org.opensearch.tasks.Task; import org.opensearch.tasks.ProtobufTaskId; import org.opensearch.transport.TransportRequest; import com.google.protobuf.ByteString; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; import java.util.Collections; +import java.util.EnumSet; import java.util.Arrays; import java.util.Map; import java.util.function.Function; @@ -74,29 +71,29 @@ public class ProtobufShardSearchRequest extends TransportRequest implements Indi public static final ToXContent.Params FORMAT_PARAMS = new ToXContent.MapParams(Collections.singletonMap("pretty", "false")); private ShardSearchRequestProto.ShardSearchRequest shardSearchRequestProto; - private final String clusterAlias; - private final ShardId shardId; - private final int numberOfShards; - private final SearchType searchType; - private final Scroll scroll; - private final float indexBoost; - private final Boolean requestCache; - private final long nowInMillis; - private long inboundNetworkTime; - private long outboundNetworkTime; - private final boolean allowPartialSearchResults; - private final String[] indexRoutings; - private final String preference; - private final OriginalIndices originalIndices; - - private boolean canReturnNullResponseIfMatchNoDocs; - private SearchSortValuesAndFormats bottomSortValues; + // private final String clusterAlias; + // private final ShardId shardId; + // private final int numberOfShards; + // private final SearchType searchType; + // private final Scroll scroll; + // private final float indexBoost; + // private final Boolean requestCache; + // private final long nowInMillis; + // private long inboundNetworkTime; + // private long outboundNetworkTime; + // private final boolean allowPartialSearchResults; + // private final String[] indexRoutings; + // private final String preference; + // private final OriginalIndices originalIndices; + + // private boolean canReturnNullResponseIfMatchNoDocs; + // private SearchSortValuesAndFormats bottomSortValues; // these are the only mutable fields, as they are subject to rewriting - private AliasFilter aliasFilter; - private SearchSourceBuilder source; - private final ShardSearchContextId readerId; - private final TimeValue keepAlive; + // private AliasFilter aliasFilter; + // private SearchSourceBuilder source; + // private final ShardSearchContextId readerId; + // private final TimeValue keepAlive; public ProtobufShardSearchRequest( OriginalIndices originalIndices, @@ -190,7 +187,7 @@ private ProtobufShardSearchRequest( Boolean requestCache, AliasFilter aliasFilter, float indexBoost, - boolean allowPartialSearchResults, + Boolean allowPartialSearchResults, String[] indexRoutings, String preference, Scroll scroll, @@ -199,253 +196,298 @@ private ProtobufShardSearchRequest( ShardSearchContextId readerId, TimeValue keepAlive ) { - this.shardId = shardId; - this.numberOfShards = numberOfShards; - this.searchType = searchType; - this.source = source; - this.requestCache = requestCache; - this.aliasFilter = aliasFilter; - this.indexBoost = indexBoost; - this.allowPartialSearchResults = allowPartialSearchResults; - this.indexRoutings = indexRoutings; - this.preference = preference; - this.scroll = scroll; - this.nowInMillis = nowInMillis; - this.inboundNetworkTime = 0; - this.outboundNetworkTime = 0; - this.clusterAlias = clusterAlias; - this.originalIndices = originalIndices; - this.readerId = readerId; - this.keepAlive = keepAlive; - assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; - // convert shardId to bytes - // ByteArrayOutputStream bos = new ByteArrayOutputStream(); - // ObjectOutputStream oos = new ObjectOutputStream(bos); - // oos.writeObject(shardId); - // oos.flush(); - // byte [] data = bos.toByteArray(); - // ByteString byteString = ByteString.copyFromUtf8(shardId); - // ShardSearchRequestProto.ShardSearchRequest.OriginalIndices.IndicesOptions indicesOptionsProto = - // ShardSearchRequestProto.ShardSearchRequest.OriginalIndices.IndicesOptions. - // if (originalIndices.indicesOptions().allowAliasesToMultipleIndices()) { - - } + // this.shardId = shardId; + // this.numberOfShards = numberOfShards; + // this.searchType = searchType; + // this.source = source; + // this.requestCache = requestCache; + // this.aliasFilter = aliasFilter; + // this.indexBoost = indexBoost; + // this.allowPartialSearchResults = allowPartialSearchResults; + // this.indexRoutings = indexRoutings; + // this.preference = preference; + // this.scroll = scroll; + // this.nowInMillis = nowInMillis; + // this.inboundNetworkTime = 0; + // this.outboundNetworkTime = 0; + // this.clusterAlias = clusterAlias; + // this.originalIndices = originalIndices; + // this.readerId = readerId; + // this.keepAlive = keepAlive; + // assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; ShardSearchRequestProto.ShardSearchRequest.OriginalIndices originalIndicesProto = ShardSearchRequestProto.ShardSearchRequest.OriginalIndices.newBuilder() .addAllIndices(Arrays.stream(originalIndices.indices()).collect(Collectors.toList())) - .setIndicesOptions().build(); + .setIndicesOptions(ShardSearchRequestProto.ShardSearchRequest.OriginalIndices.IndicesOptions.newBuilder() + .setIgnoreUnavailable(originalIndices.indicesOptions().ignoreUnavailable()) + .setAllowNoIndices(originalIndices.indicesOptions().allowNoIndices()) + .setExpandWildcardsOpen(originalIndices.indicesOptions().expandWildcardsOpen()) + .setExpandWildcardsClosed(originalIndices.indicesOptions().expandWildcardsClosed()) + .setExpandWildcardsHidden(originalIndices.indicesOptions().allowAliasesToMultipleIndices()) + .setAllowAliasesToMultipleIndices(originalIndices.indicesOptions().allowAliasesToMultipleIndices()) + .setForbidClosedIndices(originalIndices.indicesOptions().forbidClosedIndices()) + .setIgnoreAliases(originalIndices.indicesOptions().ignoreAliases()) + .setIgnoreThrottled(originalIndices.indicesOptions().ignoreThrottled()) + .build()) + .build(); + ShardSearchRequestProto.ShardSearchRequest.ShardId shardIdProto = ShardSearchRequestProto.ShardSearchRequest.ShardId.newBuilder() + .setShardId(shardId.getId()) + .setHashCode(shardId.hashCode()) + .setIndexName(shardId.getIndexName()) + .setIndexUUID(shardId.getIndex().getUUID()) + .build(); + + ShardSearchRequestProto.ShardSearchRequest.ShardSearchContextId.Builder shardSearchContextId = ShardSearchRequestProto.ShardSearchRequest.ShardSearchContextId.newBuilder(); + System.out.println("Reader id: " + readerId); + if (readerId != null) { + shardSearchContextId.setSessionId(readerId.getSessionId()); + shardSearchContextId.setId(readerId.getId()); + } + + ShardSearchRequestProto.ShardSearchRequest.Builder builder = ShardSearchRequestProto.ShardSearchRequest.newBuilder(); + builder.setOriginalIndices(originalIndicesProto); + builder.setShardId(shardIdProto); + builder.setNumberOfShards(numberOfShards); + builder.setSearchType(ShardSearchRequestProto.ShardSearchRequest.SearchType.QUERY_THEN_FETCH); + builder.setSource(ByteString.copyFrom(convertToBytes(source))); + builder.setInboundNetworkTime(0); + builder.setOutboundNetworkTime(0); + + if (requestCache != null) { + builder.setRequestCache(requestCache); + } + + if (aliasFilter != null) { + builder.setAliasFilter(ByteString.copyFrom(convertToBytes(aliasFilter))); + } + builder.setIndexBoost(indexBoost); + if (allowPartialSearchResults != null) { + builder.setAllowPartialSearchResults(allowPartialSearchResults); + } - this.shardSearchRequestProto = ShardSearchRequestProto.ShardSearchRequest.newBuilder() - .setOriginalIndices(originalIndicesProto).build(); + if (indexRoutings != null) { + builder.addAllIndexRoutings(Arrays.stream(indexRoutings).collect(Collectors.toList())); + } + if (preference != null) { + builder.setPreference(preference); + } - // this.shardSearchRequestProto = ShardSearchRequestProto.ShardSearchRequest.newBuilder().setShardId(shardId) + if (scroll != null) { + builder.setScroll(ByteString.copyFrom(convertToBytes(scroll))); + } + builder.setNowInMillis(nowInMillis); + + if (clusterAlias != null) { + builder.setClusterAlias(clusterAlias); + } + if (readerId != null) { + builder.setReaderId(shardSearchContextId.build()); + } + + System.out.println("Keep alive: " + keepAlive); + if (keepAlive != null) { + builder.setTimeValue(keepAlive.getStringRep()); + } + + this.shardSearchRequestProto = builder.build(); } public ProtobufShardSearchRequest(byte[] in) throws IOException { super(in); - shardId = null; - searchType = null; - numberOfShards = 0; - scroll = null; - source = null; - // if (in.getVersion().before(Version.V_2_0_0)) { - // // types no longer relevant so ignore - // String[] types = in.readStringArray(); - // if (types.length > 0) { - // throw new IllegalStateException("types are no longer supported in ids query but found [" + Arrays.toString(types) + "]"); - // } - // } - aliasFilter = null; - indexBoost = 0; - nowInMillis = 0; - requestCache = false; - // if (in.getVersion().onOrAfter(Version.V_2_0_0)) { - // inboundNetworkTime = in.readVLong(); - // outboundNetworkTime = in.readVLong(); - // } - clusterAlias = ""; - allowPartialSearchResults = false; - indexRoutings = null; - preference = ""; - canReturnNullResponseIfMatchNoDocs = false; - bottomSortValues = null; - readerId = null; - keepAlive = null; - originalIndices = null; - assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; + this.shardSearchRequestProto = ShardSearchRequestProto.ShardSearchRequest.parseFrom(in); } - public ProtobufShardSearchRequest(ProtobufShardSearchRequest clone) { - this.shardId = clone.shardId; - this.searchType = clone.searchType; - this.numberOfShards = clone.numberOfShards; - this.scroll = clone.scroll; - this.source = clone.source; - this.aliasFilter = clone.aliasFilter; - this.indexBoost = clone.indexBoost; - this.nowInMillis = clone.nowInMillis; - this.inboundNetworkTime = clone.inboundNetworkTime; - this.outboundNetworkTime = clone.outboundNetworkTime; - this.requestCache = clone.requestCache; - this.clusterAlias = clone.clusterAlias; - this.allowPartialSearchResults = clone.allowPartialSearchResults; - this.indexRoutings = clone.indexRoutings; - this.preference = clone.preference; - this.canReturnNullResponseIfMatchNoDocs = clone.canReturnNullResponseIfMatchNoDocs; - this.bottomSortValues = clone.bottomSortValues; - this.originalIndices = clone.originalIndices; - this.readerId = clone.readerId; - this.keepAlive = clone.keepAlive; + public ProtobufShardSearchRequest(ShardSearchRequestProto.ShardSearchRequest shardSearchRequest) { + this.shardSearchRequestProto = shardSearchRequest; } - @Override - public void writeTo(OutputStream out) throws IOException { - super.writeTo(out); - // innerWriteTo(out, false); - // OriginalIndices.writeOriginalIndices(originalIndices, out); + public ProtobufShardSearchRequest(ProtobufShardSearchRequest clone) { + this.shardSearchRequestProto = clone.shardSearchRequestProto; } - protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException { - shardId.writeTo(out); - out.writeByte(searchType.id()); - if (!asKey) { - out.writeVInt(numberOfShards); - } - out.writeOptionalWriteable(scroll); - out.writeOptionalWriteable(source); - if (out.getVersion().before(Version.V_2_0_0)) { - // types not supported so send an empty array to previous versions - out.writeStringArray(Strings.EMPTY_ARRAY); - } - aliasFilter.writeTo(out); - out.writeFloat(indexBoost); - if (asKey == false) { - out.writeVLong(nowInMillis); - } - out.writeOptionalBoolean(requestCache); - if (asKey == false && out.getVersion().onOrAfter(Version.V_2_0_0)) { - out.writeVLong(inboundNetworkTime); - out.writeVLong(outboundNetworkTime); - } - out.writeOptionalString(clusterAlias); - out.writeBoolean(allowPartialSearchResults); - if (asKey == false) { - out.writeStringArray(indexRoutings); - out.writeOptionalString(preference); - } - if (asKey == false) { - out.writeBoolean(canReturnNullResponseIfMatchNoDocs); - out.writeOptionalWriteable(bottomSortValues); - } - if (asKey == false) { - out.writeOptionalWriteable(readerId); - out.writeOptionalTimeValue(keepAlive); + private byte[] convertToBytes(Object obj) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try { + ObjectOutputStream oos = new ObjectOutputStream(bos); + oos.writeObject(obj); + oos.flush(); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); } + return bos.toByteArray(); } + @Override + public void writeTo(OutputStream out) throws IOException { + super.writeTo(out); + out.write(this.shardSearchRequestProto.toByteArray()); + } + + // protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException { + // shardId.writeTo(out); + // out.writeByte(searchType.id()); + // if (!asKey) { + // out.writeVInt(numberOfShards); + // } + // out.writeOptionalWriteable(scroll); + // out.writeOptionalWriteable(source); + // if (out.getVersion().before(Version.V_2_0_0)) { + // // types not supported so send an empty array to previous versions + // out.writeStringArray(Strings.EMPTY_ARRAY); + // } + // aliasFilter.writeTo(out); + // out.writeFloat(indexBoost); + // if (asKey == false) { + // out.writeVLong(nowInMillis); + // } + // out.writeOptionalBoolean(requestCache); + // if (asKey == false && out.getVersion().onOrAfter(Version.V_2_0_0)) { + // out.writeVLong(inboundNetworkTime); + // out.writeVLong(outboundNetworkTime); + // } + // out.writeOptionalString(clusterAlias); + // out.writeBoolean(allowPartialSearchResults); + // if (asKey == false) { + // out.writeStringArray(indexRoutings); + // out.writeOptionalString(preference); + // } + // if (asKey == false) { + // out.writeBoolean(canReturnNullResponseIfMatchNoDocs); + // out.writeOptionalWriteable(bottomSortValues); + // } + // if (asKey == false) { + // out.writeOptionalWriteable(readerId); + // out.writeOptionalTimeValue(keepAlive); + // } + // } + @Override public String[] indices() { - if (originalIndices == null) { + if (this.shardSearchRequestProto.getOriginalIndices() == null) { return null; } - return originalIndices.indices(); + return this.shardSearchRequestProto.getOriginalIndices().getIndicesList().toArray(new String[0]); } @Override public IndicesOptions indicesOptions() { - if (originalIndices == null) { + if (this.shardSearchRequestProto.getOriginalIndices() == null) { return null; } - return originalIndices.indicesOptions(); + IndicesOptions indicesOptions = new IndicesOptions(EnumSet.of(IndicesOptions.Option.ALLOW_NO_INDICES, IndicesOptions.Option.FORBID_CLOSED_INDICES, IndicesOptions.Option.IGNORE_THROTTLED), + EnumSet.of(IndicesOptions.WildcardStates.OPEN)); + return indicesOptions; } public ShardId shardId() { - return shardId; + return new ShardId(this.shardSearchRequestProto.getShardId().getIndexName(), this.shardSearchRequestProto.getShardId().getIndexUUID(), this.shardSearchRequestProto.getShardId().getShardId()); } public SearchSourceBuilder source() { - return source; + ByteArrayInputStream in = new ByteArrayInputStream(this.shardSearchRequestProto.getSource().toByteArray()); + try (ObjectInputStream is = new ObjectInputStream(in)) { + return (SearchSourceBuilder) is.readObject(); + } catch (ClassNotFoundException | IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + return null; } public AliasFilter getAliasFilter() { - return aliasFilter; + ByteArrayInputStream in = new ByteArrayInputStream(this.shardSearchRequestProto.getAliasFilter().toByteArray()); + try (ObjectInputStream is = new ObjectInputStream(in)) { + return (AliasFilter) is.readObject(); + } catch (ClassNotFoundException | IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + return null; } public void setAliasFilter(AliasFilter aliasFilter) { - this.aliasFilter = aliasFilter; + this.shardSearchRequestProto.toBuilder().setAliasFilter(ByteString.copyFrom(convertToBytes(aliasFilter))); } public void source(SearchSourceBuilder source) { - this.source = source; + this.shardSearchRequestProto.toBuilder().setSource(ByteString.copyFrom(convertToBytes(source))); } public int numberOfShards() { - return numberOfShards; + return this.shardSearchRequestProto.getNumberOfShards(); } public SearchType searchType() { - return searchType; + return SearchType.QUERY_THEN_FETCH; } public float indexBoost() { - return indexBoost; + return this.shardSearchRequestProto.getIndexBoost(); } public long nowInMillis() { - return nowInMillis; + return this.shardSearchRequestProto.getNowInMillis(); } public long getInboundNetworkTime() { - return inboundNetworkTime; + return this.shardSearchRequestProto.getInboundNetworkTime(); } public void setInboundNetworkTime(long newTime) { - this.inboundNetworkTime = newTime; + this.shardSearchRequestProto.toBuilder().setInboundNetworkTime(newTime); } public long getOutboundNetworkTime() { - return outboundNetworkTime; + return this.shardSearchRequestProto.getOutboundNetworkTime(); } public void setOutboundNetworkTime(long newTime) { - this.outboundNetworkTime = newTime; + this.shardSearchRequestProto.toBuilder().setOutboundNetworkTime(newTime); } public Boolean requestCache() { - return requestCache; + return this.shardSearchRequestProto.getRequestCache(); } public boolean allowPartialSearchResults() { - return allowPartialSearchResults; + return this.shardSearchRequestProto.getAllowPartialSearchResults(); } public Scroll scroll() { - return scroll; + ByteArrayInputStream in = new ByteArrayInputStream(this.shardSearchRequestProto.getScroll().toByteArray()); + try (ObjectInputStream is = new ObjectInputStream(in)) { + return (Scroll) is.readObject(); + } catch (ClassNotFoundException | IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } + return null; } public String[] indexRoutings() { - return indexRoutings; + return this.shardSearchRequestProto.getIndexRoutingsList().toArray(new String[0]); } public String preference() { - return preference; + return this.shardSearchRequestProto.getPreference(); } - /** - * Sets the bottom sort values that can be used by the searcher to filter documents - * that are after it. This value is computed by coordinating nodes that throttles the - * query phase. After a partial merge of successful shards the sort values of the - * bottom top document are passed as an hint on subsequent shard requests. - */ - public void setBottomSortValues(SearchSortValuesAndFormats values) { - this.bottomSortValues = values; - } + // /** + // * Sets the bottom sort values that can be used by the searcher to filter documents + // * that are after it. This value is computed by coordinating nodes that throttles the + // * query phase. After a partial merge of successful shards the sort values of the + // * bottom top document are passed as an hint on subsequent shard requests. + // */ + // public void setBottomSortValues(SearchSortValuesAndFormats values) { + // this.bottomSortValues = values; + // } - public SearchSortValuesAndFormats getBottomSortValues() { - return bottomSortValues; - } + // public SearchSortValuesAndFormats getBottomSortValues() { + // return bottomSortValues; + // } /** * Returns true if the caller can handle null response {@link QuerySearchResult#nullInstance()}. @@ -453,11 +495,11 @@ public SearchSortValuesAndFormats getBottomSortValues() { * response. */ public boolean canReturnNullResponseIfMatchNoDocs() { - return canReturnNullResponseIfMatchNoDocs; + return this.shardSearchRequestProto.getCanReturnNullResponseIfMatchNoDocs(); } public void canReturnNullResponseIfMatchNoDocs(boolean value) { - this.canReturnNullResponseIfMatchNoDocs = value; + this.shardSearchRequestProto.toBuilder().setCanReturnNullResponseIfMatchNoDocs(value); } private static final ThreadLocal scratch = ThreadLocal.withInitial(BytesStreamOutput::new); @@ -467,32 +509,26 @@ public void canReturnNullResponseIfMatchNoDocs(boolean value) { * otherwise, using the most up to date point-in-time reader. */ public ShardSearchContextId readerId() { - return readerId; + System.out.println("Getting readerId"); + if (this.shardSearchRequestProto.hasReaderId() == false) { + System.out.println("Returning null since the readerId is null"); + return null; + } + return new ShardSearchContextId(this.shardSearchRequestProto.getReaderId().getSessionId(), this.shardSearchRequestProto.getReaderId().getId()); } /** * Returns a non-null to specify the time to live of the point-in-time reader that is used to execute this request. */ public TimeValue keepAlive() { - return keepAlive; - } - - /** - * Returns the cache key for this shard search request, based on its content - */ - public BytesReference cacheKey() throws IOException { - BytesStreamOutput out = scratch.get(); - try { - this.innerWriteTo(out, true); - // copy it over since we don't want to share the thread-local bytes in #scratch - return out.copyBytes(); - } finally { - out.reset(); + if (!this.shardSearchRequestProto.hasTimeValue()) { + return null; } + return TimeValue.parseTimeValue(this.shardSearchRequestProto.getTimeValue(), null, "keep_alive"); } public String getClusterAlias() { - return clusterAlias; + return this.shardSearchRequestProto.getClusterAlias(); } @Override @@ -508,8 +544,15 @@ public String getDescription() { public String getMetadataSupplier() { StringBuilder sb = new StringBuilder(); - if (source != null) { - sb.append("source[").append(source.toString(FORMAT_PARAMS)).append("]"); + if (this.shardSearchRequestProto.getSource() != null) { + ByteArrayInputStream in = new ByteArrayInputStream(this.shardSearchRequestProto.getSource().toByteArray()); + try (ObjectInputStream is = new ObjectInputStream(in)) { + SearchSourceBuilder source = (SearchSourceBuilder) is.readObject(); + sb.append("source[").append(source.toString(FORMAT_PARAMS)).append("]"); + } catch (ClassNotFoundException | IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } } else { sb.append("source[]"); } @@ -530,25 +573,33 @@ static class RequestRewritable implements Rewriteable { @Override public Rewriteable rewrite(QueryRewriteContext ctx) throws IOException { - SearchSourceBuilder newSource = request.source() == null ? null : Rewriteable.rewrite(request.source(), ctx); - AliasFilter newAliasFilter = Rewriteable.rewrite(request.getAliasFilter(), ctx); + // System.out.println("Rewriting protobuf request source"); + // SearchSourceBuilder newSource = request.source() == null ? null : Rewriteable.rewrite(request.source(), ctx); + // System.out.println("Rewriting protobuf request source done"); + // System.out.println("Rewriting protobuf request alias filter"); + // AliasFilter newAliasFilter = Rewriteable.rewrite(request.getAliasFilter(), ctx); + // System.out.println("Rewriting protobuf request alias filter done"); + + SearchSourceBuilder newSource = request.source(); + AliasFilter newAliasFilter = request.getAliasFilter(); QueryShardContext shardContext = ctx.convertToShardContext(); FieldSortBuilder primarySort = FieldSortBuilder.getPrimaryFieldSortOrNull(newSource); if (shardContext != null && primarySort != null - && primarySort.isBottomSortShardDisjoint(shardContext, request.getBottomSortValues())) { + // && primarySort.isBottomSortShardDisjoint(shardContext, request.getBottomSortValues()) + ) { assert newSource != null : "source should contain a primary sort field"; newSource = newSource.shallowCopy(); - int trackTotalHitsUpTo = ProtobufSearchRequest.resolveTrackTotalHitsUpTo(request.scroll, request.source); + int trackTotalHitsUpTo = ProtobufSearchRequest.resolveTrackTotalHitsUpTo(request.scroll(), request.source()); if (trackTotalHitsUpTo == TRACK_TOTAL_HITS_DISABLED && newSource.suggest() == null && newSource.aggregations() == null) { newSource.query(new MatchNoneQueryBuilder()); } else { newSource.size(0); } request.source(newSource); - request.setBottomSortValues(null); + // request.setBottomSortValues(null); } if (newSource == request.source() && newAliasFilter == request.getAliasFilter()) { @@ -614,4 +665,9 @@ public static QueryBuilder parseAliasFilter( return combined; } } + + public ShardSearchRequestProto.ShardSearchRequest request() { + return this.shardSearchRequestProto; + } + } diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index c2f81b0d4b8b5..b20d840dba1e5 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -35,6 +35,7 @@ import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; +import org.opensearch.action.search.ProtobufSearchShardTask; import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.common.Nullable; @@ -123,6 +124,8 @@ protected SearchContext() {} public abstract SearchShardTask getTask(); + public abstract ProtobufSearchShardTask getProtobufTask(); + public abstract boolean isCancelled(); public boolean isSearchTimedOut() { @@ -162,6 +165,8 @@ public final void close() { public abstract ShardSearchRequest request(); + public abstract ProtobufShardSearchRequest protobufShardSearchRequest(); + public abstract SearchType searchType(); public abstract SearchShardTarget shardTarget(); @@ -411,6 +416,9 @@ public void addReleasable(Releasable releasable) { * @return true if the request contains only suggest */ public final boolean hasOnlySuggest() { + if (request() == null) { + return protobufShardSearchRequest().source() != null && protobufShardSearchRequest().source().isSuggestOnly(); + } return request().source() != null && request().source().isSuggestOnly(); } diff --git a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java index d2f6bc234e752..e0e3eaf170d58 100644 --- a/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/opensearch/search/internal/ShardSearchRequest.java @@ -262,6 +262,7 @@ public ShardSearchRequest(StreamInput in) throws IOException { canReturnNullResponseIfMatchNoDocs = in.readBoolean(); bottomSortValues = in.readOptionalWriteable(SearchSortValuesAndFormats::new); readerId = in.readOptionalWriteable(ShardSearchContextId::new); + System.out.println("Reader id: " + readerId); keepAlive = in.readOptionalTimeValue(); originalIndices = OriginalIndices.readOriginalIndices(in); assert keepAlive == null || readerId != null : "readerId: " + readerId + " keepAlive: " + keepAlive; @@ -520,8 +521,12 @@ static class RequestRewritable implements Rewriteable { @Override public Rewriteable rewrite(QueryRewriteContext ctx) throws IOException { + System.out.println("Rewriting request source"); SearchSourceBuilder newSource = request.source() == null ? null : Rewriteable.rewrite(request.source(), ctx); + System.out.println("Rewriting request source done"); + System.out.println("Rewriting request alias filter"); AliasFilter newAliasFilter = Rewriteable.rewrite(request.getAliasFilter(), ctx); + System.out.println("Rewriting request alias filter done"); QueryShardContext shardContext = ctx.convertToShardContext(); diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index 8418fdca2f777..f1e57ceda2987 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -46,6 +46,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.ProtobufSearchShardTask; import org.opensearch.action.search.SearchShardTask; import org.opensearch.common.Booleans; import org.opensearch.common.lucene.Lucene; @@ -130,6 +131,27 @@ public void preProcess(SearchContext context) { } } + public void preProcessProtobuf(SearchContext context) { + final Runnable cancellation; + if (context.lowLevelCancellation()) { + cancellation = context.searcher().addQueryCancellation(() -> { + ProtobufSearchShardTask task = context.getProtobufTask(); + if (task != null && task.isCancelled()) { + throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled()); + } + }); + } else { + cancellation = null; + } + try { + context.preProcess(true); + } finally { + if (cancellation != null) { + context.searcher().removeQueryCancellation(cancellation); + } + } + } + public void execute(SearchContext searchContext) throws QueryPhaseExecutionException { if (searchContext.hasOnlySuggest()) { suggestProcessor.process(searchContext); diff --git a/server/src/main/java/org/opensearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/opensearch/search/rescore/RescorerBuilder.java index d4094298f0c5f..3a28e14bcd234 100644 --- a/server/src/main/java/org/opensearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/opensearch/search/rescore/RescorerBuilder.java @@ -46,6 +46,7 @@ import org.opensearch.index.query.Rewriteable; import java.io.IOException; +import java.io.Serializable; import java.util.Objects; /** @@ -57,7 +58,8 @@ public abstract class RescorerBuilder> implements NamedWriteable, ToXContentObject, - Rewriteable> { + Rewriteable>, + Serializable { public static final int DEFAULT_WINDOW_SIZE = 10; protected Integer windowSize; diff --git a/server/src/main/java/org/opensearch/search/searchafter/SearchAfterBuilder.java b/server/src/main/java/org/opensearch/search/searchafter/SearchAfterBuilder.java index 516b388ce2186..001059f90a999 100644 --- a/server/src/main/java/org/opensearch/search/searchafter/SearchAfterBuilder.java +++ b/server/src/main/java/org/opensearch/search/searchafter/SearchAfterBuilder.java @@ -53,6 +53,7 @@ import org.opensearch.search.sort.SortAndFormats; import java.io.IOException; +import java.io.Serializable; import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; @@ -64,7 +65,7 @@ * * @opensearch.internal */ -public class SearchAfterBuilder implements ToXContentObject, Writeable { +public class SearchAfterBuilder implements ToXContentObject, Writeable, Serializable { public static final ParseField SEARCH_AFTER = new ParseField("search_after"); private static final Object[] EMPTY_SORT_VALUES = new Object[0]; diff --git a/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java b/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java index 31e03f5ef511e..612231845d810 100644 --- a/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java +++ b/server/src/main/java/org/opensearch/search/slice/SliceBuilder.java @@ -60,6 +60,7 @@ import org.opensearch.search.internal.ShardSearchRequest; import java.io.IOException; +import java.io.Serializable; import java.util.Collections; import java.util.Map; import java.util.Objects; @@ -78,7 +79,7 @@ * * @opensearch.internal */ -public class SliceBuilder implements Writeable, ToXContentObject { +public class SliceBuilder implements Writeable, ToXContentObject, Serializable { private static final DeprecationLogger DEPRECATION_LOG = DeprecationLogger.getLogger(SliceBuilder.class); diff --git a/server/src/main/java/org/opensearch/search/sort/SortBuilder.java b/server/src/main/java/org/opensearch/search/sort/SortBuilder.java index 5bffb8a9ca56e..0e5d5ceccf012 100644 --- a/server/src/main/java/org/opensearch/search/sort/SortBuilder.java +++ b/server/src/main/java/org/opensearch/search/sort/SortBuilder.java @@ -54,6 +54,7 @@ import org.opensearch.search.DocValueFormat; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -66,7 +67,7 @@ * * @opensearch.internal */ -public abstract class SortBuilder> implements NamedWriteable, ToXContentObject, Rewriteable> { +public abstract class SortBuilder> implements NamedWriteable, ToXContentObject, Rewriteable>, Serializable { protected SortOrder order = SortOrder.ASC; diff --git a/server/src/main/java/org/opensearch/search/suggest/SuggestBuilder.java b/server/src/main/java/org/opensearch/search/suggest/SuggestBuilder.java index 3daa4ac019cd5..9203a3dbdbd92 100644 --- a/server/src/main/java/org/opensearch/search/suggest/SuggestBuilder.java +++ b/server/src/main/java/org/opensearch/search/suggest/SuggestBuilder.java @@ -47,6 +47,7 @@ import org.opensearch.search.suggest.SuggestionSearchContext.SuggestionContext; import java.io.IOException; +import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; @@ -61,7 +62,7 @@ * * @opensearch.internal */ -public class SuggestBuilder implements Writeable, ToXContentObject { +public class SuggestBuilder implements Writeable, ToXContentObject, Serializable { protected static final ParseField GLOBAL_TEXT_FIELD = new ParseField("text"); private String globalText; diff --git a/server/src/main/java/org/opensearch/transport/InboundHandler.java b/server/src/main/java/org/opensearch/transport/InboundHandler.java index d0d765226cdea..cb0bec789a893 100644 --- a/server/src/main/java/org/opensearch/transport/InboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/InboundHandler.java @@ -62,7 +62,9 @@ import org.opensearch.server.proto.NodesStatsProto.NodesStats; import org.opensearch.server.proto.NodesStatsRequestProto.NodesStatsRequest; import org.opensearch.server.proto.MessageProto.OutboundInboundMessage; +import org.opensearch.server.proto.ShardSearchRequestProto.ShardSearchRequest; import org.opensearch.core.transport.TransportResponse; +import org.opensearch.search.internal.ProtobufShardSearchRequest; import org.opensearch.threadpool.ThreadPool; import java.io.EOFException; @@ -328,6 +330,23 @@ private void handleRequestProtobuf( } catch (Exception e) { sendErrorResponse(action, transportChannel, e); } + } else if (receivedMessage.hasShardSearchRequest()) { + System.out.println("ShardSearchRequest received"); + System.out.println(receivedMessage.getShardSearchRequest()); + final ShardSearchRequest shardSearchReq = receivedMessage.getShardSearchRequest(); + ProtobufShardSearchRequest protobufShardSearchRequest = new ProtobufShardSearchRequest(shardSearchReq); + final T request = (T) protobufShardSearchRequest; + request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); + final String executor = reg.getExecutor(); + if (ThreadPool.Names.SAME.equals(executor)) { + try { + reg.processMessageReceived(request, transportChannel); + } catch (Exception e) { + sendErrorResponse(reg.getAction(), transportChannel, e); + } + } else { + threadPool.executor(executor).execute(new ProtobufRequestHandler<>(reg, request, transportChannel)); + } } } catch (Exception e) { sendErrorResponse(action, transportChannel, e); diff --git a/server/src/main/java/org/opensearch/transport/OutboundHandler.java b/server/src/main/java/org/opensearch/transport/OutboundHandler.java index 056a81e78454b..04612e5ccd79c 100644 --- a/server/src/main/java/org/opensearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/opensearch/transport/OutboundHandler.java @@ -189,7 +189,9 @@ void sendRequest( ); sendProtobufMessage(channel, protobufMessage, listener); } else if (canonicalName.contains("ProtobufShardSearch")){ + System.out.println("OutboundHandler sendRequest for shardsearchrequest"); ProtobufShardSearchRequest protobufShardSearchRequest = (ProtobufShardSearchRequest) request; + System.out.println("ProtobufShardSearchRequest: " + protobufShardSearchRequest); byte[] bytes = new byte[1]; bytes[0] = 0; ProtobufOutboundMessage protobufMessage = new ProtobufOutboundMessage( @@ -197,10 +199,11 @@ void sendRequest( bytes, Version.CURRENT, threadPool.getThreadContext(), - protobufClusterStateRequest.request(), + protobufShardSearchRequest.request(), features, action ); + System.out.println("ProtobufOutboundMessage: " + protobufMessage); sendProtobufMessage(channel, protobufMessage, listener); } else { sendMessage(channel, message, listener); diff --git a/server/src/main/java/org/opensearch/transport/ProtobufOutboundMessage.java b/server/src/main/java/org/opensearch/transport/ProtobufOutboundMessage.java index b034035997627..e4872c7d21a97 100644 --- a/server/src/main/java/org/opensearch/transport/ProtobufOutboundMessage.java +++ b/server/src/main/java/org/opensearch/transport/ProtobufOutboundMessage.java @@ -25,6 +25,7 @@ import org.opensearch.server.proto.NodesInfoRequestProto.NodesInfoRequest; import org.opensearch.server.proto.NodesStatsProto.NodesStats; import org.opensearch.server.proto.NodesStatsRequestProto.NodesStatsRequest; +import org.opensearch.server.proto.ShardSearchRequestProto.ShardSearchRequest; import org.opensearch.server.proto.MessageProto.OutboundInboundMessage; import org.opensearch.server.proto.MessageProto.OutboundInboundMessage.Header; import org.opensearch.server.proto.MessageProto.OutboundInboundMessage.ResponseHandlersList; @@ -273,6 +274,45 @@ public ProtobufOutboundMessage( .setIsProtobuf(true) .build(); + } + + public ProtobufOutboundMessage( + long requestId, + byte[] status, + Version version, + ThreadContext threadContext, + ShardSearchRequest shardSearchReq, + String[] features, + String action + ) { + Header header = Header.newBuilder() + .addAllPrefix(Arrays.asList(ByteString.copyFrom(PREFIX))) + .setRequestId(requestId) + .setStatus(ByteString.copyFrom(status)) + .setVersionId(version.id) + .build(); + Map requestHeaders = threadContext.getHeaders(); + Map> responseHeaders = threadContext.getResponseHeaders(); + Map responseHandlers = new HashMap<>(); + for (Map.Entry> entry : responseHeaders.entrySet()) { + String key = entry.getKey(); + List value = entry.getValue(); + ResponseHandlersList responseHandlersList = ResponseHandlersList.newBuilder().addAllSetOfResponseHandlers(value).build(); + responseHandlers.put(key, responseHandlersList); + } + this.message = OutboundInboundMessage.newBuilder() + .setHeader(header) + .putAllRequestHeaders(requestHeaders) + .putAllResponseHandlers(responseHandlers) + .setVersion(version.toString()) + .setStatus(ByteString.copyFrom(status)) + .setRequestId(requestId) + .setShardSearchRequest(shardSearchReq) + .setAction(action) + .addAllFeatures(Arrays.asList(features)) + .setIsProtobuf(true) + .build(); + } public ProtobufOutboundMessage(byte[] data) throws InvalidProtocolBufferException { diff --git a/server/src/main/proto/server/MessageProto.proto b/server/src/main/proto/server/MessageProto.proto index e7fb8b821b737..033751f5ef72a 100644 --- a/server/src/main/proto/server/MessageProto.proto +++ b/server/src/main/proto/server/MessageProto.proto @@ -18,6 +18,7 @@ import "server/NodesInfoRequestProto.proto"; import "server/NodesInfoProto.proto"; import "server/NodesStatsRequestProto.proto"; import "server/NodesStatsProto.proto"; +import "server/ShardSearchRequestProto.proto"; option java_outer_classname = "MessageProto"; @@ -37,8 +38,9 @@ message OutboundInboundMessage { NodesInfo nodesInfoResponse = 12; NodesStatsRequest nodesStatsRequest = 13; NodesStats nodesStatsResponse = 14; + ShardSearchRequest shardSearchRequest = 15; } - bool isProtobuf = 15; + bool isProtobuf = 16; message Header { repeated bytes prefix = 1; diff --git a/server/src/main/proto/server/ShardSearchRequestProto.proto b/server/src/main/proto/server/ShardSearchRequestProto.proto index ce0bf6a9fb809..d7fbff6855c4b 100644 --- a/server/src/main/proto/server/ShardSearchRequestProto.proto +++ b/server/src/main/proto/server/ShardSearchRequestProto.proto @@ -29,23 +29,26 @@ message ShardSearchRequest { bytes scroll = 12; int64 nowInMillis = 13; optional string clusterAlias = 14; - bytes readerId = 15; - string timeValue = 16; + optional ShardSearchContextId readerId = 15; + optional string timeValue = 16; + int64 inboundNetworkTime = 17; + int64 outboundNetworkTime = 18; + bool canReturnNullResponseIfMatchNoDocs = 19; message OriginalIndices { repeated string indices = 1; IndicesOptions indicesOptions = 2; - enum IndicesOptions { - IGNORE_UNAVAILABLE = 0; - ALLOW_NO_INDICES = 1; - EXPAND_WILDCARDS_OPEN = 2; - EXPAND_WILDCARDS_CLOSED = 3; - EXPAND_WILDCARDS_HIDDEN = 4; - ALLOW_ALIASES_TO_MULTIPLE_INDICES = 5; - FORBID_CLOSED_INDICES = 6; - IGNORE_ALIASES = 7; - IGNORE_THROTTLED = 8; + message IndicesOptions { + bool ignoreUnavailable = 1; + bool allowNoIndices = 2; + bool expandWildcardsOpen = 3; + bool expandWildcardsClosed = 4; + bool expandWildcardsHidden = 5; + bool allowAliasesToMultipleIndices = 6; + bool forbidClosedIndices = 7; + bool ignoreAliases = 8; + bool ignoreThrottled = 9; } } @@ -60,4 +63,9 @@ message ShardSearchRequest { QUERY_THEN_FETCH = 0; DFS_QUERY_THEN_FETCH = 1; } + + message ShardSearchContextId { + string sessionId = 1; + int64 id = 2; + } } diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index 0ce63fbe2977e..7f4d4b4654f85 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -36,6 +36,7 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.ProtobufSearchShardTask; import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchType; import org.opensearch.common.unit.TimeValue; @@ -66,6 +67,7 @@ import org.opensearch.search.fetch.subphase.ScriptFieldsContext; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.ProtobufShardSearchRequest; import org.opensearch.search.internal.ReaderContext; import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.SearchContext; @@ -103,6 +105,7 @@ public class TestSearchContext extends SearchContext { Query query; Float minScore; SearchShardTask task; + ProtobufSearchShardTask protobufTask; SortAndFormats sort; boolean trackScores = false; int trackTotalHitsUpTo = SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO; @@ -191,6 +194,11 @@ public ShardSearchRequest request() { return null; } + @Override + public ProtobufShardSearchRequest protobufShardSearchRequest() { + return null; + } + @Override public SearchType searchType() { return null; @@ -644,6 +652,11 @@ public SearchShardTask getTask() { return task; } + @Override + public ProtobufSearchShardTask getProtobufTask() { + return protobufTask; + } + @Override public boolean isCancelled() { return task.isCancelled();