From 67d3cb78a60dd5f0b46e7b87b8140c8d1e63b80a Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 12 Dec 2019 17:27:41 +0200 Subject: [PATCH] [ML] Persist/restore state for DFA classification (#50040) This commit adds state persist/restore for data frame analytics classification jobs. --- .../ml/dataframe/analyses/Classification.java | 4 ++-- .../analyses/ClassificationTests.java | 7 +++++++ .../xpack/ml/integration/ClassificationIT.java | 7 +++++++ ...lNativeDataFrameAnalyticsIntegTestCase.java | 7 +++++++ .../xpack/ml/integration/RegressionIT.java | 18 ++++++------------ 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index ed4cb1fe18f8e..cbd78b4f3baab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -251,12 +251,12 @@ public boolean supportsMissingValues() { @Override public boolean persistsState() { - return false; + return true; } @Override public String getStateDocId(String jobId) { - throw new UnsupportedOperationException(); + return jobId + "_classification_state#1"; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 8308ef8dad289..75a7410f181ba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -209,4 +209,11 @@ public void testToXContent_GivenEmptyParams() throws IOException { assertThat(json, containsString("randomize_seed")); } } + + public void testGetStateDocId() { + Classification classification = createRandom(); + assertThat(classification.persistsState(), is(true)); + String randomId = randomAlphaOfLength(10); + assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1")); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 0e49043fcfbe5..0c486fdeee678 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -95,6 +95,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", @@ -135,6 +136,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", @@ -195,6 +197,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", @@ -447,4 +450,8 @@ private void assertEvaluation(String dependentVariable, List dependentVar assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L)); } } + + protected String stateDocId() { + return jobId + "_classification_state#1"; + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 8ff82c28b36e0..980f5f4da5ecb 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -274,4 +274,11 @@ protected static Set getTrainingRowsIds(String index) { assertThat(trainingRowsIds.isEmpty(), is(false)); return trainingRowsIds; } + + protected static void assertModelStatePersisted(String stateDocId) { + SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) + .setQuery(QueryBuilders.idsQuery().addIds(stateDocId)) + .get(); + assertThat(searchResponse.getHits().getHits().length, equalTo(1)); + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 84d408daacc61..29480d711f37f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -12,14 +12,12 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; -import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.junit.After; import java.util.Arrays; @@ -82,7 +80,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); - assertModelStatePersisted(jobId); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", @@ -119,7 +117,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); - assertModelStatePersisted(jobId); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", @@ -171,7 +169,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); - assertModelStatePersisted(jobId); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [regression]", @@ -233,7 +231,7 @@ public void testStopAndRestart() throws Exception { assertProgress(jobId, 100, 100, 100, 100); assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); - assertModelStatePersisted(jobId); + assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); } @@ -324,11 +322,7 @@ private static Map getMlResultsObjectFromDestDoc(Map