diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 616aaea21d12b..792cc8f7303b4 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -20,7 +20,6 @@ import org.elasticsearch.Version; import org.elasticsearch.client.common.TimeUtil; -import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ObjectParser; @@ -31,7 +30,6 @@ import java.io.IOException; import java.time.Instant; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Objects; @@ -64,9 +62,8 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); PARSER.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); - PARSER.declareNamedObjects(TrainedModelConfig.Builder::setDefinition, - (p, c, n) -> p.namedObject(TrainedModel.class, n, null), - (modelDocBuilder) -> { /* Noop does not matter client side */ }, + PARSER.declareObject(TrainedModelConfig.Builder::setDefinition, + (p, c) -> TrainedModelDefinition.fromXContent(p), DEFINITION); } @@ -82,7 +79,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr private final Long modelVersion; private final String modelType; private final Map metadata; - private final TrainedModel definition; + private final TrainedModelDefinition definition; TrainedModelConfig(String modelId, String createdBy, @@ -91,7 +88,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr Instant createdTime, Long modelVersion, String modelType, - TrainedModel definition, + TrainedModelDefinition definition, Map metadata) { this.modelId = modelId; this.createdBy = createdBy; @@ -136,7 +133,7 @@ public Map getMetadata() { return metadata; } - public TrainedModel getDefinition() { + public TrainedModelDefinition getDefinition() { return definition; } @@ -169,11 +166,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_TYPE.getPreferredName(), modelType); } if (definition != null) { - NamedXContentObjectHelper.writeNamedObjects(builder, - params, - false, - DEFINITION.getPreferredName(), - Collections.singletonList(definition)); + builder.field(DEFINITION.getPreferredName(), definition); } if (metadata != null) { builder.field(METADATA.getPreferredName(), metadata); @@ -227,7 +220,7 @@ public static class Builder { private Long modelVersion; private String modelType; private Map metadata; - private TrainedModel definition; + private TrainedModelDefinition.Builder definition; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -273,16 +266,11 @@ public Builder setMetadata(Map metadata) { return this; } - public Builder setDefinition(TrainedModel definition) { + public Builder setDefinition(TrainedModelDefinition.Builder definition) { this.definition = definition; return this; } - private Builder setDefinition(List definition) { - assert definition.size() == 1; - return setDefinition(definition.get(0)); - } - public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -292,7 +280,7 @@ public TrainedModelConfig build() { createdTime, modelVersion, modelType, - definition, + definition == null ? null : definition.build(), metadata); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java new file mode 100644 index 0000000000000..7b564a9e684fa --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java @@ -0,0 +1,137 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TrainedModelDefinition implements ToXContentObject { + + public static final String NAME = "trained_model_doc"; + + public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); + public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); + + public static final ObjectParser PARSER = new ObjectParser<>(NAME, + true, + TrainedModelDefinition.Builder::new); + static { + PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel, + (p, c, n) -> p.namedObject(TrainedModel.class, n, null), + (modelDocBuilder) -> { /* Noop does not matter client side*/ }, + TRAINED_MODEL); + PARSER.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors, + (p, c, n) -> p.namedObject(PreProcessor.class, n, null), + (trainedModelDefBuilder) -> {/* Does not matter client side*/ }, + PREPROCESSORS); + } + + public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final TrainedModel trainedModel; + private final List preProcessors; + + TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) { + this.trainedModel = trainedModel; + this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + TRAINED_MODEL.getPreferredName(), + Collections.singletonList(trainedModel)); + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + true, + PREPROCESSORS.getPreferredName(), + preProcessors); + builder.endObject(); + return builder; + } + + public TrainedModel getTrainedModel() { + return trainedModel; + } + + public List getPreProcessors() { + return preProcessors; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDefinition that = (TrainedModelDefinition) o; + return Objects.equals(trainedModel, that.trainedModel) && + Objects.equals(preProcessors, that.preProcessors) ; + } + + @Override + public int hashCode() { + return Objects.hash(trainedModel, preProcessors); + } + + public static class Builder { + + private List preProcessors; + private TrainedModel trainedModel; + + public Builder setPreProcessors(List preProcessors) { + this.preProcessors = preProcessors; + return this; + } + + public Builder setTrainedModel(TrainedModel trainedModel) { + this.trainedModel = trainedModel; + return this; + } + + private Builder setTrainedModel(List trainedModel) { + assert trainedModel.size() == 1; + return setTrainedModel(trainedModel.get(0)); + } + + public TrainedModelDefinition build() { + return new TrainedModelDefinition(this.trainedModel, this.preProcessors); + } + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/PreProcessor.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/PreProcessor.java index ea814a8a0d61a..72c5e612abb3e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/PreProcessor.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/PreProcessor.java @@ -18,13 +18,13 @@ */ package org.elasticsearch.client.ml.inference.preprocessing; -import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.client.ml.inference.NamedXContentObject; /** * Describes a pre-processor for a defined machine learning model */ -public interface PreProcessor extends ToXContentObject { +public interface PreProcessor extends NamedXContentObject { /** * @return The name of the pre-processor diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 1f484a991aa9a..ccd6f10104c1b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.Version; -import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; @@ -61,7 +60,7 @@ protected TrainedModelConfig createTestInstance() { Instant.ofEpochMilli(randomNonNegativeLong()), randomBoolean() ? null : randomNonNegativeLong(), randomAlphaOfLength(10), - randomFrom(TreeTests.createRandom()), + randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java new file mode 100644 index 0000000000000..8eeec2ce2fcb9 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java @@ -0,0 +1,83 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests; +import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests; +import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TrainedModelDefinitionTests extends AbstractXContentTestCase { + + @Override + protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException { + return TrainedModelDefinition.fromXContent(parser).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + public static TrainedModelDefinition.Builder createRandomBuilder() { + int numberOfProcessors = randomIntBetween(1, 10); + return new TrainedModelDefinition.Builder() + .setPreProcessors( + randomBoolean() ? null : + Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(), + OneHotEncodingTests.createRandom(), + TargetMeanEncodingTests.createRandom())) + .limit(numberOfProcessors) + .collect(Collectors.toList())) + .setTrainedModel(randomFrom(TreeTests.createRandom())); + } + + @Override + protected TrainedModelDefinition createTestInstance() { + return createRandomBuilder().build(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 6b8694e7e3b0c..6140c4387831c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -17,18 +17,13 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.common.time.TimeUtils; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; -import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; import java.io.IOException; import java.time.Instant; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Objects; @@ -65,11 +60,8 @@ private static ObjectParser createParser(boole parser.declareLong(TrainedModelConfig.Builder::setModelVersion, MODEL_VERSION); parser.declareString(TrainedModelConfig.Builder::setModelType, MODEL_TYPE); parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); - parser.declareNamedObjects(TrainedModelConfig.Builder::setDefinition, - (p, c, n) -> ignoreUnknownFields ? - p.namedObject(LenientlyParsedTrainedModel.class, n, null) : - p.namedObject(StrictlyParsedTrainedModel.class, n, null), - (modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ }, + parser.declareObject(TrainedModelConfig.Builder::setDefinition, + (p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields), DEFINITION); return parser; } @@ -94,7 +86,7 @@ public static String documentId(String modelId, long modelVersion) { // TODO how to reference and store large models that will not be executed in Java??? // Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something // TODO Should this be lazily parsed when loading via the index??? - private final TrainedModel definition; + private final TrainedModelDefinition definition; TrainedModelConfig(String modelId, String createdBy, Version version, @@ -102,7 +94,7 @@ public static String documentId(String modelId, long modelVersion) { Instant createdTime, Long modelVersion, String modelType, - TrainedModel definition, + TrainedModelDefinition definition, Map metadata) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); @@ -123,7 +115,7 @@ public TrainedModelConfig(StreamInput in) throws IOException { createdTime = in.readInstant(); modelVersion = in.readVLong(); modelType = in.readString(); - definition = in.readOptionalNamedWriteable(TrainedModel.class); + definition = in.readOptionalWriteable(TrainedModelDefinition::new); metadata = in.readMap(); } @@ -160,7 +152,7 @@ public Map getMetadata() { } @Nullable - public TrainedModel getDefinition() { + public TrainedModelDefinition getDefinition() { return definition; } @@ -177,7 +169,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInstant(createdTime); out.writeVLong(modelVersion); out.writeString(modelType); - out.writeOptionalNamedWriteable(definition); + out.writeOptionalWriteable(definition); out.writeMap(metadata); } @@ -194,11 +186,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MODEL_VERSION.getPreferredName(), modelVersion); builder.field(MODEL_TYPE.getPreferredName(), modelType); if (definition != null) { - NamedXContentObjectHelper.writeNamedObjects(builder, - params, - false, - DEFINITION.getPreferredName(), - Collections.singletonList(definition)); + builder.field(DEFINITION.getPreferredName(), definition); } if (metadata != null) { builder.field(METADATA.getPreferredName(), metadata); @@ -241,7 +229,6 @@ public int hashCode() { modelVersion); } - public static class Builder { private String modelId; @@ -252,7 +239,7 @@ public static class Builder { private Long modelVersion; private String modelType; private Map metadata; - private TrainedModel definition; + private TrainedModelDefinition.Builder definition; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -298,19 +285,11 @@ public Builder setMetadata(Map metadata) { return this; } - public Builder setDefinition(TrainedModel definition) { + public Builder setDefinition(TrainedModelDefinition.Builder definition) { this.definition = definition; return this; } - private Builder setDefinition(List definition) { - if (definition.size() != 1) { - throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.", - DEFINITION.getPreferredName()); - } - return setDefinition(definition.get(0)); - } - // TODO move to REST level instead of here in the builder public void validate() { // We require a definition to be available until we support other means of supplying the definition @@ -352,7 +331,7 @@ public TrainedModelConfig build() { createdTime, modelVersion, modelType, - definition, + definition == null ? null : definition.build(), metadata); } @@ -365,7 +344,7 @@ public TrainedModelConfig build(Version version) { Instant.now(), modelVersion, modelType, - definition, + definition == null ? null : definition.build(), metadata); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java new file mode 100644 index 0000000000000..6daa530e0277c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -0,0 +1,176 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TrainedModelDefinition implements ToXContentObject, Writeable { + + public static final String NAME = "trained_model_doc"; + + public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); + public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); + + // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly + public static final ObjectParser LENIENT_PARSER = createParser(true); + public static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean ignoreUnknownFields) { + ObjectParser parser = new ObjectParser<>(NAME, + ignoreUnknownFields, + TrainedModelDefinition.Builder::new); + parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel, + (p, c, n) -> ignoreUnknownFields ? + p.namedObject(LenientlyParsedTrainedModel.class, n, null) : + p.namedObject(StrictlyParsedTrainedModel.class, n, null), + (modelDocBuilder) -> { /* Noop does not matter as we will throw if more than one is defined */ }, + TRAINED_MODEL); + parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors, + (p, c, n) -> ignoreUnknownFields ? + p.namedObject(LenientlyParsedPreProcessor.class, n, null) : + p.namedObject(StrictlyParsedPreProcessor.class, n, null), + (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), + PREPROCESSORS); + return parser; + } + + public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + private final TrainedModel trainedModel; + private final List preProcessors; + + TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) { + this.trainedModel = trainedModel; + this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); + } + + public TrainedModelDefinition(StreamInput in) throws IOException { + this.trainedModel = in.readNamedWriteable(TrainedModel.class); + this.preProcessors = in.readNamedWriteableList(PreProcessor.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(trainedModel); + out.writeNamedWriteableList(preProcessors); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + TRAINED_MODEL.getPreferredName(), + Collections.singletonList(trainedModel)); + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + true, + PREPROCESSORS.getPreferredName(), + preProcessors); + builder.endObject(); + return builder; + } + + public TrainedModel getTrainedModel() { + return trainedModel; + } + + public List getPreProcessors() { + return preProcessors; + } + + @Override + public String toString() { + return Strings.toString(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelDefinition that = (TrainedModelDefinition) o; + return Objects.equals(trainedModel, that.trainedModel) && + Objects.equals(preProcessors, that.preProcessors) ; + } + + @Override + public int hashCode() { + return Objects.hash(trainedModel, preProcessors); + } + + public static class Builder { + + private List preProcessors; + private TrainedModel trainedModel; + private boolean processorsInOrder; + + private static Builder builderForParser() { + return new Builder(false); + } + + private Builder(boolean processorsInOrder) { + this.processorsInOrder = processorsInOrder; + } + + public Builder() { + this(true); + } + + public Builder setPreProcessors(List preProcessors) { + this.preProcessors = preProcessors; + return this; + } + + public Builder setTrainedModel(TrainedModel trainedModel) { + this.trainedModel = trainedModel; + return this; + } + + private Builder setTrainedModel(List trainedModel) { + if (trainedModel.size() != 1) { + throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.", + TRAINED_MODEL.getPreferredName()); + } + return setTrainedModel(trainedModel.get(0)); + } + + private void setProcessorsInOrder(boolean value) { + this.processorsInOrder = value; + } + + public TrainedModelDefinition build() { + if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) { + throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects"); + } + return new TrainedModelDefinition(this.trainedModel, this.preProcessors); + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 7a6b884eee34c..c9e66e634b374 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.MlStrings; import org.junit.Before; @@ -65,7 +64,7 @@ protected TrainedModelConfig createTestInstance() { Instant.ofEpochMilli(randomNonNegativeLong()), randomBoolean() ? null : randomNonNegativeLong(), randomAlphaOfLength(10), - randomBoolean() ? null : randomFrom(TreeTests.createRandom()), + randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); } @@ -97,14 +96,18 @@ public void testValidateWithNullDefinition() { public void testValidateWithInvalidID() { String modelId = "InvalidID-"; ElasticsearchException ex = expectThrows(ElasticsearchException.class, - () -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate()); + () -> TrainedModelConfig.builder() + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); } public void testValidateWithLongID() { String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining()); ElasticsearchException ex = expectThrows(ElasticsearchException.class, - () -> TrainedModelConfig.builder().setDefinition(randomFrom(TreeTests.createRandom())).setModelId(modelId).validate()); + () -> TrainedModelConfig.builder() + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); } @@ -112,21 +115,21 @@ public void testValidateWithIllegallyUserProvidedFields() { String modelId = "simplemodel"; ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(randomFrom(TreeTests.createRandom())) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setCreatedTime(Instant.now()) .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo("illegal to set [created_time] at inference model creation")); ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(randomFrom(TreeTests.createRandom())) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setVersion(Version.CURRENT) .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation")); ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(randomFrom(TreeTests.createRandom())) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setCreatedBy("ml_user") .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation")); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java new file mode 100644 index 0000000000000..0ecb7c1e6c2e0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TrainedModelDefinitionTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException { + return TrainedModelDefinition.fromXContent(parser, lenient).build(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + public static TrainedModelDefinition.Builder createRandomBuilder() { + int numberOfProcessors = randomIntBetween(1, 10); + return new TrainedModelDefinition.Builder() + .setPreProcessors( + randomBoolean() ? null : + Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(), + OneHotEncodingTests.createRandom(), + TargetMeanEncodingTests.createRandom())) + .limit(numberOfProcessors) + .collect(Collectors.toList())) + .setTrainedModel(randomFrom(TreeTests.createRandom())); + } + @Override + protected TrainedModelDefinition createTestInstance() { + return createRandomBuilder().build(); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelDefinition::new; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 8da2f4a3b4ef3..53c10f05dc1de 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -11,7 +11,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -93,7 +93,7 @@ public void testGetMissingTrainingModelConfig() throws Exception { private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") - .setDefinition(TreeTests.createRandom()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") .setModelId(modelId) .setModelType("binary_decision_tree")