Skip to content

Commit

Permalink
EQL: Make EqlSearchResponse immutable
Browse files Browse the repository at this point in the history
Refactors EqlSearchResponse to make it immutable

Relates to elastic#49581
  • Loading branch information
imotov committed Jan 9, 2020
1 parent 3777324 commit 8df5e45
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -25,11 +25,10 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.common.xcontent.ObjectParser.fromList;
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;


Expand Down Expand Up @@ -79,9 +78,9 @@
*/
public class EqlSearchResponse extends ActionResponse implements ToXContentObject {

private Hits hits;
private long tookInMillis;
private boolean isTimeout;
private final Hits hits;
private final long tookInMillis;
private final boolean isTimeout;

private static final class Fields {
static final String TOOK = "took";
Expand All @@ -93,25 +92,25 @@ private static final class Fields {
private static final ParseField TIMED_OUT = new ParseField(Fields.TIMED_OUT);
private static final ParseField HITS = new ParseField(Fields.HITS);

private static final ObjectParser<EqlSearchResponse, Void> PARSER = objectParser(EqlSearchResponse::new);

private static <R extends EqlSearchResponse> ObjectParser<R, Void> objectParser(Supplier<R> supplier) {
ObjectParser<R, Void> parser = new ObjectParser<>("eql/search_response", false, supplier);
parser.declareLong(EqlSearchResponse::took, TOOK);
parser.declareBoolean(EqlSearchResponse::isTimeout, TIMED_OUT);
parser.declareObject(EqlSearchResponse::hits,
(p, c) -> Hits.fromXContent(p), HITS);
return parser;
}

// Constructor for parser from json
protected EqlSearchResponse() {
super();
private static final ConstructingObjectParser<EqlSearchResponse, Void> PARSER =
new ConstructingObjectParser<>("eql/search_response", true,
args -> {
int i = 0;
Hits hits = (Hits) args[i++];
Long took = (Long) args[i++];
Boolean timeout = (Boolean) args[i];
return new EqlSearchResponse(hits, took, timeout);
});

static {
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> Hits.fromXContent(p), HITS);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), TOOK);
PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), TIMED_OUT);
}

public EqlSearchResponse(Hits hits, long tookInMillis, boolean isTimeout) {
super();
this.hits(hits);
this.hits = hits == null ? Hits.EMPTY : hits;
this.tookInMillis = tookInMillis;
this.isTimeout = isTimeout;
}
Expand Down Expand Up @@ -152,30 +151,14 @@ public long took() {
return tookInMillis;
}

public void took(long tookInMillis) {
this.tookInMillis = tookInMillis;
}

public boolean isTimeout() {
return isTimeout;
}

public void isTimeout(boolean isTimeout) {
this.isTimeout = isTimeout;
}

public Hits hits() {
return hits;
}

