Skip to content

Commit

Permalink
RankEvalRequest should implement IndicesRequest (#29188)
Browse files Browse the repository at this point in the history
Change RankEvalRequest to implement IndicesRequest, so it gets treated
in a similar fashion to regular search requests e.g. by security.
  • Loading branch information
Christoph Büscher authored Mar 22, 2018
1 parent d6d3fb3 commit e4b3007
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ static Request existsAlias(GetAliasesRequest getAliasesRequest) {
}

static Request rankEval(RankEvalRequest rankEvalRequest) throws IOException {
String endpoint = endpoint(rankEvalRequest.getIndices(), Strings.EMPTY_ARRAY, "_rank_eval");
String endpoint = endpoint(rankEvalRequest.indices(), Strings.EMPTY_ARRAY, "_rank_eval");
HttpEntity entity = createEntity(rankEvalRequest.getRankEvalSpec(), REQUEST_BODY_CONTENT_TYPE);
return new Request(HttpGet.METHOD_NAME, endpoint, Collections.emptyMap(), entity);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,47 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* Request to perform a search ranking evaluation.
*/
public class RankEvalRequest extends ActionRequest {
public class RankEvalRequest extends ActionRequest implements IndicesRequest.Replaceable {

private RankEvalSpec rankingEvaluationSpec;

private IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS;
private String[] indices = Strings.EMPTY_ARRAY;

public RankEvalRequest(RankEvalSpec rankingEvaluationSpec, String[] indices) {
this.rankingEvaluationSpec = rankingEvaluationSpec;
setIndices(indices);
this.rankingEvaluationSpec = Objects.requireNonNull(rankingEvaluationSpec, "ranking evaluation specification must not be null");
indices(indices);
}

RankEvalRequest(StreamInput in) throws IOException {
super.readFrom(in);
rankingEvaluationSpec = new RankEvalSpec(in);
if (in.getVersion().onOrAfter(Version.V_6_3_0)) {
indices = in.readStringArray();
indicesOptions = IndicesOptions.readIndicesOptions(in);
} else {
// readStringArray uses readVInt for size, we used readInt in 6.2
int indicesSize = in.readInt();
String[] indices = new String[indicesSize];
for (int i = 0; i < indicesSize; i++) {
indices[i] = in.readString();
}
// no indices options yet
}
}

RankEvalRequest() {
Expand Down Expand Up @@ -72,7 +95,8 @@ public void setRankEvalSpec(RankEvalSpec task) {
/**
* Sets the indices the search will be executed on.
*/
public RankEvalRequest setIndices(String... indices) {
@Override
public RankEvalRequest indices(String... indices) {
Objects.requireNonNull(indices, "indices must not be null");
for (String index : indices) {
Objects.requireNonNull(index, "index must not be null");
Expand All @@ -84,24 +108,23 @@ public RankEvalRequest setIndices(String... indices) {
/**
* @return the indices for this request
*/
public String[] getIndices() {
@Override
public String[] indices() {
return indices;
}

@Override
public IndicesOptions indicesOptions() {
return indicesOptions;
}

public void indicesOptions(IndicesOptions indicesOptions) {
this.indicesOptions = Objects.requireNonNull(indicesOptions, "indicesOptions must not be null");
}

@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
rankingEvaluationSpec = new RankEvalSpec(in);
if (in.getVersion().onOrAfter(Version.V_6_3_0)) {
indices = in.readStringArray();
} else {
// readStringArray uses readVInt for size, we used readInt in 6.2
int indicesSize = in.readInt();
String[] indices = new String[indicesSize];
for (int i = 0; i < indicesSize; i++) {
indices[i] = in.readString();
}
}
throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable");
}

@Override
Expand All @@ -110,12 +133,33 @@ public void writeTo(StreamOutput out) throws IOException {
rankingEvaluationSpec.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_6_3_0)) {
out.writeStringArray(indices);
indicesOptions.writeIndicesOptions(out);
} else {
// writeStringArray uses writeVInt for size, we used writeInt in 6.2
out.writeInt(indices.length);
for (String index : indices) {
out.writeString(index);
}
// no indices options yet
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
RankEvalRequest that = (RankEvalRequest) o;
return Objects.equals(indicesOptions, that.indicesOptions) &&
Arrays.equals(indices, that.indices) &&
Objects.equals(rankingEvaluationSpec, that.rankingEvaluationSpec);
}

@Override
public int hashCode() {
return Objects.hash(indicesOptions, Arrays.hashCode(indices), rankingEvaluationSpec);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}

private static void parseRankEvalRequest(RankEvalRequest rankEvalRequest, RestRequest request, XContentParser parser) {
rankEvalRequest.setIndices(Strings.splitStringByCommaToArray(request.param("index")));
rankEvalRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
RankEvalSpec spec = RankEvalSpec.parse(parser);
rankEvalRequest.setRankEvalSpec(spec);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ public class TransportRankEvalAction extends HandledTransportAction<RankEvalRequ
public TransportRankEvalAction(Settings settings, ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, Client client, TransportService transportService,
ScriptService scriptService, NamedXContentRegistry namedXContentRegistry) {
super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver,
RankEvalRequest::new);
super(settings, RankEvalAction.NAME, threadPool, transportService, actionFilters, RankEvalRequest::new,
indexNameExpressionResolver);
this.scriptService = scriptService;
this.namedXContentRegistry = namedXContentRegistry;
this.client = client;
Expand Down Expand Up @@ -126,7 +126,7 @@ LoggingDeprecationHandler.INSTANCE, new BytesArray(resolvedRequest), XContentTyp
} else {
ratedSearchSource.fetchSource(summaryFields.toArray(new String[summaryFields.size()]), new String[0]);
}
msearchRequest.add(new SearchRequest(request.getIndices(), ratedSearchSource));
msearchRequest.add(new SearchRequest(request.indices(), ratedSearchSource));
}
assert ratedRequestsInSearch.size() == msearchRequest.requests().size();
client.multiSearch(msearchRequest, new RankEvalActionListener(listener, metric,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void testPrecisionAtRequest() {
RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);

RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().setIndices("test"))
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().indices("test"))
.actionGet();
// the expected Prec@ for the first query is 4/6 and the expected Prec@ for the
// second is 1/6, divided by 2 to get the average
Expand Down Expand Up @@ -131,8 +131,7 @@ public void testPrecisionAtRequest() {
metric = new PrecisionAtK(1, false, 3);
task = new RankEvalSpec(specifications, metric);

builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest(task, new String[] { "test" }));

response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// if we look only at top 3 documente, the expected P@3 for the first query is
Expand Down Expand Up @@ -164,8 +163,7 @@ public void testDCGRequest() {
RankEvalSpec task = new RankEvalSpec(specifications, metric);

RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
new RankEvalRequest(task, new String[] { "test" }));

RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(DiscountedCumulativeGainTests.EXPECTED_DCG, response.getEvaluationResult(), 10E-14);
Expand All @@ -174,8 +172,7 @@ public void testDCGRequest() {
metric = new DiscountedCumulativeGain(false, null, 3);
task = new RankEvalSpec(specifications, metric);

builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest(task, new String[] { "test" }));

response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
assertEquals(12.39278926071437, response.getEvaluationResult(), 10E-14);
Expand All @@ -194,8 +191,7 @@ public void testMRRRequest() {
RankEvalSpec task = new RankEvalSpec(specifications, metric);

RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
new RankEvalRequest(task, new String[] { "test" }));

RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// the expected reciprocal rank for the amsterdam_query is 1/5
Expand All @@ -208,8 +204,7 @@ public void testMRRRequest() {
metric = new MeanReciprocalRank(1, 3);
task = new RankEvalSpec(specifications, metric);

builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest().setIndices("test"));
builder.setRankEvalSpec(task);
builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest(task, new String[] { "test" }));

response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
// limiting to top 3 results, the amsterdam_query has no relevant document in it
Expand Down Expand Up @@ -240,7 +235,7 @@ public void testBadQuery() {
RankEvalSpec task = new RankEvalSpec(specifications, new PrecisionAtK());

RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE,
new RankEvalRequest().setIndices("test"));
new RankEvalRequest(task, new String[] { "test" }));
builder.setRankEvalSpec(task);

RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request()).actionGet();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.index.rankeval;

import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable.Reader;
import org.elasticsearch.common.util.ArrayUtils;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.junit.AfterClass;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class RankEvalRequestTests extends AbstractWireSerializingTestCase<RankEvalRequest> {

private static RankEvalPlugin rankEvalPlugin = new RankEvalPlugin();

@AfterClass
public static void releasePluginResources() throws IOException {
rankEvalPlugin.close();
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(rankEvalPlugin.getNamedXContent());
}

@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
return new NamedWriteableRegistry(rankEvalPlugin.getNamedWriteables());
}

@Override
protected RankEvalRequest createTestInstance() {
int numberOfIndices = randomInt(3);
String[] indices = new String[numberOfIndices];
for (int i=0; i < numberOfIndices; i++) {
indices[i] = randomAlphaOfLengthBetween(5, 10);
}
RankEvalRequest rankEvalRequest = new RankEvalRequest(RankEvalSpecTests.createTestItem(), indices);
IndicesOptions indicesOptions = IndicesOptions.fromOptions(
randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean());
rankEvalRequest.indicesOptions(indicesOptions);
return rankEvalRequest;
}

@Override
protected Reader<RankEvalRequest> instanceReader() {
return RankEvalRequest::new;
}

@Override
protected RankEvalRequest mutateInstance(RankEvalRequest instance) throws IOException {
RankEvalRequest mutation = copyInstance(instance);
List<Runnable> mutators = new ArrayList<>();
mutators.add(() -> mutation.indices(ArrayUtils.concat(instance.indices(), new String[] { randomAlphaOfLength(10) })));
mutators.add(() -> mutation.indicesOptions(randomValueOtherThan(instance.indicesOptions(),
() -> IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), randomBoolean(), randomBoolean()))));
mutators.add(() -> mutation.setRankEvalSpec(RankEvalSpecTests.mutateTestItem(instance.getRankEvalSpec())));
randomFrom(mutators).run();
return mutation;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private static <T> List<T> randomList(Supplier<T> randomSupplier) {
return result;
}

private static RankEvalSpec createTestItem() throws IOException {
static RankEvalSpec createTestItem() {
Supplier<EvaluationMetric> metric = randomFrom(Arrays.asList(
() -> PrecisionAtKTests.createTestItem(),
() -> MeanReciprocalRankTests.createTestItem(),
Expand All @@ -87,6 +87,9 @@ private static RankEvalSpec createTestItem() throws IOException {
builder.field("field", randomAlphaOfLengthBetween(1, 5));
builder.endObject();
script = Strings.toString(builder);
} catch (IOException e) {
// this shouldn't happen in tests, re-throw just not to swallow it
throw new RuntimeException(e);
}

templates = new HashSet<>();
Expand Down Expand Up @@ -156,7 +159,7 @@ public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), RankEvalSpecTests::copy, RankEvalSpecTests::mutateTestItem);
}

private static RankEvalSpec mutateTestItem(RankEvalSpec original) {
static RankEvalSpec mutateTestItem(RankEvalSpec original) {
List<RatedRequest> ratedRequests = new ArrayList<>(original.getRatedRequests());
EvaluationMetric metric = original.getMetric();
Map<String, Script> templates = new HashMap<>(original.getTemplates());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public void testPrecisionAtRequest() throws IOException {
RankEvalRequestBuilder builder = new RankEvalRequestBuilder(client(), RankEvalAction.INSTANCE, new RankEvalRequest());
builder.setRankEvalSpec(task);

RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().setIndices("test")).actionGet();
RankEvalResponse response = client().execute(RankEvalAction.INSTANCE, builder.request().indices("test")).actionGet();
assertEquals(0.9, response.getEvaluationResult(), Double.MIN_VALUE);
}

Expand Down

0 comments on commit e4b3007

Please sign in to comment.