Skip to content

Commit

Permalink
[ML] Persist/restore state for DFA classification (elastic#50040)
Browse files Browse the repository at this point in the history
This commit adds state persist/restore for data frame analytics classification jobs.
  • Loading branch information
dimitris-athanasiou authored and SivagurunathanV committed Jan 21, 2020
1 parent dbcab44 commit 67d3cb7
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ public boolean supportsMissingValues() {

@Override
public boolean persistsState() {
return false;
return true;
}

@Override
public String getStateDocId(String jobId) {
throw new UnsupportedOperationException();
return jobId + "_classification_state#1";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,11 @@ public void testToXContent_GivenEmptyParams() throws IOException {
assertThat(json, containsString("randomize_seed"));
}
}

public void testGetStateDocId() {
Classification classification = createRandom();
assertThat(classification.persistsState(), is(true));
String randomId = randomAlphaOfLength(10);
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
Expand Down Expand Up @@ -135,6 +136,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
Expand Down Expand Up @@ -195,6 +197,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
Expand Down Expand Up @@ -447,4 +450,8 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
}
}

protected String stateDocId() {
return jobId + "_classification_state#1";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,11 @@ protected static Set<String> getTrainingRowsIds(String index) {
assertThat(trainingRowsIds.isEmpty(), is(false));
return trainingRowsIds;
}

protected static void assertModelStatePersisted(String stateDocId) {
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(stateDocId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.junit.After;

import java.util.Arrays;
Expand Down Expand Up @@ -82,7 +80,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
Expand Down Expand Up @@ -119,7 +117,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
Expand Down Expand Up @@ -171,7 +169,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [regression]",
Expand Down Expand Up @@ -233,7 +231,7 @@ public void testStopAndRestart() throws Exception {

assertProgress(jobId, 100, 100, 100, 100);
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(jobId);
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
}

Expand Down Expand Up @@ -324,11 +322,7 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
return resultsObject;
}

private static void assertModelStatePersisted(String jobId) {
String docId = jobId + "_regression_state#1";
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setQuery(QueryBuilders.idsQuery().addIds(docId))
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
protected String stateDocId() {
return jobId + "_regression_state#1";
}
}

0 comments on commit 67d3cb7

Please sign in to comment.