public void hits(Hits hits) {
if (hits == null) {
this.hits = new Hits((Events)null, null);
} else {
this.hits = hits;
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -210,60 +193,40 @@ private static final class Fields {
private static final ParseField JOIN_KEYS = new ParseField(Fields.JOIN_KEYS);
private static final ParseField EVENTS = new ParseField(Events.NAME);

private static final ObjectParser<EqlSearchResponse.Sequence, Void> PARSER = objectParser(EqlSearchResponse.Sequence::new);
private static final ConstructingObjectParser<EqlSearchResponse.Sequence, Void> PARSER =
new ConstructingObjectParser<>("eql/search_response_sequence", true,
args -> {
int i = 0;
@SuppressWarnings("unchecked") List<String> joinKeys = (List<String>) args[i++];
@SuppressWarnings("unchecked") Events events = new Events(((List<SearchHit>) args[i]).toArray(new SearchHit[0]));
return new EqlSearchResponse.Sequence(joinKeys, events);
});

private static <R extends EqlSearchResponse.Sequence> ObjectParser<R, Void> objectParser(Supplier<R> supplier) {
ObjectParser<R, Void> parser = new ObjectParser<>("eql/search_response_sequence", false, supplier);
parser.declareStringArray(fromList(String.class, EqlSearchResponse.Sequence::joinKeys), JOIN_KEYS);
parser.declareObjectArray(Sequence::setEvents,
(p, c) -> SearchHit.fromXContent(p), EVENTS);
return parser;
static {
PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), JOIN_KEYS);
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> SearchHit.fromXContent(p), EVENTS);
}

private String[] joinKeys = null;
private Events events = null;
private final List<String> joinKeys;
private final Events events;

private Sequence(){
this(null, null);
}

public Sequence(String[] joinKeys, Events events) {
this.joinKeys(joinKeys);
if (events == null) {
this.events = new Events((SearchHit[])(null));
} else {
this.events = events;
}
public Sequence(List<String> joinKeys, Events events) {
this.joinKeys = joinKeys == null ? Collections.emptyList() : joinKeys;
this.events = events == null ? Events.EMPTY : events;
}

public Sequence(StreamInput in) throws IOException {
this.joinKeys = in.readStringArray();
this.joinKeys = in.readStringList();
this.events = new Events(in);
}

public void joinKeys(String[] joinKeys) {
if (joinKeys == null) {
this.joinKeys = new String[0];
} else {
this.joinKeys = joinKeys;
}
}

private void setEvents(List<SearchHit> hits) {
if (hits == null) {
this.events = new Events((SearchHit[])(null));
} else {
this.events = new Events(hits.toArray(new SearchHit[hits.size()]));
}
}

public static Sequence fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringArray(joinKeys);
out.writeStringCollection(joinKeys);
out.writeVInt(events.entries().length);
if (events.entries().length > 0) {
for (SearchHit hit : events.entries()) {
Expand Down Expand Up @@ -298,13 +261,13 @@ public boolean equals(Object o) {
return false;
}
Sequence that = (Sequence) o;
return Arrays.equals(joinKeys, that.joinKeys)
return Objects.equals(joinKeys, that.joinKeys)
&& Objects.equals(events, that.events);
}

@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(joinKeys), events);
return Objects.hash(joinKeys, events);
}
}

Expand All @@ -316,70 +279,58 @@ private static final class Fields {
static final String PERCENT = "_percent";
}

private int count;
private String[] keys;
private float percent;
private final int count;
private final List<String> keys;
private final float percent;

private static final ParseField COUNT = new ParseField(Fields.COUNT);
private static final ParseField KEYS = new ParseField(Fields.KEYS);
private static final ParseField PERCENT = new ParseField(Fields.PERCENT);

private static final ObjectParser<EqlSearchResponse.Count, Void> PARSER = objectParser(EqlSearchResponse.Count::new);
private static final ConstructingObjectParser<EqlSearchResponse.Count, Void> PARSER =
new ConstructingObjectParser<>("eql/search_response_count", true,
args -> {
int i = 0;
int count = (int) args[i++];
@SuppressWarnings("unchecked") List<String> joinKeys = (List<String>) args[i++];
float percent = (float) args[i];
return new EqlSearchResponse.Count(count, joinKeys, percent);
});

protected static <R extends EqlSearchResponse.Count> ObjectParser<R, Void> objectParser(Supplier<R> supplier) {
ObjectParser<R, Void> parser = new ObjectParser<>("eql/search_response_count", false, supplier);
parser.declareInt(EqlSearchResponse.Count::count, COUNT);
parser.declareStringArray(fromList(String.class, EqlSearchResponse.Count::keys), KEYS);
parser.declareFloat(EqlSearchResponse.Count::percent, PERCENT);
return parser;
static {
PARSER.declareInt(ConstructingObjectParser.constructorArg(), COUNT);
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), KEYS);
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), PERCENT);
}

private Count() {}

