From 71a0500d27e4dd3a87fc98e996869d491aecfcae Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 10 Dec 2019 10:22:53 +0200 Subject: [PATCH] [ML] Introduce randomize_seed setting for regression and classification (#49990) This adds a new `randomize_seed` for regression and classification. When not explicitly set, the seed is randomly generated. One can reuse the seed in a similar job in order to ensure the same docs are picked for training. --- .../client/ml/dataframe/Classification.java | 27 +++++- .../client/ml/dataframe/Regression.java | 29 +++++-- .../client/MachineLearningIT.java | 2 + .../MlClientDocumentationIT.java | 4 +- .../ml/dataframe/ClassificationTests.java | 1 + .../ml/put-data-frame-analytics.asciidoc | 4 +- .../apis/dfanalyticsresources.asciidoc | 4 + .../apis/put-dfanalytics.asciidoc | 4 +- docs/reference/ml/ml-shared.asciidoc | 9 ++ .../dataframe/DataFrameAnalyticsConfig.java | 3 +- .../dataframe/analyses/BoostedTreeParams.java | 4 +- .../ml/dataframe/analyses/Classification.java | 41 +++++++-- .../ml/dataframe/analyses/Regression.java | 41 +++++++-- .../DataFrameAnalyticsConfigTests.java | 47 ++++++++++- .../analyses/ClassificationTests.java | 84 +++++++++++++++---- .../dataframe/analyses/RegressionTests.java | 71 ++++++++++++++-- .../ml/integration/ClassificationIT.java | 50 +++++++++-- ...NativeDataFrameAnalyticsIntegTestCase.java | 22 +++++ .../xpack/ml/integration/RegressionIT.java | 41 ++++++++- .../TransportPutDataFrameAnalyticsAction.java | 12 +-- .../CustomProcessorFactory.java | 4 +- .../DatasetSplittingCustomProcessor.java | 6 +- .../DatasetSplittingCustomProcessorTests.java | 10 ++- .../test/ml/data_frame_analytics_crud.yml | 16 ++-- 24 files changed, 460 insertions(+), 76 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index d4e7bce5ec442..9d384e6d86786 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -49,6 +49,7 @@ public static Builder builder(String dependentVariable) { static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -63,7 +64,8 @@ public static Builder builder(String dependentVariable) { (Double) a[5], (String) a[6], (Double) a[7], - (Integer) a[8])); + (Integer) a[8], + (Long) a[9])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -75,6 +77,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); } private final String dependentVariable; @@ -86,10 +89,11 @@ public static Builder builder(String dependentVariable) { private final String predictionFieldName; private final Double trainingPercent; private final Integer numTopClasses; + private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, - @Nullable Double trainingPercent, @Nullable Integer numTopClasses) { + @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -99,6 +103,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.numTopClasses = numTopClasses; + this.randomizeSeed = randomizeSeed; } @Override @@ -138,6 +143,10 @@ public Double getTrainingPercent() { return trainingPercent; } + public Long getRandomizeSeed() { + return randomizeSeed; + } + public Integer getNumTopClasses() { return numTopClasses; } @@ -167,6 +176,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (trainingPercent != null) { builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); } + if (randomizeSeed != null) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } if (numTopClasses != null) { builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); } @@ -177,7 +189,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses); + trainingPercent, randomizeSeed, numTopClasses); } @Override @@ -193,6 +205,7 @@ public boolean equals(Object o) { && Objects.equals(featureBagFraction, that.featureBagFraction) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) + && Objects.equals(randomizeSeed, that.randomizeSeed) && Objects.equals(numTopClasses, that.numTopClasses); } @@ -211,6 +224,7 @@ public static class Builder { private String predictionFieldName; private Double trainingPercent; private Integer numTopClasses; + private Long randomizeSeed; private Builder(String dependentVariable) { this.dependentVariable = Objects.requireNonNull(dependentVariable); @@ -251,6 +265,11 @@ public Builder setTrainingPercent(Double trainingPercent) { return this; } + public Builder setRandomizeSeed(Long randomizeSeed) { + this.randomizeSeed = randomizeSeed; + return this; + } + public Builder setNumTopClasses(Integer numTopClasses) { this.numTopClasses = numTopClasses; return this; @@ -258,7 +277,7 @@ public Builder setNumTopClasses(Integer numTopClasses) { public Classification build() { return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses); + trainingPercent, numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index 3c1edece6fc16..fa55ee40b27fb 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) { static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) { (Integer) a[4], (Double) a[5], (String) a[6], - (Double) a[7])); + (Double) a[7], + (Long) a[8])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); } private final String dependentVariable; @@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) { private final Double featureBagFraction; private final String predictionFieldName; private final Double trainingPercent; + private final Long randomizeSeed; private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -94,6 +98,7 @@ private Regression(String dependentVariable, @Nullable Double lambda, @Nullable this.featureBagFraction = featureBagFraction; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; + this.randomizeSeed = randomizeSeed; } @Override @@ -133,6 +138,10 @@ public Double getTrainingPercent() { return trainingPercent; } + public Long getRandomizeSeed() { + return randomizeSeed; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -158,6 +167,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (trainingPercent != null) { builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); } + if (randomizeSeed != null) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } builder.endObject(); return builder; } @@ -165,7 +177,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); + trainingPercent, randomizeSeed); } @Override @@ -180,7 +192,8 @@ public boolean equals(Object o) { && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) && Objects.equals(predictionFieldName, that.predictionFieldName) - && Objects.equals(trainingPercent, that.trainingPercent); + && Objects.equals(trainingPercent, that.trainingPercent) + && Objects.equals(randomizeSeed, that.randomizeSeed); } @Override @@ -197,6 +210,7 @@ public static class Builder { private Double featureBagFraction; private String predictionFieldName; private Double trainingPercent; + private Long randomizeSeed; private Builder(String dependentVariable) { this.dependentVariable = Objects.requireNonNull(dependentVariable); @@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) { return this; } + public Builder setRandomizeSeed(Long randomizeSeed) { + this.randomizeSeed = randomizeSeed; + return this; + } + public Regression build() { return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); + trainingPercent, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 6ed3734831aa2..29e69c5095cbd 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1291,6 +1291,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception { .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) + .setRandomizeSeed(42L) .build()) .setDescription("this is a regression") .build(); @@ -1326,6 +1327,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti .setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) + .setRandomizeSeed(42L) .setNumTopClasses(1) .build()) .setDescription("this is a classification") diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 1d9a151cf8ae3..13185e221633b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2975,7 +2975,8 @@ public void testPutDataFrameAnalytics() throws Exception { .setFeatureBagFraction(0.4) // <6> .setPredictionFieldName("my_prediction_field_name") // <7> .setTrainingPercent(50.0) // <8> - .setNumTopClasses(1) // <9> + .setRandomizeSeed(1234L) // <9> + .setNumTopClasses(1) // <10> .build(); // end::put-data-frame-analytics-classification @@ -2988,6 +2989,7 @@ public void testPutDataFrameAnalytics() throws Exception { .setFeatureBagFraction(0.4) // <6> .setPredictionFieldName("my_prediction_field_name") // <7> .setTrainingPercent(50.0) // <8> + .setRandomizeSeed(1234L) // <9> .build(); // end::put-data-frame-analytics-regression diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 98f060cc8534a..5ef8fdaef5a27 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -34,6 +34,7 @@ public static Classification randomClassification() { .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) + .setRandomizeSeed(randomBoolean() ? null : randomLong()) .setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10)) .build(); } diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 91a97ad604cee..2152eff5c0850 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -119,7 +119,8 @@ include-tagged::{doc-tests-file}[{api}-classification] <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. <7> The name of the prediction field in the results object. <8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The number of top classes to be reported in the results. Defaults to 2. +<9> The seed to be used by the random generator that picks which rows are used in training. +<10> The number of top classes to be reported in the results. Defaults to 2. ===== Regression @@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression] <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. <7> The name of the prediction field in the results object. <8> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<9> The seed to be used by the random generator that picks which rows are used in training. ==== Analyzed fields diff --git a/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc b/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc index e8ee463c66af7..111953b8321ab 100644 --- a/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc +++ b/docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc @@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name] include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent] +include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed] + [float] [[regression-resources-advanced]] @@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name] include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent] +include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed] + [float] [[classification-resources-advanced]] diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 5b0987e41c4bc..123eb6633e37b 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -397,7 +397,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3 { "regression": { "dependent_variable": "G3", - "training_percent": 70 <1> + "training_percent": 70, <1> + "randomize_seed": 19673948271 <2> } } } @@ -406,6 +407,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3 <1> The `training_percent` defines the percentage of the data set that will be used for training the model. +<2> The `randomize_seed` is the seed used to randomly pick which data is used for training. [[ml-put-dfanalytics-example-c]] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 11e062796afa6..bea970078d06b 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -681,6 +681,15 @@ those that contain arrays) won’t be included in the calculation for used percentage. Defaults to `100`. end::training_percent[] +tag::randomize_seed[] +`randomize_seed`:: +(Optional, long) Defines the seed to the random generator that is used to pick +which documents will be used for training. By default it is randomly generated. +Set it to a specific value to ensure the same documents are used for training +assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same. +end::randomize_seed[] + + tag::use-null[] Defines whether a new series is used as the null series when there is no value for the by or partition fields. The default value is `false`. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 9fd7f8aa86fcb..1142b5411fb0c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -225,7 +225,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(DEST.getPreferredName(), dest); builder.startObject(ANALYSIS.getPreferredName()); - builder.field(analysis.getWriteableName(), analysis); + builder.field(analysis.getWriteableName(), analysis, + new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString()))); builder.endObject(); if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index ed3cff7d73c0c..0f06b08444f53 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -49,7 +49,7 @@ static void declareFields(AbstractObjectParser parser) { private final Integer maximumNumberTrees; private final Double featureBagFraction; - BoostedTreeParams(@Nullable Double lambda, + public BoostedTreeParams(@Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @@ -76,7 +76,7 @@ static void declareFields(AbstractObjectParser parser) { this.featureBagFraction = featureBagFraction; } - BoostedTreeParams() { + public BoostedTreeParams() { this(null, null, null, null, null); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index b4b258ea161fa..cd96b815fc11e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -5,8 +5,10 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -35,6 +37,7 @@ public class Classification implements DataFrameAnalysis { public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -48,12 +51,14 @@ private static ConstructingObjectParser createParser(boole new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), (String) a[6], (Integer) a[7], - (Double) a[8])); + (Double) a[8], + (Long) a[9])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); + parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); return parser; } @@ -82,12 +87,14 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU private final String predictionFieldName; private final int numTopClasses; private final double trainingPercent; + private final long randomizeSeed; public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, @Nullable Integer numTopClasses, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, + @Nullable Long randomizeSeed) { if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) { throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName()); } @@ -99,10 +106,11 @@ public Classification(String dependentVariable, this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; + this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; } public Classification(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null); + this(dependentVariable, new BoostedTreeParams(), null, null, null, null); } public Classification(StreamInput in) throws IOException { @@ -111,12 +119,21 @@ public Classification(StreamInput in) throws IOException { predictionFieldName = in.readOptionalString(); numTopClasses = in.readOptionalVInt(); trainingPercent = in.readDouble(); + if (in.getVersion().onOrAfter(Version.CURRENT)) { + randomizeSeed = in.readOptionalLong(); + } else { + randomizeSeed = Randomness.get().nextLong(); + } } public String getDependentVariable() { return dependentVariable; } + public BoostedTreeParams getBoostedTreeParams() { + return boostedTreeParams; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -129,6 +146,11 @@ public double getTrainingPercent() { return trainingPercent; } + @Nullable + public Long getRandomizeSeed() { + return randomizeSeed; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -141,10 +163,15 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(predictionFieldName); out.writeOptionalVInt(numTopClasses); out.writeDouble(trainingPercent); + if (out.getVersion().onOrAfter(Version.CURRENT)) { + out.writeOptionalLong(randomizeSeed); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + Version version = Version.fromString(params.param("version", Version.CURRENT.toString())); + builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); @@ -153,6 +180,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + if (version.onOrAfter(Version.CURRENT)) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } builder.endObject(); return builder; } @@ -238,11 +268,12 @@ public boolean equals(Object o) { && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(numTopClasses, that.numTopClasses) - && trainingPercent == that.trainingPercent; + && trainingPercent == that.trainingPercent + && randomizeSeed == that.randomizeSeed; } @Override public int hashCode() { - return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent); + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 01388f01d807c..dd8f6a91272c2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -5,8 +5,10 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Randomness; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -32,6 +34,7 @@ public class Regression implements DataFrameAnalysis { public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -44,11 +47,13 @@ private static ConstructingObjectParser createParser(boolean l (String) a[0], new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), (String) a[6], - (Double) a[7])); + (Double) a[7], + (Long) a[8])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); + parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); return parser; } @@ -60,11 +65,13 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; private final double trainingPercent; + private final long randomizeSeed; public Regression(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, + @Nullable Long randomizeSeed) { if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName()); } @@ -72,10 +79,11 @@ public Regression(String dependentVariable, this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; + this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; } public Regression(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null); + this(dependentVariable, new BoostedTreeParams(), null, null, null); } public Regression(StreamInput in) throws IOException { @@ -83,12 +91,21 @@ public Regression(StreamInput in) throws IOException { boostedTreeParams = new BoostedTreeParams(in); predictionFieldName = in.readOptionalString(); trainingPercent = in.readDouble(); + if (in.getVersion().onOrAfter(Version.CURRENT)) { + randomizeSeed = in.readOptionalLong(); + } else { + randomizeSeed = Randomness.get().nextLong(); + } } public String getDependentVariable() { return dependentVariable; } + public BoostedTreeParams getBoostedTreeParams() { + return boostedTreeParams; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -97,6 +114,11 @@ public double getTrainingPercent() { return trainingPercent; } + @Nullable + public Long getRandomizeSeed() { + return randomizeSeed; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -108,10 +130,15 @@ public void writeTo(StreamOutput out) throws IOException { boostedTreeParams.writeTo(out); out.writeOptionalString(predictionFieldName); out.writeDouble(trainingPercent); + if (out.getVersion().onOrAfter(Version.CURRENT)) { + out.writeOptionalLong(randomizeSeed); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + Version version = Version.fromString(params.param("version", Version.CURRENT.toString())); + builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); @@ -119,6 +146,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); + if (version.onOrAfter(Version.CURRENT)) { + builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); + } builder.endObject(); return builder; } @@ -177,11 +207,12 @@ public boolean equals(Object o) { return Objects.equals(dependentVariable, that.dependentVariable) && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) - && trainingPercent == that.trainingPercent; + && trainingPercent == that.trainingPercent + && randomizeSeed == randomizeSeed; } @Override public int hashCode() { - return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent); + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index d6b2c077388e3..880bea8884658 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; @@ -20,17 +21,20 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.junit.Before; @@ -42,10 +46,13 @@ import java.util.List; import java.util.Map; -import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { @@ -339,6 +346,44 @@ public void testPreventVersionInjection() throws IOException { } } + public void testToXContent_GivenAnalysisWithRandomizeSeedAndVersionIsCurrent() throws IOException { + Regression regression = new Regression("foo"); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setVersion(Version.CURRENT) + .setId("test_config") + .setSource(new DataFrameAnalyticsSource(new String[] {"source_index"}, null, null)) + .setDest(new DataFrameAnalyticsDest("dest_index", null)) + .setAnalysis(regression) + .build(); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenAnalysisWithRandomizeSeedAndVersionIsBeforeItWasIntroduced() throws IOException { + Regression regression = new Regression("foo"); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder() + .setVersion(Version.V_7_5_0) + .setId("test_config") + .setSource(new DataFrameAnalyticsSource(new String[] {"source_index"}, null, null)) + .setDest(new DataFrameAnalyticsDest("dest_index", null)) + .setAnalysis(regression) + .build(); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + assertThat(json, not(containsString("randomize_seed"))); + } + } + private static void assertTooSmall(ElasticsearchStatusException e) { assertThat(e.getMessage(), startsWith("model_memory_limit must be at least 1kb.")); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 61d6b4dfe3f7a..8308ef8dad289 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -6,20 +6,28 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; +import java.util.Collections; import java.util.Map; import java.util.Set; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class ClassificationTests extends AbstractSerializingTestCase { @@ -42,7 +50,9 @@ public static Classification createRandom() { String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); - return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent); + Long randomizeSeed = randomBoolean() ? null : randomLong(); + return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, + randomizeSeed); } @Override @@ -52,71 +62,71 @@ protected Writeable.Reader instanceReader() { public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testGetPredictionFieldName() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("result")); - classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); } public void testGetNumTopClasses() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(7)); // Boundary condition: num_top_classes == 0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(0)); // Boundary condition: num_top_classes == 1000 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(1000)); // num_top_classes == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(2)); } public void testGetTrainingPercent() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); } @@ -155,4 +165,48 @@ public void testGetParams() { public void testFieldCardinalityLimitsIsNonNull() { assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); } + + public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { + Classification classification = createRandom(); + assertThat(classification.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0"))); + String json = Strings.toString(builder); + assertThat(json, not(containsString("randomize_seed"))); + } + } + + public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException { + Classification classification = createRandom(); + assertThat(classification.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString()))); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenVersionIsNull() throws IOException { + Classification classification = createRandom(); + assertThat(classification.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + classification.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null))); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenEmptyParams() throws IOException { + Classification classification = createRandom(); + assertThat(classification.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + classification.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index f3d5312280e88..58e19f6ef6a2a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -6,16 +6,24 @@ package org.elasticsearch.xpack.core.ml.dataframe.analyses; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; +import java.util.Collections; import java.util.Map; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; public class RegressionTests extends AbstractSerializingTestCase { @@ -37,7 +45,8 @@ public static Regression createRandom() { BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); - return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent); + Long randomizeSeed = randomBoolean() ? null : randomLong(); + return new Regression(dependentVariableName, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed); } @Override @@ -47,40 +56,40 @@ protected Writeable.Reader instanceReader() { public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 0.999, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001)); + () -> new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0001, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testGetPredictionFieldName() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0); + Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong()); assertThat(regression.getPredictionFieldName(), equalTo("result")); - regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, null, 50.0, randomLong()); assertThat(regression.getPredictionFieldName(), equalTo("foo_prediction")); } public void testGetTrainingPercent() { - Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0); + Regression regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 50.0, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 1.0, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", 100.0, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null); + regression = new Regression("foo", BOOSTED_TREE_PARAMS, "result", null, randomLong()); assertThat(regression.getTrainingPercent(), equalTo(100.0)); } @@ -100,4 +109,48 @@ public void testGetStateDocId() { String randomId = randomAlphaOfLength(10); assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1")); } + + public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", "7.5.0"))); + String json = Strings.toString(builder); + assertThat(json, not(containsString("randomize_seed"))); + } + } + + public void testToXContent_GivenVersionAfterRandomizeSeedWasIntroduced() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", Version.CURRENT.toString()))); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenVersionIsNull() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap("version", null))); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } + + public void testToXContent_GivenEmptyParams() throws IOException { + Regression regression = createRandom(); + assertThat(regression.getRandomizeSeed(), is(notNullValue())); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + regression.toXContent(builder, ToXContent.EMPTY_PARAMS); + String json = Strings.toString(builder); + assertThat(json, containsString("randomize_seed")); + } + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index f5db9ae690a96..e7c0ccd0e0554 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -20,6 +20,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; @@ -31,6 +32,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import static java.util.stream.Collectors.toList; import static org.hamcrest.Matchers.allOf; @@ -158,7 +160,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0)); + new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -269,6 +271,44 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang assertProgress(jobId, 100, 100, 100, 100); } + public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception { + String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source"; + String dependentVariable = KEYWORD_FIELD; + indexData(sourceIndex, 10, 0, dependentVariable); + + String firstJobId = "classification_two_jobs_with_same_randomize_seed_1"; + String firstJobDestIndex = firstJobId + "_dest"; + + BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + + DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, + new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); + registerAnalytics(firstJob); + putAnalytics(firstJob); + + String secondJobId = "classification_two_jobs_with_same_randomize_seed_2"; + String secondJobDestIndex = secondJobId + "_dest"; + + long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); + DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, + new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed)); + + registerAnalytics(secondJob); + putAnalytics(secondJob); + + // Let's run both jobs in parallel and wait until they are finished + startAnalytics(firstJobId); + startAnalytics(secondJobId); + waitUntilAnalyticsIsStopped(firstJobId); + waitUntilAnalyticsIsStopped(secondJobId); + + // Now we compare they both used the same training rows + Set firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex); + Set secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex); + + assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds)); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; @@ -340,10 +380,10 @@ private static Map getMlResultsObjectFromDestDoc(Map void assertTopClasses( - Map resultsObject, - int numTopClasses, - String dependentVariable, - List dependentVariableValues) { + Map resultsObject, + int numTopClasses, + String dependentVariable, + List dependentVariableValues) { assertThat(resultsObject.containsKey("top_classes"), is(true)); List> topClasses = (List>) resultsObject.get("top_classes"); assertThat(topClasses, hasSize(numTopClasses)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 29ef54d3f7524..99223247d7305 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; @@ -45,7 +46,10 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -252,4 +256,22 @@ private static List fetchAllAuditMessages(String dataFrameAnalyticsId) { .map(hit -> (String) hit.getSourceAsMap().get("message")) .collect(Collectors.toList()); } + + protected static Set getTrainingRowsIds(String index) { + Set trainingRowsIds = new HashSet<>(); + SearchResponse hits = client().prepareSearch(index).get(); + for (SearchHit hit : hits.getHits()) { + Map sourceAsMap = hit.getSourceAsMap(); + assertThat(sourceAsMap.containsKey("ml"), is(true)); + @SuppressWarnings("unchecked") + Map resultsObject = (Map) sourceAsMap.get("ml"); + + assertThat(resultsObject.containsKey("is_training"), is(true)); + if (Boolean.TRUE.equals(resultsObject.get("is_training"))) { + trainingRowsIds.add(hit.getId()); + } + } + assertThat(trainingRowsIds.isEmpty(), is(false)); + return trainingRowsIds; + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 71ea840c53ea8..84d408daacc61 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -16,6 +16,7 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -25,6 +26,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; @@ -139,7 +141,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -235,6 +237,43 @@ public void testStopAndRestart() throws Exception { assertInferenceModelPersisted(jobId); } + public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exception { + String sourceIndex = "regression_two_jobs_with_same_randomize_seed_source"; + indexData(sourceIndex, 10, 0); + + String firstJobId = "regression_two_jobs_with_same_randomize_seed_1"; + String firstJobDestIndex = firstJobId + "_dest"; + + BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + + DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null)); + registerAnalytics(firstJob); + putAnalytics(firstJob); + + String secondJobId = "regression_two_jobs_with_same_randomize_seed_2"; + String secondJobDestIndex = secondJobId + "_dest"; + + long randomizeSeed = ((Regression) firstJob.getAnalysis()).getRandomizeSeed(); + DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, + new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, randomizeSeed)); + + registerAnalytics(secondJob); + putAnalytics(secondJob); + + // Let's run both jobs in parallel and wait until they are finished + startAnalytics(firstJobId); + startAnalytics(secondJobId); + waitUntilAnalyticsIsStopped(firstJobId); + waitUntilAnalyticsIsStopped(secondJobId); + + // Now we compare they both used the same training rows + Set firstRunTrainingRowsIds = getTrainingRowsIds(firstJobDestIndex); + Set secondRunTrainingRowsIds = getTrainingRowsIds(secondJobDestIndex); + + assertThat(secondRunTrainingRowsIds, equalTo(firstRunTrainingRowsIds)); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java index 2884cd331779e..1cbed7ed76613 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -111,7 +111,7 @@ protected ClusterBlockException checkBlock(PutDataFrameAnalyticsAction.Request r protected void masterOperation(Task task, PutDataFrameAnalyticsAction.Request request, ClusterState state, ActionListener listener) { validateConfig(request.getConfig()); - DataFrameAnalyticsConfig memoryCappedConfig = + DataFrameAnalyticsConfig preparedForPutConfig = new DataFrameAnalyticsConfig.Builder(request.getConfig(), maxModelMemoryLimit) .setCreateTime(Instant.now()) .setVersion(Version.CURRENT) @@ -120,11 +120,11 @@ protected void masterOperation(Task task, PutDataFrameAnalyticsAction.Request re if (licenseState.isAuthAllowed()) { final String username = securityContext.getUser().principal(); RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(memoryCappedConfig.getSource().getIndex()) + .indices(preparedForPutConfig.getSource().getIndex()) .privileges("read") .build(); RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(memoryCappedConfig.getDest().getIndex()) + .indices(preparedForPutConfig.getDest().getIndex()) .privileges("read", "index", "create_index") .build(); @@ -135,16 +135,16 @@ protected void masterOperation(Task task, PutDataFrameAnalyticsAction.Request re privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges); ActionListener privResponseListener = ActionListener.wrap( - r -> handlePrivsResponse(username, memoryCappedConfig, r, listener), + r -> handlePrivsResponse(username, preparedForPutConfig, r, listener), listener::onFailure); client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener); } else { updateDocMappingAndPutConfig( - memoryCappedConfig, + preparedForPutConfig, threadPool.getThreadContext().getHeaders(), ActionListener.wrap( - indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(memoryCappedConfig)), + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(preparedForPutConfig)), listener::onFailure )); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java index fd52a3fd8da58..77f0b127a2638 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java @@ -24,12 +24,12 @@ public CustomProcessor create(DataFrameAnalysis analysis) { if (analysis instanceof Regression) { Regression regression = (Regression) analysis; return new DatasetSplittingCustomProcessor( - fieldNames, regression.getDependentVariable(), regression.getTrainingPercent()); + fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed()); } if (analysis instanceof Classification) { Classification classification = (Classification) analysis; return new DatasetSplittingCustomProcessor( - fieldNames, classification.getDependentVariable(), classification.getTrainingPercent()); + fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed()); } return row -> {}; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java index ed42cf5198854..bf6284aa7a5c8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; -import org.elasticsearch.common.Randomness; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.util.List; @@ -23,12 +22,13 @@ class DatasetSplittingCustomProcessor implements CustomProcessor { private final int dependentVariableIndex; private final double trainingPercent; - private final Random random = Randomness.get(); + private final Random random; private boolean isFirstRow = true; - DatasetSplittingCustomProcessor(List fieldNames, String dependentVariable, double trainingPercent) { + DatasetSplittingCustomProcessor(List fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) { this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); this.trainingPercent = trainingPercent; + this.random = new Random(randomizeSeed); } private static int findDependentVariableIndex(List fieldNames, String dependentVariable) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java index d5973f8782461..d18adc3dcdb48 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java @@ -24,6 +24,7 @@ public class DatasetSplittingCustomProcessorTests extends ESTestCase { private List fields; private int dependentVariableIndex; private String dependentVariable; + private long randomizeSeed; @Before public void setUpTests() { @@ -34,10 +35,11 @@ public void setUpTests() { } dependentVariableIndex = randomIntBetween(0, fieldCount - 1); dependentVariable = fields.get(dependentVariableIndex); + randomizeSeed = randomLong(); } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -55,7 +57,7 @@ public void testProcess_GivenRowsWithoutDependentVariableValue() { } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -75,7 +77,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed); int runCount = 20; int rowsCount = 1000; @@ -121,7 +123,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0); + CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index a1d78b7444057..4335a50382a94 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1456,7 +1456,8 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, - "training_percent": 60.3 + "training_percent": 60.3, + "randomize_seed": 42 } } } @@ -1472,7 +1473,8 @@ setup: "maximum_number_trees": 400, "feature_bag_fraction": 0.3, "prediction_field_name": "foo_prediction", - "training_percent": 60.3 + "training_percent": 60.3, + "randomize_seed": 42 } }} - is_true: create_time @@ -1796,7 +1798,8 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, - "training_percent": 60.3 + "training_percent": 60.3, + "randomize_seed": 24 } } } @@ -1813,6 +1816,7 @@ setup: "feature_bag_fraction": 0.3, "prediction_field_name": "foo_prediction", "training_percent": 60.3, + "randomize_seed": 24, "num_top_classes": 2 } }} @@ -1836,7 +1840,8 @@ setup: }, "analysis": { "regression": { - "dependent_variable": "foo" + "dependent_variable": "foo", + "randomize_seed": 42 } } } @@ -1848,7 +1853,8 @@ setup: "regression":{ "dependent_variable": "foo", "prediction_field_name": "foo_prediction", - "training_percent": 100.0 + "training_percent": 100.0, + "randomize_seed": 42 } }} - is_true: create_time