Skip to content

Commit

Permalink
Working protobuf search requests and node to node communication for t…
Browse files Browse the repository at this point in the history
…hose requests

Signed-off-by: Vacha Shah <[email protected]>
  • Loading branch information
VachaShah committed Nov 14, 2023
1 parent 86090b0 commit 3153068
Show file tree
Hide file tree
Showing 33 changed files with 513 additions and 273 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.common.unit;

import java.io.Serializable;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
Expand All @@ -41,7 +42,7 @@
*
* @opensearch.api
*/
public class TimeValue implements Comparable<TimeValue> {
public class TimeValue implements Comparable<TimeValue>, Serializable {

/** How many nano-seconds in one milli-second */
public static final long NSEC_PER_MSEC = TimeUnit.NANOSECONDS.convert(1, TimeUnit.MILLISECONDS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ private ProtobufShardSearchRequest rewriteShardSearchRequest(ProtobufShardSearch

// set the current best bottom field doc
if (bottomSortCollector.getBottomSortValues() != null) {
request.setBottomSortValues(bottomSortCollector.getBottomSortValues());
// request.setBottomSortValues(bottomSortCollector.getBottomSortValues());
System.out.println("Bottom sort values is not null......now what????");
}
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
import org.opensearch.core.xcontent.ToXContentObject;

import java.io.IOException;
import java.io.Serializable;

/**
* Foundation class for all OpenSearch query builders
*
* @opensearch.internal
*/
public interface QueryBuilder extends NamedWriteable, ToXContentObject, Rewriteable<QueryBuilder> {
public interface QueryBuilder extends NamedWriteable, ToXContentObject, Rewriteable<QueryBuilder>, Serializable {

/**
* Converts this QueryBuilder to a lucene {@link Query}.
Expand Down
47 changes: 25 additions & 22 deletions server/src/main/java/org/opensearch/index/query/Rewriteable.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,34 @@ static <T extends Rewriteable<T>> void rewriteAndFetch(
ActionListener<T> rewriteResponse,
int iteration
) {
System.out.println("In rewriteAndFetch");
System.out.println("Original: " + original.getClass());
System.out.println("Context: " + context.getClass());
T builder = original;
try {
for (T rewrittenBuilder = builder.rewrite(context); rewrittenBuilder != builder; rewrittenBuilder = builder.rewrite(context)) {
builder = rewrittenBuilder;
if (iteration++ >= MAX_REWRITE_ROUNDS) {
// this is some protection against user provided queries if they don't obey the contract of rewrite we allow 16 rounds
// and then we fail to prevent infinite loops
throw new IllegalStateException(
"too many rewrite rounds, rewriteable might return new objects even if they are not " + "rewritten"
);
}
if (context.hasAsyncActions()) {
T finalBuilder = builder;
final int currentIterationNumber = iteration;
context.executeAsyncActions(
ActionListener.wrap(
n -> rewriteAndFetch(finalBuilder, context, rewriteResponse, currentIterationNumber),
rewriteResponse::onFailure
)
);
return;
}
}
// for (T rewrittenBuilder = builder.rewrite(context); rewrittenBuilder != builder; rewrittenBuilder = builder.rewrite(context)) {
// builder = rewrittenBuilder;
// if (iteration++ >= MAX_REWRITE_ROUNDS) {
// // this is some protection against user provided queries if they don't obey the contract of rewrite we allow 16 rounds
// // and then we fail to prevent infinite loops
// throw new IllegalStateException(
// "too many rewrite rounds, rewriteable might return new objects even if they are not " + "rewritten"
// );
// }
// if (context.hasAsyncActions()) {
// T finalBuilder = builder;
// final int currentIterationNumber = iteration;
// context.executeAsyncActions(
// ActionListener.wrap(
// n -> rewriteAndFetch(finalBuilder, context, rewriteResponse, currentIterationNumber),
// rewriteResponse::onFailure
// )
// );
// return;
// }
// }
rewriteResponse.onResponse(builder);
} catch (IOException | IllegalArgumentException | ParsingException ex) {
} catch (IllegalArgumentException | ParsingException ex) {
rewriteResponse.onFailure(ex);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,12 @@ public void preProcess(boolean rewrite) {
}
// initialize the filtering alias based on the provided filters
try {
final QueryBuilder queryBuilder = request.getAliasFilter().getQueryBuilder();
final QueryBuilder queryBuilder;
if (request == null) {
queryBuilder = protobufShardSearchRequest.getAliasFilter().getQueryBuilder();
} else {
queryBuilder = request.getAliasFilter().getQueryBuilder();
}
aliasFilter = queryBuilder == null ? null : queryBuilder.toQuery(queryShardContext);
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand Down Expand Up @@ -465,6 +470,11 @@ public ShardSearchRequest request() {
return this.request;
}

@Override
public ProtobufShardSearchRequest protobufShardSearchRequest() {
return this.protobufShardSearchRequest;
}

@Override
public SearchType searchType() {
return this.searchType;
Expand Down Expand Up @@ -969,8 +979,16 @@ public SearchShardTask getTask() {
return task;
}

@Override
public ProtobufSearchShardTask getProtobufTask() {
return protobufSearchShardTask;
}

@Override
public boolean isCancelled() {
if (task == null) {
return protobufSearchShardTask.isCancelled();
}
return task.isCancelled();
}

Expand Down
3 changes: 2 additions & 1 deletion server/src/main/java/org/opensearch/search/Scroll.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.common.unit.TimeValue;

import java.io.IOException;
import java.io.Serializable;
import java.util.Objects;

/**
Expand All @@ -46,7 +47,7 @@
*
* @opensearch.internal
*/
public final class Scroll implements Writeable {
public final class Scroll implements Writeable, Serializable {

private final TimeValue keepAlive;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

package org.opensearch.search;

import java.io.Serializable;

import org.opensearch.common.CheckedFunction;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -58,7 +60,7 @@
*
* @opensearch.internal
*/
public abstract class SearchExtBuilder implements NamedWriteable, ToXContentFragment {
public abstract class SearchExtBuilder implements NamedWriteable, ToXContentFragment, Serializable {

public abstract int hashCode();

Expand Down
14 changes: 12 additions & 2 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ public void onFailure(Exception exc) {
}

private IndexShard getShard(ShardSearchRequest request) {
System.out.println("getShard");
if (request.readerId() != null) {
return findReaderContext(request.readerId(), request).indexShard();
} else {
Expand All @@ -632,6 +633,7 @@ private IndexShard getShard(ShardSearchRequest request) {
}

private IndexShard getShardProtobuf(ProtobufShardSearchRequest request) {
System.out.println("getShardProtobuf");
if (request.readerId() != null) {
return findReaderContext(request.readerId(), request).indexShard();
} else {
Expand Down Expand Up @@ -939,13 +941,17 @@ public void executeFetchPhaseProtobuf(ProtobufShardFetchRequest request, Protobu
}

private ReaderContext getReaderContext(ShardSearchContextId id) {
System.out.println("getReaderContext");
System.out.println(id.getSessionId());
if (sessionId.equals(id.getSessionId()) == false && id.getSessionId().isEmpty() == false) {
throw new SearchContextMissingException(id);
}
return activeReaders.get(id.getId());
}

private ReaderContext findReaderContext(ShardSearchContextId id, TransportRequest request) throws SearchContextMissingException {
System.out.println("findReaderContext");
System.out.println(id);
final ReaderContext reader = getReaderContext(id);
if (reader == null) {
throw new SearchContextMissingException(id);
Expand Down Expand Up @@ -1357,7 +1363,7 @@ private DefaultSearchContext createSearchContextProtobuf(ReaderContext reader, P
// might end up with incorrect state since we are using now() or script services
// during rewrite and normalized / evaluate templates etc.
QueryShardContext context = new QueryShardContext(searchContext.getQueryShardContext());
Rewriteable.rewrite(request.getRewriteable(), context, true);
// Rewriteable.rewrite(request.getRewriteable(), context, true);
assert searchContext.getQueryShardContext().isCacheable();
success = true;
} finally {
Expand Down Expand Up @@ -1746,7 +1752,7 @@ private void shortcutDocIdsToLoad(SearchContext context) {
} else {
completionSuggestions = Collections.emptyList();
}
if (context.request().scroll() != null) {
if ((context.request() != null && context.request().scroll() != null) || (context.protobufShardSearchRequest() != null && context.protobufShardSearchRequest().scroll() != null)) {
TopDocs topDocs = context.queryResult().topDocs().topDocs;
docIdsToLoad = new int[topDocs.scoreDocs.length + numSuggestDocs];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
Expand Down Expand Up @@ -1991,6 +1997,8 @@ public static boolean canRewriteToMatchNone(SearchSourceBuilder source) {
}

private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest request, ActionListener<ShardSearchRequest> listener) {
System.out.println("SearchService rewriteAndFetchShardRequest");
System.out.println("ShardSearchRequest: " + request);
ActionListener<Rewriteable> actionListener = ActionListener.wrap(r -> {
if (request.readerId() != null) {
listener.onResponse(request);
Expand All @@ -2006,6 +2014,8 @@ private void rewriteAndFetchShardRequest(IndexShard shard, ShardSearchRequest re
}

private void rewriteAndFetchShardRequestProtobuf(IndexShard shard, ProtobufShardSearchRequest request, ActionListener<ProtobufShardSearchRequest> listener) {
System.out.println("SearchService rewriteAndFetchShardRequestProtobuf");
System.out.println("ProtobufShardSearchRequest: " + request);
ActionListener<Rewriteable> actionListener = ActionListener.wrap(r -> {
if (request.readerId() != null) {
listener.onResponse(request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.opensearch.search.profile.aggregation.ProfilingAggregator;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -333,7 +334,7 @@ public int countAggregators() {
*
* @opensearch.internal
*/
public static class Builder implements Writeable, ToXContentObject {
public static class Builder implements Writeable, ToXContentObject, Serializable {
private final Set<String> names = new HashSet<>();

// Using LinkedHashSets to preserve the order of insertion, that makes the results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.io.Serializable;
import java.util.Objects;

/**
Expand All @@ -52,7 +53,7 @@
*
* @opensearch.internal
*/
public final class PointInTimeBuilder implements Writeable, ToXContentObject {
public final class PointInTimeBuilder implements Writeable, ToXContentObject, Serializable {
private static final ParseField ID_FIELD = new ParseField("id");
private static final ParseField KEEP_ALIVE_FIELD = new ParseField("keep_alive");
private static final ObjectParser<XContentParams, Void> PARSER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
import org.opensearch.search.suggest.SuggestBuilder;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -93,7 +94,7 @@
*
* @opensearch.internal
*/
public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable<SearchSourceBuilder> {
public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable<SearchSourceBuilder>, Serializable {
private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class);

public static final ParseField FROM_FIELD = new ParseField("from");
Expand Down Expand Up @@ -1495,7 +1496,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
*
* @opensearch.internal
*/
public static class IndexBoost implements Writeable, ToXContentObject {
public static class IndexBoost implements Writeable, ToXContentObject, Serializable {
private final String index;
private final float boost;

Expand Down Expand Up @@ -1596,7 +1597,7 @@ public boolean equals(Object obj) {
*
* @opensearch.internal
*/
public static class ScriptField implements Writeable, ToXContentFragment {
public static class ScriptField implements Writeable, ToXContentFragment, Serializable {

private final boolean ignoreFailure;
private final String fieldName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.opensearch.index.query.QueryShardContext;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -59,7 +60,7 @@
*
* @opensearch.internal
*/
public class CollapseBuilder implements Writeable, ToXContentObject {
public class CollapseBuilder implements Writeable, ToXContentObject, Serializable {
public static final ParseField FIELD_FIELD = new ParseField("field");
public static final ParseField INNER_HITS_FIELD = new ParseField("inner_hits");
public static final ParseField MAX_CONCURRENT_GROUP_REQUESTS_FIELD = new ParseField("max_concurrent_group_searches");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -53,7 +54,7 @@
*
* @opensearch.internal
*/
public class StoredFieldsContext implements Writeable {
public class StoredFieldsContext implements Writeable, Serializable {
public static final String _NONE_ = "_none_";

private final List<String> fieldNames;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.rest.RestRequest;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand All @@ -57,7 +58,7 @@
*
* @opensearch.internal
*/
public class FetchSourceContext implements Writeable, ToXContentObject {
public class FetchSourceContext implements Writeable, ToXContentObject, Serializable {

public static final ParseField INCLUDES_FIELD = new ParseField("includes", "include");
public static final ParseField EXCLUDES_FIELD = new ParseField("excludes", "exclude");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.io.Serializable;
import java.util.Objects;

/**
Expand All @@ -52,7 +53,7 @@
*
* @opensearch.internal
*/
public final class FieldAndFormat implements Writeable, ToXContentObject {
public final class FieldAndFormat implements Writeable, ToXContentObject, Serializable {
private static final ParseField FIELD_FIELD = new ParseField("field");
private static final ParseField FORMAT_FIELD = new ParseField("format");

Expand Down
Loading

0 comments on commit 3153068

Please sign in to comment.