From e821c16e6220b874a48f00b32e640f03badbac02 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 25 Sep 2019 11:19:43 -0400 Subject: [PATCH 1/5] [ML][Inference] adding ensemble model objects --- .../MlInferenceNamedXContentProvider.java | 13 + .../ml/inference/NamedXContentObject.java | 34 ++ .../inference/NamedXContentObjectHelper.java | 57 +++ .../ml/inference/trainedmodel/TargetType.java | 35 ++ .../inference/trainedmodel/TrainedModel.java | 4 +- .../trainedmodel/ensemble/Ensemble.java | 188 +++++++++ .../ensemble/OutputAggregator.java | 28 ++ .../trainedmodel/ensemble/WeightedMode.java | 84 ++++ .../trainedmodel/ensemble/WeightedSum.java | 84 ++++ .../ml/inference/trainedmodel/tree/Tree.java | 66 +++- .../client/RestHighLevelClientTests.java | 14 +- .../trainedmodel/ensemble/EnsembleTests.java | 99 +++++ .../ensemble/WeightedModeTests.java | 55 +++ .../ensemble/WeightedSumTests.java | 54 +++ .../trainedmodel/tree/TreeTests.java | 26 +- .../MlInferenceNamedXContentProvider.java | 33 ++ .../ml/inference/trainedmodel/TargetType.java | 36 ++ .../inference/trainedmodel/TrainedModel.java | 37 +- .../trainedmodel/ensemble/Ensemble.java | 315 +++++++++++++++ .../LenientlyParsedOutputAggregator.java | 10 + .../ensemble/OutputAggregator.java | 47 +++ .../StrictlyParsedOutputAggregator.java | 10 + .../trainedmodel/ensemble/WeightedMode.java | 145 +++++++ .../trainedmodel/ensemble/WeightedSum.java | 114 ++++++ .../ml/inference/trainedmodel/tree/Tree.java | 227 +++++++---- .../inference/trainedmodel/tree/TreeNode.java | 2 +- .../core/ml/inference/utils/Statistics.java | 52 +++ .../ml/utils/NamedXContentObjectHelper.java | 44 +++ .../inference/NamedXContentObjectsTests.java | 3 +- .../trainedmodel/ensemble/EnsembleTests.java | 368 ++++++++++++++++++ .../ensemble/WeightedAggregatorTests.java | 51 +++ .../ensemble/WeightedModeTests.java | 56 +++ .../ensemble/WeightedSumTests.java | 56 +++ .../trainedmodel/tree/TreeTests.java | 97 +++-- .../ml/inference/utils/StatisticsTests.java | 26 ++ 35 files changed, 2450 insertions(+), 120 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java index be7c3c00af2c2..2325bbf27baa0 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/MlInferenceNamedXContentProvider.java @@ -19,6 +19,10 @@ package org.elasticsearch.client.ml.inference; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; @@ -47,6 +51,15 @@ public List getNamedXContentParsers() { // Model namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent)); + + // Aggregating output + namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, + new ParseField(WeightedMode.NAME), + WeightedMode::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, + new ParseField(WeightedSum.NAME), + WeightedSum::fromXContent)); return namedXContent; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java new file mode 100644 index 0000000000000..969add5254766 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObject.java @@ -0,0 +1,34 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Simple interface for XContent Objects that are named. + * + * This affords more general handling when serializing and de-serializing this type of XContent when it is used in a NamedObjects + * parser. + */ +public interface NamedXContentObject extends ToXContentObject { + /** + * @return The name of the XContentObject that is to be serialized + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java new file mode 100644 index 0000000000000..1795f5da49511 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/NamedXContentObjectHelper.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public final class NamedXContentObjectHelper { + + private NamedXContentObjectHelper() {} + + public static XContentBuilder writeNamedObjects(XContentBuilder builder, + ToXContent.Params params, + boolean useExplicitOrder, + String namedObjectsName, + List namedObjects) throws IOException { + if (useExplicitOrder) { + builder.startArray(namedObjectsName); + } else { + builder.startObject(namedObjectsName); + } + for (NamedXContentObject object : namedObjects) { + if (useExplicitOrder) { + builder.startObject(); + } + builder.field(object.getName(), object, params); + if (useExplicitOrder) { + builder.endObject(); + } + } + if (useExplicitOrder) { + builder.endArray(); + } else { + builder.endObject(); + } + return builder; + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java new file mode 100644 index 0000000000000..694a72f1cc5f8 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TargetType.java @@ -0,0 +1,35 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel; + +import java.util.Locale; + +public enum TargetType { + + REGRESSION, CLASSIFICATION; + + public static TargetType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java index fb1f5c3b4ab92..43ff877089b51 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/TrainedModel.java @@ -18,11 +18,11 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel; -import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.client.ml.inference.NamedXContentObject; import java.util.List; -public interface TrainedModel extends ToXContentObject { +public interface TrainedModel extends NamedXContentObject { /** * @return List of featureNames expected by the model. In the order that they are expected diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java new file mode 100644 index 0000000000000..89a3815e72810 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -0,0 +1,188 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.client.ml.inference.NamedXContentObjectHelper; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class Ensemble implements TrainedModel { + + public static final String NAME = "ensemble"; + public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); + public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); + public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); + + private static final ObjectParser PARSER = new ObjectParser<>( + NAME, + true, + Ensemble.Builder::new); + + static { + PARSER.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES); + PARSER.declareNamedObjects(Ensemble.Builder::setTrainedModels, + (p, c, n) -> + p.namedObject(TrainedModel.class, n, null), + (ensembleBuilder) -> { /* Noop does not matter client side */ }, + TRAINED_MODELS); + PARSER.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser, + (p, c, n) -> p.namedObject(OutputAggregator.class, n, null), + (ensembleBuilder) -> { /* Noop does not matter client side */ }, + AGGREGATE_OUTPUT); + PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); + PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); + } + + public static Ensemble fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + private final List featureNames; + private final List models; + private final OutputAggregator outputAggregator; + private final TargetType targetType; + private final List classificationLabels; + + Ensemble(List featureNames, + List models, + @Nullable OutputAggregator outputAggregator, + TargetType targetType, + @Nullable List classificationLabels) { + this.featureNames = featureNames; + this.models = models; + this.outputAggregator = outputAggregator; + this.targetType = targetType; + this.classificationLabels = classificationLabels; + } + + @Override + public List getFeatureNames() { + return featureNames; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (featureNames != null) { + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + } + if (models != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models); + } + if (outputAggregator != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + AGGREGATE_OUTPUT.getPreferredName(), + Collections.singletonList(outputAggregator)); + } + if (targetType != null) { + builder.field(TARGET_TYPE.getPreferredName(), targetType); + } + if (classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Ensemble that = (Ensemble) o; + return Objects.equals(featureNames, that.featureNames) + && Objects.equals(models, that.models) + && Objects.equals(targetType, that.targetType) + && Objects.equals(classificationLabels, that.classificationLabels) + && Objects.equals(outputAggregator, that.outputAggregator); + } + + @Override + public int hashCode() { + return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + List featureNames; + List trainedModels; + OutputAggregator outputAggregator; + TargetType targetType; + List classificationLabels; + + public Builder setFeatureNames(List featureNames) { + this.featureNames = featureNames; + return this; + } + + public Builder setTrainedModels(List trainedModels) { + this.trainedModels = trainedModels; + return this; + } + + public Builder setOutputAggregator(OutputAggregator outputAggregator) { + this.outputAggregator = outputAggregator; + return this; + } + + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setOutputAggregatorFromParser(List outputAggregators) { + this.setOutputAggregator(outputAggregators.get(0)); + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + + public Ensemble build() { + return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java new file mode 100644 index 0000000000000..955def1999ae3 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -0,0 +1,28 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.client.ml.inference.NamedXContentObject; + +public interface OutputAggregator extends NamedXContentObject { + /** + * @return The name of the output aggregator + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java new file mode 100644 index 0000000000000..f5ad4a3f99ed5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -0,0 +1,84 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + + +public class WeightedMode implements OutputAggregator { + + public static final String NAME = "weighted_mode"; + public static final ParseField WEIGHTS = new ParseField("weights"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + a -> new WeightedMode((List)a[0])); + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + } + + public static WeightedMode fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List weights; + + public WeightedMode(List weights) { + this.weights = weights; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedMode that = (WeightedMode) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java new file mode 100644 index 0000000000000..d6132ca00fbc1 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -0,0 +1,84 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class WeightedSum implements OutputAggregator { + + public static final String NAME = "weighted_sum"; + public static final ParseField WEIGHTS = new ParseField("weights"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + true, + a -> new WeightedSum((List)a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + } + + public static WeightedSum fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final List weights; + + public WeightedSum(List weights) { + this.weights = weights; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedSum that = (WeightedSum) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java index de040ec6f9ed7..5a1e07f34e256 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/Tree.java @@ -18,7 +18,9 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel.tree; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ObjectParser; @@ -28,7 +30,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -39,12 +40,16 @@ public class Tree implements TrainedModel { public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); private static final ObjectParser PARSER = new ObjectParser<>(NAME, true, Builder::new); static { PARSER.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES); PARSER.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p), TREE_STRUCTURE); + PARSER.declareString(Builder::setTargetType, TARGET_TYPE); + PARSER.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS); } public static Tree fromXContent(XContentParser parser) { @@ -53,10 +58,14 @@ public static Tree fromXContent(XContentParser parser) { private final List featureNames; private final List nodes; - - Tree(List featureNames, List nodes) { - this.featureNames = Collections.unmodifiableList(Objects.requireNonNull(featureNames)); - this.nodes = Collections.unmodifiableList(Objects.requireNonNull(nodes)); + private final TargetType targetType; + private final List classificationLabels; + + Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { + this.featureNames = featureNames; + this.nodes = nodes; + this.targetType = targetType; + this.classificationLabels = classificationLabels; } @Override @@ -73,11 +82,30 @@ public List getNodes() { return nodes; } + @Nullable + public List getClassificationLabels() { + return classificationLabels; + } + + public TargetType getTargetType() { + return targetType; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(FEATURE_NAMES.getPreferredName(), featureNames); - builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + if (featureNames != null) { + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + } + if (nodes != null) { + builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + } + if (classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } + if (targetType != null) { + builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + } builder.endObject(); return builder; } @@ -93,12 +121,14 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Tree that = (Tree) o; return Objects.equals(featureNames, that.featureNames) + && Objects.equals(classificationLabels, that.classificationLabels) + && Objects.equals(targetType, that.targetType) && Objects.equals(nodes, that.nodes); } @Override public int hashCode() { - return Objects.hash(featureNames, nodes); + return Objects.hash(featureNames, nodes, targetType, classificationLabels); } public static Builder builder() { @@ -109,6 +139,8 @@ public static class Builder { private List featureNames; private ArrayList nodes; private int numNodes; + private TargetType targetType; + private List classificationLabels; public Builder() { nodes = new ArrayList<>(); @@ -137,6 +169,20 @@ public Builder setNodes(TreeNode.Builder... nodes) { return setNodes(Arrays.asList(nodes)); } + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + /** * Add a decision node. Space for the child nodes is allocated * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index @@ -185,7 +231,9 @@ public Builder addLeaf(int nodeIndex, double value) { public Tree build() { return new Tree(featureNames, - nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList())); + nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()), + targetType, + classificationLabels); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 7641dd3032c83..d9c40bb19c055 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -65,6 +65,9 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding; @@ -681,7 +684,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(41, namedXContents.size()); + assertEquals(44, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -691,7 +694,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 11, categories.size()); + assertEquals("Had: " + categories, 12, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -740,8 +743,11 @@ public void testProvidedNamedXContents() { RSquaredMetric.NAME)); assertEquals(Integer.valueOf(3), categories.get(org.elasticsearch.client.ml.inference.preprocessing.PreProcessor.class)); assertThat(names, hasItems(FrequencyEncoding.NAME, OneHotEncoding.NAME, TargetMeanEncoding.NAME)); - assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); - assertThat(names, hasItems(Tree.NAME)); + assertEquals(Integer.valueOf(2), categories.get(org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel.class)); + assertThat(names, hasItems(Tree.NAME, Ensemble.NAME)); + assertEquals(Integer.valueOf(2), + categories.get(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator.class)); + assertThat(names, hasItems(WeightedMode.NAME, WeightedSum.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java new file mode 100644 index 0000000000000..91d7816f33c4d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + + +public class EnsembleTests extends AbstractXContentTestCase { + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + @Override + protected Ensemble doParseInstance(XContentParser parser) throws IOException { + return Ensemble.fromXContent(parser); + } + + public static Ensemble createRandom() { + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = new ArrayList<>(); + for (int i = 0; i < numberOfFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + int numberOfModels = randomIntBetween(1, 10); + List models = new ArrayList<>(numberOfModels); + for (int i = 0; i < numberOfModels; i++) { + models.add(TreeTests.buildRandomTree(featureNames, 6)); + } + OutputAggregator outputAggregator = null; + if (randomBoolean()) { + List weights = new ArrayList<>(numberOfModels); + for (int i = 0; i < numberOfModels; i++) { + weights.add(randomDouble()); + } + outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + } + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } + return new Ensemble(featureNames, + models, + outputAggregator, + randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION), + categoryLabels); + } + + @Override + protected Ensemble createTestInstance() { + return createRandom(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java new file mode 100644 index 0000000000000..860042bb42c19 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -0,0 +1,55 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +public class WeightedModeTests extends AbstractXContentTestCase { + + WeightedMode createTestInstance(int numberOfWeights) { + List weights = new ArrayList<>(numberOfWeights); + for (int i = 0; i < numberOfWeights; i++) { + weights.add(randomDouble()); + } + return new WeightedMode(weights); + } + + @Override + protected WeightedMode doParseInstance(XContentParser parser) throws IOException { + return WeightedMode.fromXContentLenient(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected WeightedMode createTestInstance() { + return createTestInstance(randomIntBetween(1, 100)); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java new file mode 100644 index 0000000000000..d597d510b1df7 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -0,0 +1,54 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + + +public class WeightedSumTests extends AbstractXContentTestCase { + + WeightedSum createTestInstance(int numberOfWeights) { + List weights = new ArrayList<>(numberOfWeights); + for (int i = 0; i < numberOfWeights; i++) { + weights.add(randomDouble()); + } + return new WeightedSum(weights); + } + + @Override + protected WeightedSum doParseInstance(XContentParser parser) throws IOException { + return WeightedSum.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected WeightedSum createTestInstance() { + return createTestInstance(randomIntBetween(1, 100)); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java index 66cdb44b10073..1ce1af4a5b7b4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java @@ -18,11 +18,13 @@ */ package org.elasticsearch.client.ml.inference.trainedmodel.tree; +import org.elasticsearch.client.ml.inference.trainedmodel.TargetType; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Predicate; @@ -50,16 +52,17 @@ protected Tree createTestInstance() { } public static Tree createRandom() { - return buildRandomTree(randomIntBetween(2, 15), 6); + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = new ArrayList<>(); + for (int i = 0; i < numberOfFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + return buildRandomTree(featureNames, 6); } - public static Tree buildRandomTree(int numFeatures, int depth) { - + public static Tree buildRandomTree(List featureNames, int depth) { + int numFeatures = featureNames.size(); Tree.Builder builder = Tree.builder(); - List featureNames = new ArrayList<>(numFeatures); - for(int i = 0; i < numFeatures; i++) { - featureNames.add(randomAlphaOfLength(10)); - } builder.setFeatureNames(featureNames); TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); @@ -80,8 +83,13 @@ public static Tree buildRandomTree(int numFeatures, int depth) { } childNodes = nextNodes; } - - return builder.build(); + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } + return builder.setClassificationLabels(categoryLabels) + .setTargetType(randomFrom(TargetType.REGRESSION, TargetType.CLASSIFICATION)) + .build(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7f14077a1504e..7fff4d6abbd3b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -11,6 +11,12 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; @@ -46,9 +52,27 @@ public List getNamedXContentParsers() { // Model Lenient namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient)); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentLenient)); + + // Output Aggregator Lenient + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, + WeightedMode.NAME, + WeightedMode::fromXContentLenient)); + namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedOutputAggregator.class, + WeightedSum.NAME, + WeightedSum::fromXContentLenient)); // Model Strict namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentStrict)); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedTrainedModel.class, Ensemble.NAME, Ensemble::fromXContentStrict)); + + // Output Aggregator Strict + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, + WeightedMode.NAME, + WeightedMode::fromXContentStrict)); + namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedOutputAggregator.class, + WeightedSum.NAME, + WeightedSum::fromXContentStrict)); return namedXContent; } @@ -66,6 +90,15 @@ public List getNamedWriteables() { // Model namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new)); + + // Output Aggregator + namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, + WeightedSum.NAME.getPreferredName(), + WeightedSum::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(OutputAggregator.class, + WeightedMode.NAME.getPreferredName(), + WeightedMode::new)); return namedWriteables; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java new file mode 100644 index 0000000000000..9897231f5911a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TargetType.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum TargetType implements Writeable { + + REGRESSION, CLASSIFICATION; + + public static TargetType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static TargetType fromStream(StreamInput in) throws IOException { + return in.readEnum(TargetType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index 1d68e3d6d3f46..cad5a6c0a8c74 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -28,17 +29,47 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { double infer(Map fields); /** - * @return {@code true} if the model is classification, {@code false} otherwise. + * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles + * @return The predicted value. */ - boolean isClassification(); + double infer(List fields); + + /** + * @return {@link TargetType} for the model. + */ + TargetType targetType(); /** * This gathers the probabilities for each potential classification value. * + * The probabilities are indexed by classification ordinal label encoding. + * The length of this list is equal to the number of classification labels. + * * This only should return if the implementation model is inferring classification values and not regression * @param fields The fields and their values to infer against * @return The probabilities of each classification value */ - List inferProbabilities(Map fields); + List classificationProbability(Map fields); + + /** + * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles + * @return The probabilities of each classification value + */ + List classificationProbability(List fields); + /** + * The ordinal encoded list of the classification labels. + * @return Oridinal encoded list of classification labels. + */ + @Nullable + List classificationLabels(); + + /** + * Runs validations against the model. + * + * Example: {@link org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree} should check if there are any loops + * + * @throws org.elasticsearch.ElasticsearchException if validations fail + */ + void validate(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java new file mode 100644 index 0000000000000..05ff090839acd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -0,0 +1,315 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + + public static final ParseField NAME = new ParseField("ensemble"); + public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); + public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); + public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); + + private static final ObjectParser LENIENT_PARSER = createParser(true); + private static final ObjectParser STRICT_PARSER = createParser(false); + + private static ObjectParser createParser(boolean lenient) { + ObjectParser parser = new ObjectParser<>( + NAME.getPreferredName(), + lenient, + Ensemble.Builder::builderForParser); + parser.declareStringArray(Ensemble.Builder::setFeatureNames, FEATURE_NAMES); + parser.declareNamedObjects(Ensemble.Builder::setTrainedModels, + (p, c, n) -> + lenient ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) : + p.namedObject(StrictlyParsedTrainedModel.class, n, null), + (ensembleBuilder) -> ensembleBuilder.setModelsAreOrdered(true), + TRAINED_MODELS); + parser.declareNamedObjects(Ensemble.Builder::setOutputAggregatorFromParser, + (p, c, n) -> + lenient ? p.namedObject(LenientlyParsedOutputAggregator.class, n, null) : + p.namedObject(StrictlyParsedOutputAggregator.class, n, null), + (ensembleBuilder) -> {/*Noop as it could be an array or object, it just has to be a one*/}, + AGGREGATE_OUTPUT); + parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); + parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); + return parser; + } + + public static Ensemble fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null).build(); + } + + public static Ensemble fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null).build(); + } + + private final List featureNames; + private final List models; + private final OutputAggregator outputAggregator; + private final TargetType targetType; + private final List classificationLabels; + + Ensemble(List featureNames, + List models, + @Nullable OutputAggregator outputAggregator, + TargetType targetType, + @Nullable List classificationLabels) { + this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); + this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS)); + this.outputAggregator = outputAggregator; + this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); + } + + public Ensemble(StreamInput in) throws IOException { + this.featureNames = Collections.unmodifiableList(in.readStringList()); + this.models = Collections.unmodifiableList(in.readNamedWriteableList(TrainedModel.class)); + this.outputAggregator = in.readOptionalNamedWriteable(OutputAggregator.class); + this.targetType = TargetType.fromStream(in); + if (in.readBoolean()) { + this.classificationLabels = in.readStringList(); + } else { + this.classificationLabels = null; + } + } + + @Override + public List getFeatureNames() { + return featureNames; + } + + @Override + public double infer(Map fields) { + List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + return infer(features); + } + + @Override + public double infer(List fields) { + List processedInferences = inferAndProcess(fields); + if (outputAggregator != null) { + return outputAggregator.aggregate(processedInferences); + } + return processedInferences.stream().mapToDouble(Double::doubleValue).sum(); + } + + @Override + public TargetType targetType() { + return targetType; + } + + @Override + public List classificationProbability(Map fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + return classificationProbability(features); + } + + @Override + public List classificationProbability(List fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + return inferAndProcess(fields); + } + + @Override + public List classificationLabels() { + return classificationLabels; + } + + private List inferAndProcess(List fields) { + List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); + if (outputAggregator != null) { + return outputAggregator.processValues(modelInferences); + } + return modelInferences; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(featureNames); + out.writeNamedWriteableList(models); + out.writeOptionalNamedWriteable(outputAggregator); + targetType.writeTo(out); + out.writeBoolean(classificationLabels != null); + if (classificationLabels != null) { + out.writeStringCollection(classificationLabels); + } + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAMES.getPreferredName(), featureNames); + NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models); + if (outputAggregator != null) { + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + AGGREGATE_OUTPUT.getPreferredName(), + Collections.singletonList(outputAggregator)); + } + builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + if (classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Ensemble that = (Ensemble) o; + return Objects.equals(featureNames, that.featureNames) + && Objects.equals(models, that.models) + && Objects.equals(targetType, that.targetType) + && Objects.equals(classificationLabels, that.classificationLabels) + && Objects.equals(outputAggregator, that.outputAggregator); + } + + @Override + public int hashCode() { + return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels); + } + + @Override + public void validate() { + if (this.featureNames != null) { + if (this.models.stream() + .anyMatch(trainedModel -> trainedModel.getFeatureNames().equals(this.featureNames) == false)) { + throw ExceptionsHelper.badRequestException( + "[{}] must be the same and in the same order for each of the {}", + FEATURE_NAMES.getPreferredName(), + TRAINED_MODELS.getPreferredName()); + } + } + if (outputAggregator != null && + outputAggregator.expectedValueSize() != null && + outputAggregator.expectedValueSize() != models.size()) { + throw ExceptionsHelper.badRequestException( + "[{}] expects value array of size [{}] but number of models is [{}]", + AGGREGATE_OUTPUT.getPreferredName(), + outputAggregator.expectedValueSize(), + models.size()); + } + this.models.forEach(TrainedModel::validate); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + List featureNames; + List trainedModels; + OutputAggregator outputAggregator; + TargetType targetType = TargetType.REGRESSION; + List classificationLabels; + boolean modelsAreOrdered; + + private Builder (boolean modelsAreOrdered) { + this.modelsAreOrdered = modelsAreOrdered; + } + + private static Builder builderForParser() { + return new Builder(false); + } + + public Builder() { + this(true); + } + + public Builder setFeatureNames(List featureNames) { + this.featureNames = featureNames; + return this; + } + + public Builder setTrainedModels(List trainedModels) { + this.trainedModels = trainedModels; + return this; + } + + public Builder setOutputAggregator(OutputAggregator outputAggregator) { + this.outputAggregator = outputAggregator; + return this; + } + + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setOutputAggregatorFromParser(List outputAggregators) { + if ((outputAggregators.size() == 1) == false) { + throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.", + AGGREGATE_OUTPUT.getPreferredName()); + } + this.setOutputAggregator(outputAggregators.get(0)); + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + + private void setModelsAreOrdered(boolean value) { + this.modelsAreOrdered = value; + } + + public Ensemble build() { + // This is essentially a serialization error but the underlying xcontent parsing does not allow us to inject this requirement + // So, we verify the models were parsed in an ordered fashion here instead. + if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) { + throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects"); + } + return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java new file mode 100644 index 0000000000000..29ba4e3aa7389 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LenientlyParsedOutputAggregator.java @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +public interface LenientlyParsedOutputAggregator extends OutputAggregator { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java new file mode 100644 index 0000000000000..1f882b724ee94 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + +import java.util.List; + +public interface OutputAggregator extends NamedXContentObject, NamedWriteable { + + /** + * @return The expected size of the values array when aggregating. `null` implies there is no expected size. + */ + Integer expectedValueSize(); + + /** + * This pre-processes the values so that they may be passed directly to the {@link OutputAggregator#aggregate(List)} method. + * + * Two major types of pre-processed values could be returned: + * - The confidence/probability scaled values given the input values (See: {@link WeightedMode#processValues(List)} + * - A simple transformation of the passed values in preparation for aggregation (See: {@link WeightedSum#processValues(List)} + * @param values the values to process + * @return A new list containing the processed values or the same list if no processing is required + */ + List processValues(List values); + + /** + * Function to aggregate the processed values into a single double + * + * This may be as simple as returning the index of the maximum value. + * + * Or as complex as a mathematical reduction of all the passed values (i.e. summation, average, etc.). + * + * @param processedValues The values to aggregate + * @return the aggregated value. + */ + double aggregate(List processedValues); + + /** + * @return The name of the output aggregator + */ + String getName(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java new file mode 100644 index 0000000000000..017340fda44ac --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/StrictlyParsedOutputAggregator.java @@ -0,0 +1,10 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +public interface StrictlyParsedOutputAggregator extends OutputAggregator { +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java new file mode 100644 index 0000000000000..f99eba4e3031a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; + +public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + + public static final ParseField NAME = new ParseField("weighted_mode"); + public static final ParseField WEIGHTS = new ParseField("weights"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new WeightedMode((List)a[0])); + parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + return parser; + } + + public static WeightedMode fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static WeightedMode fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private final List weights; + + public WeightedMode(List weights) { + this.weights = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(weights, WEIGHTS.getPreferredName())); + } + + public WeightedMode(StreamInput in) throws IOException { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } + + @Override + public Integer expectedValueSize() { + return this.weights.size(); + } + + @Override + public List processValues(List values) { + Objects.requireNonNull(values, "values must not be null"); + if (values.size() != weights.size()) { + throw new IllegalArgumentException("values must be the same length as weights."); + } + List freqArray = new ArrayList<>(); + Integer maxVal = 0; + for (Double value : values) { + if (value == null) { + throw new IllegalArgumentException("values must not contain null values"); + } + if (Double.isNaN(value) || Double.isInfinite(value) || value < 0.0 || value != Math.rint(value)) { + throw new IllegalArgumentException("values must be whole, non-infinite, and positive"); + } + Integer integerValue = value.intValue(); + freqArray.add(integerValue); + if (integerValue > maxVal) { + maxVal = integerValue; + } + } + List frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY)); + for (int i = 0; i < freqArray.size(); i++) { + Double weight = weights.get(i); + Integer value = freqArray.get(i); + Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight; + frequencies.set(value, frequency); + } + return softMax(frequencies); + } + + @Override + public double aggregate(List values) { + int bestValue = 0; + double bestFreq = Double.NEGATIVE_INFINITY; + for (int i = 0; i < values.size(); i++) { + if (values.get(i) > bestFreq) { + bestFreq = values.get(i); + bestValue = i; + } + } + return bestValue; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(weights, StreamOutput::writeDouble); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(WEIGHTS.getPreferredName(), weights); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedMode that = (WeightedMode) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java new file mode 100644 index 0000000000000..e4e72e4e82cc5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator { + + public static final ParseField NAME = new ParseField("weighted_sum"); + public static final ParseField WEIGHTS = new ParseField("weights"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>( + NAME.getPreferredName(), + lenient, + a -> new WeightedSum((List)a[0])); + parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + return parser; + } + + public static WeightedSum fromXContentStrict(XContentParser parser) { + return STRICT_PARSER.apply(parser, null); + } + + public static WeightedSum fromXContentLenient(XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + private final List weights; + + public WeightedSum(List weights) { + this.weights = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(weights, WEIGHTS.getPreferredName())); + } + + public WeightedSum(StreamInput in) throws IOException { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } + + @Override + public List processValues(List values) { + Objects.requireNonNull(values, "values must not be null"); + if (values.size() != weights.size()) { + throw new IllegalArgumentException("values must be the same length as weights."); + } + return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList()); + } + + @Override + public double aggregate(List values) { + return values.stream().reduce((memo, v) -> memo + v).get(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(weights, StreamOutput::writeDouble); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(WEIGHTS.getPreferredName(), weights); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + WeightedSum that = (WeightedSum) o; + return Objects.equals(weights, that.weights); + } + + @Override + public int hashCode() { + return Objects.hash(weights); + } + + @Override + public Integer expectedValueSize() { + return this.weights.size(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 8e48fa488a0a8..a36a62c36d2b6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -9,11 +9,13 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -35,6 +37,8 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure"); + public static final ParseField TARGET_TYPE = new ParseField("target_type"); + public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); private static final ObjectParser LENIENT_PARSER = createParser(true); private static final ObjectParser STRICT_PARSER = createParser(false); @@ -46,6 +50,8 @@ private static ObjectParser createParser(boolean lenient) { Tree.Builder::new); parser.declareStringArray(Tree.Builder::setFeatureNames, FEATURE_NAMES); parser.declareObjectArray(Tree.Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE); + parser.declareString(Tree.Builder::setTargetType, TARGET_TYPE); + parser.declareStringArray(Tree.Builder::setClassificationLabels, CLASSIFICATION_LABELS); return parser; } @@ -59,15 +65,28 @@ public static Tree fromXContentLenient(XContentParser parser) { private final List featureNames; private final List nodes; + private final TargetType targetType; + private final List classificationLabels; + private final CachedSupplier highestOrderCategory; - Tree(List featureNames, List nodes) { + Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE)); + this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); + this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); + this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); } public Tree(StreamInput in) throws IOException { this.featureNames = Collections.unmodifiableList(in.readStringList()); this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new)); + this.targetType = TargetType.fromStream(in); + if (in.readBoolean()) { + this.classificationLabels = Collections.unmodifiableList(in.readStringList()); + } else { + this.classificationLabels = null; + } + this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); } @Override @@ -90,7 +109,8 @@ public double infer(Map fields) { return infer(features); } - private double infer(List features) { + @Override + public double infer(List features) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); @@ -115,13 +135,40 @@ public List trace(List features) { } @Override - public boolean isClassification() { - return false; + public TargetType targetType() { + return targetType; + } + + @Override + public List classificationProbability(Map fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList())); + } + + @Override + public List classificationProbability(List fields) { + if ((targetType == TargetType.CLASSIFICATION) == false) { + throw new UnsupportedOperationException( + "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + } + double label = infer(fields); + // If we are classification, we should assume that the inference return value is whole. + assert label == Math.rint(label); + double maxCategory = this.highestOrderCategory.get(); + // If we are classification, we should assume that the largest leaf value is whole. + assert maxCategory == Math.rint(maxCategory); + List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); + // TODO, eventually have TreeNodes contain confidence levels + list.set(Double.valueOf(label).intValue(), 1.0); + return list; } @Override - public List inferProbabilities(Map fields) { - throw new UnsupportedOperationException("Cannot infer probabilities against a regression model."); + public List classificationLabels() { + return classificationLabels; } @Override @@ -133,6 +180,11 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(featureNames); out.writeCollection(nodes); + targetType.writeTo(out); + out.writeBoolean(classificationLabels != null); + if (classificationLabels != null) { + out.writeStringCollection(classificationLabels); + } } @Override @@ -140,6 +192,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(FEATURE_NAMES.getPreferredName(), featureNames); builder.field(TREE_STRUCTURE.getPreferredName(), nodes); + builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); + if(classificationLabels != null) { + builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); + } builder.endObject(); return builder; } @@ -155,22 +211,91 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Tree that = (Tree) o; return Objects.equals(featureNames, that.featureNames) - && Objects.equals(nodes, that.nodes); + && Objects.equals(nodes, that.nodes) + && Objects.equals(targetType, that.targetType) + && Objects.equals(classificationLabels, that.classificationLabels); } @Override public int hashCode() { - return Objects.hash(featureNames, nodes); + return Objects.hash(featureNames, nodes, targetType, classificationLabels); } public static Builder builder() { return new Builder(); } + @Override + public void validate() { + detectNullOrMissingNode(); + detectCycle(); + } + + private void detectCycle() { + if (nodes.isEmpty()) { + return; + } + Set visited = new HashSet<>(); + Queue toVisit = new ArrayDeque<>(nodes.size()); + toVisit.add(0); + while(toVisit.isEmpty() == false) { + Integer nodeIdx = toVisit.remove(); + if (visited.contains(nodeIdx)) { + throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx); + } + visited.add(nodeIdx); + TreeNode treeNode = nodes.get(nodeIdx); + if (treeNode.getLeftChild() >= 0) { + toVisit.add(treeNode.getLeftChild()); + } + if (treeNode.getRightChild() >= 0) { + toVisit.add(treeNode.getRightChild()); + } + } + } + + private void detectNullOrMissingNode() { + if (nodes.isEmpty()) { + return; + } + + List missingNodes = new ArrayList<>(); + for (int i = 0; i < nodes.size(); i++) { + TreeNode currentNode = nodes.get(i); + if (currentNode == null) { + continue; + } + if (nodeMissing(currentNode.getLeftChild(), nodes)) { + missingNodes.add(currentNode.getLeftChild()); + } + if (nodeMissing(currentNode.getRightChild(), nodes)) { + missingNodes.add(currentNode.getRightChild()); + } + } + if (missingNodes.isEmpty() == false) { + throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes); + } + } + + private static boolean nodeMissing(int nodeIdx, List nodes) { + if (nodeIdx < 0) { + return false; + } + return nodeIdx >= nodes.size(); + } + + private Double maxLeafValue() { + return targetType == TargetType.CLASSIFICATION ? + this.nodes.stream().filter(TreeNode::isLeaf).mapToDouble(TreeNode::getLeafValue).max().getAsDouble() : + null; + } + public static class Builder { private List featureNames; private ArrayList nodes; private int numNodes; + private TargetType targetType = TargetType.REGRESSION; + private List classificationLabels; public Builder() { nodes = new ArrayList<>(); @@ -185,13 +310,18 @@ public Builder setFeatureNames(List featureNames) { return this; } + public Builder setRoot(TreeNode.Builder root) { + nodes.set(0, root); + return this; + } + public Builder addNode(TreeNode.Builder node) { nodes.add(node); return this; } public Builder setNodes(List nodes) { - this.nodes = new ArrayList<>(nodes); + this.nodes = new ArrayList<>(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName())); return this; } @@ -199,6 +329,21 @@ public Builder setNodes(TreeNode.Builder... nodes) { return setNodes(Arrays.asList(nodes)); } + + public Builder setTargetType(TargetType targetType) { + this.targetType = targetType; + return this; + } + + public Builder setClassificationLabels(List classificationLabels) { + this.classificationLabels = classificationLabels; + return this; + } + + private void setTargetType(String targetType) { + this.targetType = TargetType.fromString(targetType); + } + /** * Add a decision node. Space for the child nodes is allocated * @param nodeIndex Where to place the node. This is either 0 (root) or an existing child node index @@ -231,61 +376,6 @@ TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultL return node; } - void detectCycle(List nodes) { - if (nodes.isEmpty()) { - return; - } - Set visited = new HashSet<>(); - Queue toVisit = new ArrayDeque<>(nodes.size()); - toVisit.add(0); - while(toVisit.isEmpty() == false) { - Integer nodeIdx = toVisit.remove(); - if (visited.contains(nodeIdx)) { - throw new IllegalArgumentException("[tree] contains cycle at node " + nodeIdx); - } - visited.add(nodeIdx); - TreeNode.Builder treeNode = nodes.get(nodeIdx); - if (treeNode.getLeftChild() != null) { - toVisit.add(treeNode.getLeftChild()); - } - if (treeNode.getRightChild() != null) { - toVisit.add(treeNode.getRightChild()); - } - } - } - - void detectNullOrMissingNode(List nodes) { - if (nodes.isEmpty()) { - return; - } - if (nodes.get(0) == null) { - throw new IllegalArgumentException("[tree] must have non-null root node."); - } - List nullOrMissingNodes = new ArrayList<>(); - for (int i = 0; i < nodes.size(); i++) { - TreeNode.Builder currentNode = nodes.get(i); - if (currentNode == null) { - continue; - } - if (nodeNullOrMissing(currentNode.getLeftChild())) { - nullOrMissingNodes.add(currentNode.getLeftChild()); - } - if (nodeNullOrMissing(currentNode.getRightChild())) { - nullOrMissingNodes.add(currentNode.getRightChild()); - } - } - if (nullOrMissingNodes.isEmpty() == false) { - throw new IllegalArgumentException("[tree] contains null or missing nodes " + nullOrMissingNodes); - } - } - - private boolean nodeNullOrMissing(Integer nodeIdx) { - if (nodeIdx == null) { - return false; - } - return nodeIdx >= nodes.size() || nodes.get(nodeIdx) == null; - } - /** * Sets the node at {@code nodeIndex} to a leaf node. * @param nodeIndex The index as allocated by a call to {@link #addJunction(int, int, boolean, double)} @@ -301,10 +391,13 @@ Tree.Builder addLeaf(int nodeIndex, double value) { } public Tree build() { - detectNullOrMissingNode(nodes); - detectCycle(nodes); + if (nodes.stream().anyMatch(Objects::isNull)) { + throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes"); + } return new Tree(featureNames, - nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList())); + nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()), + targetType, + classificationLabels); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java index f0dbb0617503b..9beda88e2c50a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeNode.java @@ -143,7 +143,7 @@ public int getRightChild() { } public boolean isLeaf() { - return leftChild < 1; + return leftChild < 0; } public int compare(List features) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java new file mode 100644 index 0000000000000..cb44d03e22bb2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.utils; + +import java.util.List; +import java.util.stream.Collectors; + +public final class Statistics { + + private Statistics(){} + + /** + * Calculates the softMax of the passed values. + * + * Any {@link Double#isInfinite()}, {@link Double#NaN}, or `null` values are ignored in calculation and returned as 0.0 in the + * softMax. + * @param values Values on which to run SoftMax. + * @return A new list containing the softmax of the passed values + */ + public static List softMax(List values) { + Double expSum = 0.0; + Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null); + if (max == null) { + throw new IllegalArgumentException("no valid values present"); + } + List exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max) + .collect(Collectors.toList()); + for (int i = 0; i < exps.size(); i++) { + if (isInvalid(exps.get(i)) == false) { + Double exp = Math.exp(exps.get(i)); + expSum += exp; + exps.set(i, exp); + } + } + for (int i = 0; i < exps.size(); i++) { + if (isInvalid(exps.get(i))) { + exps.set(i, 0.0); + } else { + exps.set(i, exps.get(i)/expSum); + } + } + return exps; + } + + public static boolean isInvalid(Double v) { + return v == null || Double.isInfinite(v) || Double.isNaN(v); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java new file mode 100644 index 0000000000000..a7a6d22ae3e0c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/NamedXContentObjectHelper.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.utils; + +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +public final class NamedXContentObjectHelper { + + private NamedXContentObjectHelper() {} + + public static XContentBuilder writeNamedObjects(XContentBuilder builder, + ToXContent.Params params, + boolean useExplicitOrder, + String namedObjectsName, + List namedObjects) throws IOException { + if (useExplicitOrder) { + builder.startArray(namedObjectsName); + } else { + builder.startObject(namedObjectsName); + } + for (NamedXContentObject object : namedObjects) { + if (useExplicitOrder) { + builder.startObject(); + } + builder.field(object.getName(), object, params); + if (useExplicitOrder) { + builder.endObject(); + } + } + if (useExplicitOrder) { + builder.endArray(); + } else { + builder.endObject(); + } + return builder; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java index 3a3856cbe95a4..2db86e64e3502 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/NamedXContentObjectsTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; @@ -157,7 +158,7 @@ public NamedObjectContainer createTestInstance() { NamedObjectContainer container = new NamedObjectContainer(); container.setPreProcessors(preProcessors); container.setUseExplicitPreprocessorOrder(true); - container.setModel(TreeTests.buildRandomTree(5, 4)); + container.setModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom())); return container; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java new file mode 100644 index 0000000000000..a4860b8e97292 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -0,0 +1,368 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + +public class EnsembleTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + @Override + protected Ensemble doParseInstance(XContentParser parser) throws IOException { + return lenient ? Ensemble.fromXContentLenient(parser) : Ensemble.fromXContentStrict(parser); + } + + public static Ensemble createRandom() { + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = new ArrayList<>(); + for (int i = 0; i < numberOfFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + int numberOfModels = randomIntBetween(1, 10); + List models = new ArrayList<>(numberOfModels); + for (int i = 0; i < numberOfModels; i++) { + models.add(TreeTests.buildRandomTree(featureNames, 6)); + } + OutputAggregator outputAggregator = null; + if (randomBoolean()) { + List weights = new ArrayList<>(numberOfModels); + for (int i = 0; i < numberOfModels; i++) { + weights.add(randomDouble()); + } + outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + } + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } + + return new Ensemble(featureNames, + models, + outputAggregator, + randomFrom(TargetType.REGRESSION, TargetType.CLASSIFICATION), + categoryLabels); + } + + @Override + protected Ensemble createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Ensemble::new; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + public void testEnsembleWithModelsThatHaveDifferentFeatureNames() { + List featureNames = Arrays.asList("foo", "bar", "baz", "farequote"); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder().setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("bar", "foo", "baz", "farequote"), 6))) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models")); + + ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder().setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(TreeTests.buildRandomTree(Arrays.asList("completely_different"), 6))) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo("[feature_names] must be the same and in the same order for each of the trained_models")); + } + + public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() { + List featureNames = Arrays.asList("foo", "bar"); + int numberOfModels = 5; + List weights = new ArrayList<>(numberOfModels + 2); + for (int i = 0; i < numberOfModels + 2; i++) { + weights.add(randomDouble()); + } + OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); + + List models = new ArrayList<>(numberOfModels); + for (int i = 0; i < numberOfModels; i++) { + models.add(TreeTests.buildRandomTree(featureNames, 6)); + } + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setTrainedModels(models) + .setOutputAggregator(outputAggregator) + .setFeatureNames(featureNames) + .build() + .validate(); + }); + assertThat(ex.getMessage(), equalTo("[aggregate_output] expects value array of size [7] but number of models is [5]")); + } + + public void testEnsembleWithInvalidModel() { + List featureNames = Arrays.asList("foo", "bar"); + expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + // Tree with loop + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble()), + TreeNode.builder(0) + .setLeftChild(0) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .build() + .validate(); + }); + } + + public void testClassificationProbability() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + List expected = Arrays.asList(0.23147521, 0.768524783); + List probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i), probabilities.get(i), 0.000001); + } + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + expected = Arrays.asList(0.3100255188, 0.689974481); + probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i), probabilities.get(i), 0.000001); + } + + featureVector = Arrays.asList(0.0, 1.0); + featureMap = zipObjMap(featureNames, featureVector); + expected = Arrays.asList(0.231475216, 0.768524783); + probabilities = ensemble.classificationProbability(featureMap); + for(int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i), probabilities.get(i), 0.000001); + } + } + + public void testClassificationInference() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(0.0, 1.0); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + } + + public void testRegressionInference() { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Ensemble ensemble = Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5))) + .build(); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.9, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + + ensemble = Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .build(); + + featureVector = Arrays.asList(0.4, 0.0); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + + featureVector = Arrays.asList(2.0, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + } + + private static Map zipObjMap(List keys, List values) { + return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java new file mode 100644 index 0000000000000..02bfe2797d990 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedAggregatorTests.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public abstract class WeightedAggregatorTests extends AbstractSerializingTestCase { + + protected boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + public void testWithNullValues() { + OutputAggregator outputAggregator = createTestInstance(); + NullPointerException ex = expectThrows(NullPointerException.class, () -> outputAggregator.processValues(null)); + assertThat(ex.getMessage(), equalTo("values must not be null")); + } + + public void testWithValuesOfWrongLength() { + int numberOfValues = randomIntBetween(5, 10); + List values = new ArrayList<>(numberOfValues); + for (int i = 0; i < numberOfValues; i++) { + values.add(randomDouble()); + } + + OutputAggregator outputAggregatorWithTooFewWeights = createTestInstance(randomIntBetween(1, numberOfValues - 1)); + expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooFewWeights.processValues(values)); + + OutputAggregator outputAggregatorWithTooManyWeights = createTestInstance(randomIntBetween(numberOfValues + 1, numberOfValues + 10)); + expectThrows(IllegalArgumentException.class, () -> outputAggregatorWithTooManyWeights.processValues(values)); + } + + abstract T createTestInstance(int numberOfWeights); +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java new file mode 100644 index 0000000000000..22e39c13d2e67 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class WeightedModeTests extends WeightedAggregatorTests { + + @Override + WeightedMode createTestInstance(int numberOfWeights) { + List weights = new ArrayList<>(numberOfWeights); + for (int i = 0; i < numberOfWeights; i++) { + weights.add(randomDouble()); + } + return new WeightedMode(weights); + } + + @Override + protected WeightedMode doParseInstance(XContentParser parser) throws IOException { + return lenient ? WeightedMode.fromXContentLenient(parser) : WeightedMode.fromXContentStrict(parser); + } + + @Override + protected WeightedMode createTestInstance() { + return createTestInstance(randomIntBetween(1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return WeightedMode::new; + } + + public void testAggregate() { + List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + + WeightedMode weightedMode = new WeightedMode(ones); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); + + List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + + weightedMode = new WeightedMode(variedWeights); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java new file mode 100644 index 0000000000000..01755f316e7a1 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class WeightedSumTests extends WeightedAggregatorTests { + + @Override + WeightedSum createTestInstance(int numberOfWeights) { + List weights = new ArrayList<>(numberOfWeights); + for (int i = 0; i < numberOfWeights; i++) { + weights.add(randomDouble()); + } + return new WeightedSum(weights); + } + + @Override + protected WeightedSum doParseInstance(XContentParser parser) throws IOException { + return lenient ? WeightedSum.fromXContentLenient(parser) : WeightedSum.fromXContentStrict(parser); + } + + @Override + protected WeightedSum createTestInstance() { + return createTestInstance(randomIntBetween(1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return WeightedSum::new; + } + + public void testAggregate() { + List ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0); + List values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0); + + WeightedSum weightedSum = new WeightedSum(ones); + assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); + + List variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0); + + weightedSum = new WeightedSum(variedWeights); + assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0)); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 391f2e4b7e59a..160d89d692d7a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -5,9 +5,11 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; import java.io.IOException; @@ -47,23 +49,23 @@ protected Predicate getRandomFieldsExcludeFilter() { return field -> field.startsWith("feature_names"); } - @Override protected Tree createTestInstance() { return createRandom(); } public static Tree createRandom() { - return buildRandomTree(randomIntBetween(2, 15), 6); + int numberOfFeatures = randomIntBetween(1, 10); + List featureNames = new ArrayList<>(); + for (int i = 0; i < numberOfFeatures; i++) { + featureNames.add(randomAlphaOfLength(10)); + } + return buildRandomTree(featureNames, 6); } - public static Tree buildRandomTree(int numFeatures, int depth) { - + public static Tree buildRandomTree(List featureNames, int depth) { Tree.Builder builder = Tree.builder(); - List featureNames = new ArrayList<>(numFeatures); - for(int i = 0; i < numFeatures; i++) { - featureNames.add(randomAlphaOfLength(10)); - } + int numFeatures = featureNames.size() - 1; builder.setFeatureNames(featureNames); TreeNode.Builder node = builder.addJunction(0, randomInt(numFeatures), true, randomDouble()); @@ -84,8 +86,14 @@ public static Tree buildRandomTree(int numFeatures, int depth) { } childNodes = nextNodes; } + List categoryLabels = null; + if (randomBoolean()) { + categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); + } - return builder.build(); + return builder.setTargetType(randomFrom(TargetType.REGRESSION, TargetType.CLASSIFICATION)) + .setClassificationLabels(categoryLabels) + .build(); } @Override @@ -96,7 +104,7 @@ protected Writeable.Reader instanceReader() { public void testInfer() { // Build a tree with 2 nodes and 3 leaves using 2 features // The leaves have unique values 0.1, 0.2, 0.3 - Tree.Builder builder = Tree.builder(); + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); builder.addLeaf(rootNode.getRightChild(), 0.3); TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); @@ -124,37 +132,76 @@ public void testInfer() { assertEquals(0.2, tree.infer(featureMap), 0.00001); } + public void testTreeClassificationProbability() { + // Build a tree with 2 nodes and 3 leaves using 2 features + // The leaves have unique values 0.1, 0.2, 0.3 + Tree.Builder builder = Tree.builder().setTargetType(TargetType.CLASSIFICATION); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), 1.0); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), 1.0); + builder.addLeaf(leftChildNode.getRightChild(), 0.0); + + List featureNames = Arrays.asList("foo", "bar"); + Tree tree = builder.setFeatureNames(featureNames).build(); + + // This feature vector should hit the right child of the root node + List featureVector = Arrays.asList(0.6, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + + // This should hit the left child of the left child of the root node + // i.e. it takes the path left, left + featureVector = Arrays.asList(0.3, 0.7); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + + // This should hit the right child of the left child of the root node + // i.e. it takes the path left, right + featureVector = Arrays.asList(0.3, 0.9); + featureMap = zipObjMap(featureNames, featureVector); + assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + } + public void testTreeWithNullRoot() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(Collections.singletonList(null)) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(Collections.singletonList(null)) + .setFeatureNames(Arrays.asList("foo", "bar")) .build()); - assertThat(ex.getMessage(), equalTo("[tree] must have non-null root node.")); + assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes")); } public void testTreeWithInvalidNode() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(TreeNode.builder(0) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble())) - .build()); - assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]")); + .setFeatureNames(Arrays.asList("foo", "bar")) + .build().validate()); + assertThat(ex.getMessage(), equalTo("[tree] contains missing nodes [1]")); } public void testTreeWithNullNode() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(TreeNode.builder(0) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble()), null) - .build()); - assertThat(ex.getMessage(), equalTo("[tree] contains null or missing nodes [1]")); + .setFeatureNames(Arrays.asList("foo", "bar")) + .build() + .validate()); + assertThat(ex.getMessage(), equalTo("[tree] cannot contain null nodes")); } public void testTreeWithCycle() { - IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, - () -> Tree.builder().setNodes(TreeNode.builder(0) + ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class, + () -> Tree.builder() + .setNodes(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble()), @@ -162,7 +209,9 @@ public void testTreeWithCycle() { .setLeftChild(0) .setSplitFeature(1) .setThreshold(randomDouble())) - .build()); + .setFeatureNames(Arrays.asList("foo", "bar")) + .build() + .validate()); assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java new file mode 100644 index 0000000000000..feed5a6438c5d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.utils; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Arrays; +import java.util.List; + +public class StatisticsTests extends ESTestCase { + + public void testSoftMax() { + List values = Arrays.asList(Double.NEGATIVE_INFINITY, 1.0, -0.5, null, Double.NaN, Double.POSITIVE_INFINITY, 1.0, 5.0); + List softMax = Statistics.softMax(values); + + List expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042); + + for(int i = 0; i < expected.size(); i++) { + assertEquals(expected.get(i), softMax.get(i), 0.000001); + } + } + +} From e6dd87d0052759b88563b553dee1e22a0d992a52 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 30 Sep 2019 08:42:58 -0400 Subject: [PATCH 2/5] addressing PR comments --- .../trainedmodel/ensemble/EnsembleTests.java | 24 ++++---- .../ensemble/WeightedModeTests.java | 11 ++-- .../ensemble/WeightedSumTests.java | 11 ++-- .../trainedmodel/ensemble/Ensemble.java | 4 ++ .../trainedmodel/ensemble/WeightedMode.java | 4 ++ .../trainedmodel/ensemble/WeightedSum.java | 8 ++- .../ml/inference/trainedmodel/tree/Tree.java | 11 +++- .../trainedmodel/ensemble/EnsembleTests.java | 55 ++++++++++++++----- .../ensemble/WeightedModeTests.java | 9 ++- .../ensemble/WeightedSumTests.java | 9 ++- .../trainedmodel/tree/TreeTests.java | 29 +++++++++- .../ml/inference/utils/StatisticsTests.java | 9 ++- 12 files changed, 128 insertions(+), 56 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 91d7816f33c4d..774ab26bc17c7 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.ArrayList; @@ -34,6 +35,8 @@ import java.util.Collections; import java.util.List; import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class EnsembleTests extends AbstractXContentTestCase { @@ -55,21 +58,16 @@ protected Ensemble doParseInstance(XContentParser parser) throws IOException { public static Ensemble createRandom() { int numberOfFeatures = randomIntBetween(1, 10); - List featureNames = new ArrayList<>(); - for (int i = 0; i < numberOfFeatures; i++) { - featureNames.add(randomAlphaOfLength(10)); - } + List featureNames = Stream.generate(() -> randomAlphaOfLength(10)) + .limit(numberOfFeatures) + .collect(Collectors.toList()); int numberOfModels = randomIntBetween(1, 10); - List models = new ArrayList<>(numberOfModels); - for (int i = 0; i < numberOfModels; i++) { - models.add(TreeTests.buildRandomTree(featureNames, 6)); - } + List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) + .limit(numberOfFeatures) + .collect(Collectors.toList()); OutputAggregator outputAggregator = null; if (randomBoolean()) { - List weights = new ArrayList<>(numberOfModels); - for (int i = 0; i < numberOfModels; i++) { - weights.add(randomDouble()); - } + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); } List categoryLabels = null; @@ -79,7 +77,7 @@ public static Ensemble createRandom() { return new Ensemble(featureNames, models, outputAggregator, - randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION), + randomFrom(TargetType.values()), categoryLabels); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 860042bb42c19..8c1fb7f6f2517 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -20,21 +20,18 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class WeightedModeTests extends AbstractXContentTestCase { WeightedMode createTestInstance(int numberOfWeights) { - List weights = new ArrayList<>(numberOfWeights); - for (int i = 0; i < numberOfWeights; i++) { - weights.add(randomDouble()); - } - return new WeightedMode(weights); + return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index d597d510b1df7..0ceac07149ec3 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -20,20 +20,17 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class WeightedSumTests extends AbstractXContentTestCase { WeightedSum createTestInstance(int numberOfWeights) { - List weights = new ArrayList<>(numberOfWeights); - for (int i = 0; i < numberOfWeights; i++) { - weights.add(randomDouble()); - } - return new WeightedSum(weights); + return new WeightedSum(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList())); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 05ff090839acd..942a366347a1f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -235,6 +235,10 @@ public void validate() { outputAggregator.expectedValueSize(), models.size()); } + if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { + throw ExceptionsHelper.badRequestException( + "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"); + } this.models.forEach(TrainedModel::validate); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index f99eba4e3031a..42fb075b28377 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -96,9 +96,13 @@ public List processValues(List values) { @Override public double aggregate(List values) { + Objects.requireNonNull(values, "values must not be null"); int bestValue = 0; double bestFreq = Double.NEGATIVE_INFINITY; for (int i = 0; i < values.size(); i++) { + if (values.get(i) == null) { + throw new IllegalArgumentException("values must not contain null values"); + } if (values.get(i) > bestFreq) { bestFreq = values.get(i); bestValue = i; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index e4e72e4e82cc5..3a35276d06e36 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -18,6 +18,7 @@ import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -68,7 +69,12 @@ public List processValues(List values) { @Override public double aggregate(List values) { - return values.stream().reduce((memo, v) -> memo + v).get(); + Objects.requireNonNull(values, "values must not be null"); + Optional summation = values.stream().reduce((memo, v) -> memo + v); + if (summation.isPresent()) { + return summation.get(); + } + throw new IllegalArgumentException("values must not contain null values"); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index a36a62c36d2b6..c3957c4a029b8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -227,10 +227,18 @@ public static Builder builder() { @Override public void validate() { + checkTargetType(); detectNullOrMissingNode(); detectCycle(); } + private void checkTargetType() { + if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { + throw ExceptionsHelper.badRequestException( + "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"); + } + } + private void detectCycle() { if (nodes.isEmpty()) { return; @@ -278,9 +286,6 @@ private void detectNullOrMissingNode() { } private static boolean nodeMissing(int nodeIdx, List nodes) { - if (nodeIdx < 0) { - return false; - } return nodeIdx >= nodes.size(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index a4860b8e97292..85f208491de25 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -30,6 +31,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -59,21 +61,14 @@ protected Ensemble doParseInstance(XContentParser parser) throws IOException { public static Ensemble createRandom() { int numberOfFeatures = randomIntBetween(1, 10); - List featureNames = new ArrayList<>(); - for (int i = 0; i < numberOfFeatures; i++) { - featureNames.add(randomAlphaOfLength(10)); - } + List featureNames = Stream.generate(() -> randomAlphaOfLength(10)).limit(numberOfFeatures).collect(Collectors.toList()); int numberOfModels = randomIntBetween(1, 10); - List models = new ArrayList<>(numberOfModels); - for (int i = 0; i < numberOfModels; i++) { - models.add(TreeTests.buildRandomTree(featureNames, 6)); - } + List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) + .limit(numberOfModels) + .collect(Collectors.toList()); OutputAggregator outputAggregator = null; if (randomBoolean()) { - List weights = new ArrayList<>(numberOfModels); - for (int i = 0; i < numberOfModels; i++) { - weights.add(randomDouble()); - } + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); } List categoryLabels = null; @@ -84,7 +79,7 @@ public static Ensemble createRandom() { return new Ensemble(featureNames, models, outputAggregator, - randomFrom(TargetType.REGRESSION, TargetType.CLASSIFICATION), + randomFrom(TargetType.values()), categoryLabels); } @@ -179,6 +174,40 @@ public void testEnsembleWithInvalidModel() { }); } + public void testEnsembleWithTargetTypeAndLabelsMismatch() { + List featureNames = Arrays.asList("foo", "bar"); + expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .build() + .validate(); + }); + expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setTargetType(TargetType.CLASSIFICATION) + .build() + .validate(); + }); + } + public void testClassificationProbability() { List featureNames = Arrays.asList("foo", "bar"); Tree tree1 = Tree.builder() diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 22e39c13d2e67..7c22d6cd62b7e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -7,11 +7,13 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -19,10 +21,7 @@ public class WeightedModeTests extends WeightedAggregatorTests { @Override WeightedMode createTestInstance(int numberOfWeights) { - List weights = new ArrayList<>(numberOfWeights); - for (int i = 0; i < numberOfWeights; i++) { - weights.add(randomDouble()); - } + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); return new WeightedMode(weights); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index 01755f316e7a1..f52ff11077a6d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -7,11 +7,13 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.ESTestCase; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -19,10 +21,7 @@ public class WeightedSumTests extends WeightedAggregatorTests { @Override WeightedSum createTestInstance(int numberOfWeights) { - List weights = new ArrayList<>(numberOfWeights); - for (int i = 0; i < numberOfWeights; i++) { - weights.add(randomDouble()); - } + List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()); return new WeightedSum(weights); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 160d89d692d7a..3c893a04cd534 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; @@ -91,7 +92,7 @@ public static Tree buildRandomTree(List featureNames, int depth) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } - return builder.setTargetType(randomFrom(TargetType.REGRESSION, TargetType.CLASSIFICATION)) + return builder.setTargetType(randomFrom(TargetType.values())) .setClassificationLabels(categoryLabels) .build(); } @@ -215,6 +216,32 @@ public void testTreeWithCycle() { assertThat(ex.getMessage(), equalTo("[tree] contains cycle at node 0")); } + public void testTreeWithTargetTypeAndLabelsMismatch() { + List featureNames = Arrays.asList("foo", "bar"); + expectThrows(ElasticsearchException.class, () -> { + Tree.builder() + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(Arrays.asList("foo", "bar")) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .build() + .validate(); + }); + expectThrows(ElasticsearchException.class, () -> { + Tree.builder() + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(Arrays.asList("foo", "bar")) + .setTargetType(TargetType.CLASSIFICATION) + .build() + .validate(); + }); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java index feed5a6438c5d..5fb69238b1579 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/utils/StatisticsTests.java @@ -10,6 +10,8 @@ import java.util.Arrays; import java.util.List; +import static org.hamcrest.Matchers.closeTo; + public class StatisticsTests extends ESTestCase { public void testSoftMax() { @@ -19,8 +21,13 @@ public void testSoftMax() { List expected = Arrays.asList(0.0, 0.017599040, 0.003926876, 0.0, 0.0, 0.0, 0.017599040, 0.960875042); for(int i = 0; i < expected.size(); i++) { - assertEquals(expected.get(i), softMax.get(i), 0.000001); + assertThat(softMax.get(i), closeTo(expected.get(i), 0.000001)); } } + public void testSoftMaxWithNoValidValues() { + List values = Arrays.asList(Double.NEGATIVE_INFINITY, null, Double.NaN, Double.POSITIVE_INFINITY); + expectThrows(IllegalArgumentException.class, () -> Statistics.softMax(values)); + } + } From ecb310f5102d3e2a4fd236f5c1ab7b48efd2d55b Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 1 Oct 2019 07:10:56 -0400 Subject: [PATCH 3/5] Update TreeTests.java --- .../client/ml/inference/trainedmodel/tree/TreeTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java index 1ce1af4a5b7b4..cb06469eaeaf1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/tree/TreeTests.java @@ -88,7 +88,7 @@ public static Tree buildRandomTree(List featureNames, int depth) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } return builder.setClassificationLabels(categoryLabels) - .setTargetType(randomFrom(TargetType.REGRESSION, TargetType.CLASSIFICATION)) + .setTargetType(randomFrom(TargetType.values())) .build(); } From 8bbeaf3b52f674419e744c14611d5c40c5cf296e Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 1 Oct 2019 09:43:22 -0400 Subject: [PATCH 4/5] addressing PR comments --- .../trainedmodel/ensemble/Ensemble.java | 10 ++-- .../trainedmodel/ensemble/WeightedMode.java | 2 +- .../trainedmodel/ensemble/WeightedSum.java | 2 +- .../ensemble/WeightedModeTests.java | 5 +- .../ensemble/WeightedSumTests.java | 2 +- .../trainedmodel/ensemble/Ensemble.java | 50 ++++++++----------- .../trainedmodel/ensemble/WeightedMode.java | 30 +++++++---- .../trainedmodel/ensemble/WeightedSum.java | 34 ++++++++++--- .../ml/inference/trainedmodel/tree/Tree.java | 7 +-- .../trainedmodel/ensemble/EnsembleTests.java | 27 ++++++---- .../ensemble/WeightedModeTests.java | 5 +- .../ensemble/WeightedSumTests.java | 5 +- .../trainedmodel/tree/TreeTests.java | 11 ++-- 13 files changed, 113 insertions(+), 77 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java index 89a3815e72810..d16d758769c2b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -142,11 +142,11 @@ public static Builder builder() { } public static class Builder { - List featureNames; - List trainedModels; - OutputAggregator outputAggregator; - TargetType targetType; - List classificationLabels; + private List featureNames; + private List trainedModels; + private OutputAggregator outputAggregator; + private TargetType targetType; + private List classificationLabels; public Builder setFeatureNames(List featureNames) { this.featureNames = featureNames; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java index f5ad4a3f99ed5..37d589badd1e4 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -41,7 +41,7 @@ public class WeightedMode implements OutputAggregator { true, a -> new WeightedMode((List)a[0])); static { - PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); } public static WeightedMode fromXContent(XContentParser parser) { diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java index d6132ca00fbc1..534eb8d4def2d 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -41,7 +41,7 @@ public class WeightedSum implements OutputAggregator { a -> new WeightedSum((List)a[0])); static { - PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); } public static WeightedSum fromXContent(XContentParser parser) { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 8c1fb7f6f2517..a04652c1d3813 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -21,7 +21,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; import java.io.IOException; import java.util.stream.Collectors; @@ -36,7 +35,7 @@ WeightedMode createTestInstance(int numberOfWeights) { @Override protected WeightedMode doParseInstance(XContentParser parser) throws IOException { - return WeightedMode.fromXContentLenient(parser); + return WeightedMode.fromXContent(parser); } @Override @@ -46,7 +45,7 @@ protected boolean supportsUnknownFields() { @Override protected WeightedMode createTestInstance() { - return createTestInstance(randomIntBetween(1, 100)); + return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100)); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index 0ceac07149ec3..ddc4aeccfd34d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -45,7 +45,7 @@ protected boolean supportsUnknownFields() { @Override protected WeightedSum createTestInstance() { - return createTestInstance(randomIntBetween(1, 100)); + return randomBoolean() ? new WeightedSum(null) : createTestInstance(randomIntBetween(1, 100)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 942a366347a1f..7f2a7cc9a02ce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -28,6 +28,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("ensemble"); public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); public static final ParseField TRAINED_MODELS = new ParseField("trained_models"); @@ -77,12 +78,12 @@ public static Ensemble fromXContentLenient(XContentParser parser) { Ensemble(List featureNames, List models, - @Nullable OutputAggregator outputAggregator, + OutputAggregator outputAggregator, TargetType targetType, @Nullable List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS)); - this.outputAggregator = outputAggregator; + this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); } @@ -90,7 +91,7 @@ public static Ensemble fromXContentLenient(XContentParser parser) { public Ensemble(StreamInput in) throws IOException { this.featureNames = Collections.unmodifiableList(in.readStringList()); this.models = Collections.unmodifiableList(in.readNamedWriteableList(TrainedModel.class)); - this.outputAggregator = in.readOptionalNamedWriteable(OutputAggregator.class); + this.outputAggregator = in.readNamedWriteable(OutputAggregator.class); this.targetType = TargetType.fromStream(in); if (in.readBoolean()) { this.classificationLabels = in.readStringList(); @@ -113,10 +114,7 @@ public double infer(Map fields) { @Override public double infer(List fields) { List processedInferences = inferAndProcess(fields); - if (outputAggregator != null) { - return outputAggregator.aggregate(processedInferences); - } - return processedInferences.stream().mapToDouble(Double::doubleValue).sum(); + return outputAggregator.aggregate(processedInferences); } @Override @@ -150,10 +148,7 @@ public List classificationLabels() { private List inferAndProcess(List fields) { List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); - if (outputAggregator != null) { - return outputAggregator.processValues(modelInferences); - } - return modelInferences; + return outputAggregator.processValues(modelInferences); } @Override @@ -165,7 +160,7 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(featureNames); out.writeNamedWriteableList(models); - out.writeOptionalNamedWriteable(outputAggregator); + out.writeNamedWriteable(outputAggregator); targetType.writeTo(out); out.writeBoolean(classificationLabels != null); if (classificationLabels != null) { @@ -183,13 +178,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(FEATURE_NAMES.getPreferredName(), featureNames); NamedXContentObjectHelper.writeNamedObjects(builder, params, true, TRAINED_MODELS.getPreferredName(), models); - if (outputAggregator != null) { - NamedXContentObjectHelper.writeNamedObjects(builder, - params, - false, - AGGREGATE_OUTPUT.getPreferredName(), - Collections.singletonList(outputAggregator)); - } + NamedXContentObjectHelper.writeNamedObjects(builder, + params, + false, + AGGREGATE_OUTPUT.getPreferredName(), + Collections.singletonList(outputAggregator)); builder.field(TARGET_TYPE.getPreferredName(), targetType.toString()); if (classificationLabels != null) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); @@ -226,8 +219,7 @@ public void validate() { TRAINED_MODELS.getPreferredName()); } } - if (outputAggregator != null && - outputAggregator.expectedValueSize() != null && + if (outputAggregator.expectedValueSize() != null && outputAggregator.expectedValueSize() != models.size()) { throw ExceptionsHelper.badRequestException( "[{}] expects value array of size [{}] but number of models is [{}]", @@ -247,12 +239,12 @@ public static Builder builder() { } public static class Builder { - List featureNames; - List trainedModels; - OutputAggregator outputAggregator; - TargetType targetType = TargetType.REGRESSION; - List classificationLabels; - boolean modelsAreOrdered; + private List featureNames; + private List trainedModels; + private OutputAggregator outputAggregator = new WeightedSum(); + private TargetType targetType = TargetType.REGRESSION; + private List classificationLabels; + private boolean modelsAreOrdered; private Builder (boolean modelsAreOrdered) { this.modelsAreOrdered = modelsAreOrdered; @@ -277,7 +269,7 @@ public Builder setTrainedModels(List trainedModels) { } public Builder setOutputAggregator(OutputAggregator outputAggregator) { - this.outputAggregator = outputAggregator; + this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); return this; } @@ -292,7 +284,7 @@ public Builder setClassificationLabels(List classificationLabels) { } private void setOutputAggregatorFromParser(List outputAggregators) { - if ((outputAggregators.size() == 1) == false) { + if (outputAggregators.size() != 1) { throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.", AGGREGATE_OUTPUT.getPreferredName()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 42fb075b28377..739a4e13d8659 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.ArrayList; @@ -36,7 +35,7 @@ private static ConstructingObjectParser createParser(boolean NAME.getPreferredName(), lenient, a -> new WeightedMode((List)a[0])); - parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); return parser; } @@ -50,23 +49,31 @@ public static WeightedMode fromXContentLenient(XContentParser parser) { private final List weights; + WeightedMode() { + this.weights = null; + } + public WeightedMode(List weights) { - this.weights = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(weights, WEIGHTS.getPreferredName())); + this.weights = weights == null ? null : Collections.unmodifiableList(weights); } public WeightedMode(StreamInput in) throws IOException { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + if (in.readBoolean()) { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } else { + this.weights = null; + } } @Override public Integer expectedValueSize() { - return this.weights.size(); + return this.weights == null ? null : this.weights.size(); } @Override public List processValues(List values) { Objects.requireNonNull(values, "values must not be null"); - if (values.size() != weights.size()) { + if (weights != null && values.size() != weights.size()) { throw new IllegalArgumentException("values must be the same length as weights."); } List freqArray = new ArrayList<>(); @@ -86,7 +93,7 @@ public List processValues(List values) { } List frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY)); for (int i = 0; i < freqArray.size(); i++) { - Double weight = weights.get(i); + Double weight = weights == null ? 1.0 : weights.get(i); Integer value = freqArray.get(i); Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight; frequencies.set(value, frequency); @@ -123,13 +130,18 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeBoolean(weights != null); + if (weights != null) { + out.writeCollection(weights, StreamOutput::writeDouble); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(WEIGHTS.getPreferredName(), weights); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } builder.endObject(); return builder; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index 3a35276d06e36..f5812dabf88f2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; @@ -36,7 +35,7 @@ private static ConstructingObjectParser createParser(boolean NAME.getPreferredName(), lenient, a -> new WeightedSum((List)a[0])); - parser.declareDoubleArray(ConstructingObjectParser.constructorArg(), WEIGHTS); + parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS); return parser; } @@ -50,17 +49,28 @@ public static WeightedSum fromXContentLenient(XContentParser parser) { private final List weights; + WeightedSum() { + this.weights = null; + } + public WeightedSum(List weights) { - this.weights = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(weights, WEIGHTS.getPreferredName())); + this.weights = weights == null ? null : Collections.unmodifiableList(weights); } public WeightedSum(StreamInput in) throws IOException { - this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + if (in.readBoolean()) { + this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble)); + } else { + this.weights = null; + } } @Override public List processValues(List values) { Objects.requireNonNull(values, "values must not be null"); + if (weights == null) { + return values; + } if (values.size() != weights.size()) { throw new IllegalArgumentException("values must be the same length as weights."); } @@ -70,7 +80,10 @@ public List processValues(List values) { @Override public double aggregate(List values) { Objects.requireNonNull(values, "values must not be null"); - Optional summation = values.stream().reduce((memo, v) -> memo + v); + if (values.isEmpty()) { + throw new IllegalArgumentException("values must not be empty"); + } + Optional summation = values.stream().reduce(Double::sum); if (summation.isPresent()) { return summation.get(); } @@ -89,13 +102,18 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(weights, StreamOutput::writeDouble); + out.writeBoolean(weights != null); + if (weights != null) { + out.writeCollection(weights, StreamOutput::writeDouble); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(WEIGHTS.getPreferredName(), weights); + if (weights != null) { + builder.field(WEIGHTS.getPreferredName(), weights); + } builder.endObject(); return builder; } @@ -115,6 +133,6 @@ public int hashCode() { @Override public Integer expectedValueSize() { - return this.weights.size(); + return weights == null ? null : this.weights.size(); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index c3957c4a029b8..5dca29d58437e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -33,6 +33,7 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { + // TODO should we have regression/classification sub-classes that accept the builder? public static final ParseField NAME = new ParseField("tree"); public static final ParseField FEATURE_NAMES = new ParseField("feature_names"); @@ -228,7 +229,7 @@ public static Builder builder() { @Override public void validate() { checkTargetType(); - detectNullOrMissingNode(); + detectMissingNodes(); detectCycle(); } @@ -243,7 +244,7 @@ private void detectCycle() { if (nodes.isEmpty()) { return; } - Set visited = new HashSet<>(); + Set visited = new HashSet<>(nodes.size()); Queue toVisit = new ArrayDeque<>(nodes.size()); toVisit.add(0); while(toVisit.isEmpty() == false) { @@ -262,7 +263,7 @@ private void detectCycle() { } } - private void detectNullOrMissingNode() { + private void detectMissingNodes() { if (nodes.isEmpty()) { return; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 85f208491de25..1e1b1f8f7286d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -33,6 +33,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; public class EnsembleTests extends AbstractSerializingTestCase { @@ -66,11 +67,10 @@ public static Ensemble createRandom() { List models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6)) .limit(numberOfModels) .collect(Collectors.toList()); - OutputAggregator outputAggregator = null; - if (randomBoolean()) { - List weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); - outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); - } + List weights = randomBoolean() ? + null : + Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList()); + OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights)); List categoryLabels = null; if (randomBoolean()) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); @@ -176,7 +176,8 @@ public void testEnsembleWithInvalidModel() { public void testEnsembleWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); - expectThrows(ElasticsearchException.class, () -> { + String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { Ensemble.builder() .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList( @@ -191,7 +192,8 @@ public void testEnsembleWithTargetTypeAndLabelsMismatch() { .build() .validate(); }); - expectThrows(ElasticsearchException.class, () -> { + assertThat(ex.getMessage(), equalTo(msg)); + ex = expectThrows(ElasticsearchException.class, () -> { Ensemble.builder() .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList( @@ -206,6 +208,7 @@ public void testEnsembleWithTargetTypeAndLabelsMismatch() { .build() .validate(); }); + assertThat(ex.getMessage(), equalTo(msg)); } public void testClassificationProbability() { @@ -254,10 +257,11 @@ public void testClassificationProbability() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - List expected = Arrays.asList(0.23147521, 0.768524783); + List expected = Arrays.asList(0.231475216, 0.768524783); + double eps = 0.000001; List probabilities = ensemble.classificationProbability(featureMap); for(int i = 0; i < expected.size(); i++) { - assertEquals(expected.get(i), probabilities.get(i), 0.000001); + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(2.0, 0.7); @@ -265,7 +269,7 @@ public void testClassificationProbability() { expected = Arrays.asList(0.3100255188, 0.689974481); probabilities = ensemble.classificationProbability(featureMap); for(int i = 0; i < expected.size(); i++) { - assertEquals(expected.get(i), probabilities.get(i), 0.000001); + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(0.0, 1.0); @@ -273,7 +277,7 @@ public void testClassificationProbability() { expected = Arrays.asList(0.231475216, 0.768524783); probabilities = ensemble.classificationProbability(featureMap); for(int i = 0; i < expected.size(); i++) { - assertEquals(expected.get(i), probabilities.get(i), 0.000001); + assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); } } @@ -376,6 +380,7 @@ public void testRegressionInference() { featureMap = zipObjMap(featureNames, featureVector); assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() .setTargetType(TargetType.REGRESSION) .setFeatureNames(featureNames) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 7c22d6cd62b7e..7849d6d071ef1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -32,7 +32,7 @@ protected WeightedMode doParseInstance(XContentParser parser) throws IOException @Override protected WeightedMode createTestInstance() { - return createTestInstance(randomIntBetween(1, 100)); + return randomBoolean() ? new WeightedMode() : createTestInstance(randomIntBetween(1, 100)); } @Override @@ -51,5 +51,8 @@ public void testAggregate() { weightedMode = new WeightedMode(variedWeights); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0)); + + weightedMode = new WeightedMode(); + assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index f52ff11077a6d..89222365c83d8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -32,7 +32,7 @@ protected WeightedSum doParseInstance(XContentParser parser) throws IOException @Override protected WeightedSum createTestInstance() { - return createTestInstance(randomIntBetween(1, 100)); + return randomBoolean() ? new WeightedSum() : createTestInstance(randomIntBetween(1, 100)); } @Override @@ -51,5 +51,8 @@ public void testAggregate() { weightedSum = new WeightedSum(variedWeights); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0)); + + weightedSum = new WeightedSum(); + assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 3c893a04cd534..ce27120d671be 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -218,28 +218,31 @@ public void testTreeWithCycle() { public void testTreeWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); - expectThrows(ElasticsearchException.class, () -> { + String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { Tree.builder() .setRoot(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble())) - .setFeatureNames(Arrays.asList("foo", "bar")) + .setFeatureNames(featureNames) .setClassificationLabels(Arrays.asList("label1", "label2")) .build() .validate(); }); - expectThrows(ElasticsearchException.class, () -> { + assertThat(ex.getMessage(), equalTo(msg)); + ex = expectThrows(ElasticsearchException.class, () -> { Tree.builder() .setRoot(TreeNode.builder(0) .setLeftChild(1) .setSplitFeature(1) .setThreshold(randomDouble())) - .setFeatureNames(Arrays.asList("foo", "bar")) + .setFeatureNames(featureNames) .setTargetType(TargetType.CLASSIFICATION) .build() .validate(); }); + assertThat(ex.getMessage(), equalTo(msg)); } private static Map zipObjMap(List keys, List values) { From 2925524cae0cb2256752cb3b1291ae48f1e32e83 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 1 Oct 2019 11:44:31 -0400 Subject: [PATCH 5/5] fixing test --- .../java/org/elasticsearch/client/RestHighLevelClientTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 8d9542e001b9d..e632b1f8165ab 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -686,7 +686,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(44, namedXContents.size()); + assertEquals(47, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) {