Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] adjusting definition object schema and validation #47447

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -38,6 +39,7 @@ public class TrainedModelDefinition implements ToXContentObject {

public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
Expand All @@ -51,6 +53,7 @@ public class TrainedModelDefinition implements ToXContentObject {
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
PREPROCESSORS);
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
}

public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
Expand All @@ -59,10 +62,12 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser)

private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private final Input input;

TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
this.trainedModel = trainedModel;
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = input;
}

@Override
Expand All @@ -78,6 +83,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
builder.endObject();
return builder;
}
Expand All @@ -90,6 +98,10 @@ public List<PreProcessor> getPreProcessors() {
return preProcessors;
}

public Input getInput() {
return input;
}

@Override
public String toString() {
return Strings.toString(this);
Expand All @@ -101,18 +113,20 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) ;
Objects.equals(preProcessors, that.preProcessors) &&
Objects.equals(input, that.input);
}

@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
return Objects.hash(trainedModel, preProcessors, input);
}

public static class Builder {

private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
private Input input;

public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
Expand All @@ -124,14 +138,71 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
return this;
}

public Builder setInput(Input input) {
this.input = input;
return this;
}

private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0));
}

public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
}
}

public static class Input implements ToXContentObject {
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved

public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Input((List<String>)a[0]));
static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
}

public static Input fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

private final List<String> fieldNames;

public Input(List<String> fieldNames) {
this.fieldNames = fieldNames;
}

public List<String> getFieldNames() {
return fieldNames;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldNames != null) {
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}

@Override
public int hashCode() {
return Objects.hash(fieldNames);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class TargetMeanEncoding implements PreProcessor {
public static final String NAME = "target_mean_encoding";
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
public static final ParseField TARGET_MEANS = new ParseField("target_means");
public static final ParseField TARGET_MAP = new ParseField("target_map");
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");

@SuppressWarnings("unchecked")
Expand All @@ -52,7 +52,7 @@ public class TargetMeanEncoding implements PreProcessor {
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
TARGET_MEANS);
TARGET_MAP);
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
}

Expand Down Expand Up @@ -110,7 +110,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
builder.startObject();
builder.field(FIELD.getPreferredName(), field);
builder.field(FEATURE_NAME.getPreferredName(), featureName);
builder.field(TARGET_MEANS.getPreferredName(), meanMap);
builder.field(TARGET_MAP.getPreferredName(), meanMap);
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ public static TrainedModelDefinition.Builder createRandomBuilder() {
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom()));
.setTrainedModel(randomFrom(TreeTests.createRandom()))
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList())));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -30,10 +31,11 @@

public class TrainedModelDefinition implements ToXContentObject, Writeable {

public static final String NAME = "trained_model_doc";
public static final String NAME = "trained_mode_definition";

public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelDefinition.Builder, Void> LENIENT_PARSER = createParser(true);
Expand All @@ -55,6 +57,7 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS);
parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT);
return parser;
}

Expand All @@ -64,21 +67,25 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,

private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private final Input input;

TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = trainedModel;
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
}

public TrainedModelDefinition(StreamInput in) throws IOException {
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
this.input = new Input(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(trainedModel);
out.writeNamedWriteableList(preProcessors);
input.writeTo(out);
}

@Override
Expand All @@ -94,6 +101,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
builder.field(INPUT.getPreferredName(), input);
builder.endObject();
return builder;
}
Expand All @@ -106,6 +114,10 @@ public List<PreProcessor> getPreProcessors() {
return preProcessors;
}

public Input getInput() {
return input;
}

@Override
public String toString() {
return Strings.toString(this);
Expand All @@ -117,19 +129,21 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) ;
Objects.equals(input, that.input) &&
Objects.equals(preProcessors, that.preProcessors);
}

@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
return Objects.hash(trainedModel, input, preProcessors);
}

public static class Builder {

private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
private boolean processorsInOrder;
private Input input;

private static Builder builderForParser() {
return new Builder(false);
Expand All @@ -153,6 +167,11 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
return this;
}

public Builder setInput(Input input) {
this.input = input;
return this;
}

private Builder setTrainedModel(List<TrainedModel> trainedModel) {
if (trainedModel.size() != 1) {
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
Expand All @@ -169,8 +188,71 @@ public TrainedModelDefinition build() {
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
}
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
}
}

public static class Input implements ToXContentObject, Writeable {

public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");
przemekwitek marked this conversation as resolved.
Show resolved Hide resolved

public static final ConstructingObjectParser<Input, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<Input, Void> STRICT_PARSER = createParser(false);

@SuppressWarnings("unchecked")
private static ConstructingObjectParser<Input, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<Input, Void> parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new Input((List<String>)a[0]));
parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
return parser;
}

public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}

private final List<String> fieldNames;

public Input(List<String> fieldNames) {
this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES));
}

public Input(StreamInput in) throws IOException {
this.fieldNames = Collections.unmodifiableList(in.readStringList());
}

public List<String> getFieldNames() {
return fieldNames;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(fieldNames);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}

@Override
public int hashCode() {
return Objects.hash(fieldNames);
}

}

}
Loading