public Count(int count, String[] keys, float percent) {
public Count(int count, List<String> keys, float percent) {
this.count = count;
this.keys(keys);
this.keys = keys == null ? Collections.emptyList() : keys;
this.percent = percent;
}

public Count(StreamInput in) throws IOException {
count = in.readVInt();
keys = in.readStringArray();
keys = in.readStringList();
percent = in.readFloat();
}

public void count(int count) {
this.count = count;
}

public void keys(String[] keys) {
if (keys == null) {
this.keys = new String[0];
} else {
this.keys = keys;
}
}

public void percent(float percent) {
this.percent = percent;
}

public static Count fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(count);
out.writeStringArray(keys);
out.writeStringCollection(keys);
out.writeFloat(percent);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(Fields.COUNT, count);
builder.array(Fields.KEYS, keys);
builder.field(Fields.KEYS, keys);
builder.field(Fields.PERCENT, percent);
builder.endObject();
return builder;
Expand All @@ -395,13 +346,13 @@ public boolean equals(Object o) {
}
Count that = (Count) o;
return Objects.equals(count, that.count)
&& Arrays.equals(keys, that.keys)
&& Objects.equals(keys, that.keys)
&& Objects.equals(percent, that.percent);
}

@Override
public int hashCode() {
return Objects.hash(count, Arrays.hashCode(keys), percent);
return Objects.hash(count, keys, percent);
}
}

Expand Down Expand Up @@ -477,6 +428,7 @@ public int hashCode() {

// Events
public static class Events extends EqlSearchResponse.Entries<SearchHit> {
private static final Events EMPTY = new Events((SearchHit[]) null);
private static final String NAME = "events";

public Events(SearchHit[] entries) {
Expand Down Expand Up @@ -545,6 +497,8 @@ protected final Count createEntry(StreamInput in) throws IOException {

// Hits
public static class Hits implements Writeable, ToXContentFragment {
public static final Hits EMPTY = new Hits((Events)null, null);

private Events events = null;
private Sequences sequences = null;
private Counts counts = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.elasticsearch.xpack.eql.action.EqlSearchRequest;
import org.elasticsearch.xpack.eql.action.EqlSearchResponse;

import java.util.Collections;

public class TransportEqlSearchAction extends HandledTransportAction<EqlSearchRequest, EqlSearchResponse> {
private final SecurityContext securityContext;
private final ClusterService clusterService;
Expand Down Expand Up @@ -54,8 +56,8 @@ static EqlSearchResponse createResponse(EqlSearchRequest request) {
new SearchHit(2, "222", null),
});
EqlSearchResponse.Hits hits = new EqlSearchResponse.Hits(new EqlSearchResponse.Sequences(new EqlSearchResponse.Sequence[]{
new EqlSearchResponse.Sequence(new String[]{"4021"}, events),
new EqlSearchResponse.Sequence(new String[]{"2343"}, events)
new EqlSearchResponse.Sequence(Collections.singletonList("4021"), events),
new EqlSearchResponse.Sequence(Collections.singletonList("2343"), events)
}), new TotalHits(0, TotalHits.Relation.EQUAL_TO));
return new EqlSearchResponse(hits, 0, false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.test.AbstractSerializingTestCase;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

public class EqlSearchResponseTests extends AbstractSerializingTestCase<EqlSearchResponse> {

Expand Down Expand Up @@ -58,9 +60,9 @@ public static EqlSearchResponse createRandomSequencesResponse(TotalHits totalHit
if (randomBoolean()) {
seq = new EqlSearchResponse.Sequence[size];
for (int i = 0; i < size; i++) {
String[] joins = null;
List<String> joins = null;
if (randomBoolean()) {
joins = generateRandomStringArray(6, 11, false);
joins = Arrays.asList(generateRandomStringArray(6, 11, false));
}
seq[i] = new EqlSearchResponse.Sequence(joins, randomEvents());
}
Expand All @@ -82,9 +84,9 @@ public static EqlSearchResponse createRandomCountResponse(TotalHits totalHits) {
if (randomBoolean()) {
cn = new EqlSearchResponse.Count[size];
for (int i = 0; i < size; i++) {
String[] keys = null;
List<String> keys = null;
if (randomBoolean()) {
keys = generateRandomStringArray(6, 11, false);
keys = Arrays.asList(generateRandomStringArray(6, 11, false));
}
cn[i] = new EqlSearchResponse.Count(randomIntBetween(0, 41), keys, randomFloat());
}
Expand Down

0 comments on commit 8df5e45

Please sign in to comment.