diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 1fa16f5059d31..d2519086c5cff 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -248,6 +248,14 @@ public Map getFieldCardinalityLimits() { return Collections.singletonMap(dependentVariable, 2L); } + @Override + public Map getExplicitlyMappedFields(String resultsFieldName) { + return new HashMap<>() {{ + put(resultsFieldName + "." + predictionFieldName, dependentVariable); + put(resultsFieldName + ".top_classes.class_name", dependentVariable); + }}; + } + @Override public boolean supportsMissingValues() { return true; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index d0af0a452a474..74cdc5824cbdc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -41,6 +41,17 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { */ Map getFieldCardinalityLimits(); + /** + * Returns fields for which the mappings should be copied from source index to destination index. + * Each entry of the returned {@link Map} is of the form: + * key - field path in the destination index + * value - field path in the source index from which the mapping should be taken + * + * @param resultsFieldName name of the results field under which all the results are stored + * @return {@link Map} containing fields for which the mappings should be copied from source index to destination index + */ + Map getExplicitlyMappedFields(String resultsFieldName); + /** * @return {@code true} if this analysis supports data frame rows with missing values */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 70b3cfb9fe246..81c4673809368 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -229,6 +229,11 @@ public Map getFieldCardinalityLimits() { return Collections.emptyMap(); } + @Override + public Map getExplicitlyMappedFields(String resultsFieldName) { + return Collections.emptyMap(); + } + @Override public boolean supportsMissingValues() { return false; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 27c8a3f2eb7ca..fe2927591312a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -186,6 +186,11 @@ public Map getFieldCardinalityLimits() { return Collections.emptyMap(); } + @Override + public Map getExplicitlyMappedFields(String resultsFieldName) { + return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable); + } + @Override public boolean supportsMissingValues() { return true; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 8b23fe619efc8..7a0af05071b7f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -25,7 +25,9 @@ import java.util.Map; import java.util.Set; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; @@ -161,8 +163,16 @@ public void testGetParams() { hasEntry("prediction_field_type", "string"))); } - public void testFieldCardinalityLimitsIsNonNull() { - assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); + public void testRequiredFieldsIsNonEmpty() { + assertThat(createTestInstance().getRequiredFields(), is(not(empty()))); + } + + public void testFieldCardinalityLimitsIsNonEmpty() { + assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap()))); + } + + public void testFieldMappingsToCopyIsNonEmpty() { + assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap()))); } public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java index c35b9a3bad1af..5b7a23b46ff24 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java @@ -12,11 +12,11 @@ import java.io.IOException; import java.util.Map; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.nullValue; public class OutlierDetectionTests extends AbstractSerializingTestCase { @@ -84,8 +84,16 @@ public void testGetParams_GivenExplicitValues() { assertThat(params.get(OutlierDetection.STANDARDIZATION_ENABLED.getPreferredName()), is(false)); } - public void testFieldCardinalityLimitsIsNonNull() { - assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); + public void testRequiredFieldsIsEmpty() { + assertThat(createTestInstance().getRequiredFields(), is(empty())); + } + + public void testFieldCardinalityLimitsIsEmpty() { + assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap())); + } + + public void testFieldMappingsToCopyIsEmpty() { + assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap())); } public void testGetStateDocId() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index d45125bbc3d7e..c123a0553d190 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -19,7 +19,9 @@ import java.util.Collections; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; @@ -100,8 +102,16 @@ public void testGetParams() { allOf(hasEntry("dependent_variable", "foo"), hasEntry("prediction_field_name", "foo_prediction"))); } - public void testFieldCardinalityLimitsIsNonNull() { - assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); + public void testRequiredFieldsIsNonEmpty() { + assertThat(createTestInstance().getRequiredFields(), is(not(empty()))); + } + + public void testFieldCardinalityLimitsIsEmpty() { + assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap())); + } + + public void testFieldMappingsToCopyIsNonEmpty() { + assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap()))); } public void testGetStateDocId() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index a95d104eee97f..96406434227bd 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -7,6 +7,8 @@ import com.google.common.collect.Ordering; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.admin.indices.get.GetIndexAction; +import org.elasticsearch.action.admin.indices.get.GetIndexRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -39,6 +41,7 @@ import java.util.Set; import static java.util.stream.Collectors.toList; +import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -64,6 +67,7 @@ public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private String jobId; private String sourceIndex; private String destIndex; + private boolean analysisUsesExistingDestIndex; @After public void cleanup() { @@ -72,6 +76,7 @@ public void cleanup() { public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws Exception { initialize("classification_single_numeric_feature_and_mixed_data_set"); + String predictedClassField = KEYWORD_FIELD + "_prediction"; indexData(sourceIndex, 300, 50, KEYWORD_FIELD); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); @@ -88,12 +93,9 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { Map destDoc = getDestDoc(config, hit); - Map resultsObject = getMlResultsObjectFromDestDoc(destDoc); - - assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); - assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); - assertThat(resultsObject.containsKey("is_training"), is(true)); - assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); + Map resultsObject = getFieldValue(destDoc, "ml"); + assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); + assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } @@ -101,19 +103,21 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, "keyword"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", "Estimated memory usage for this analytics to be", "Starting analytics on node", "Started analytics", - "Creating destination index [" + destIndex + "]", + expectedDestIndexAuditMessage(), "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Exception { initialize("classification_only_training_data_and_training_percent_is_100"); + String predictedClassField = KEYWORD_FIELD + "_prediction"; indexData(sourceIndex, 300, 0, KEYWORD_FIELD); DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); @@ -129,12 +133,10 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti client().admin().indices().refresh(new RefreshRequest(destIndex)); SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { - Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - - assertThat(resultsObject.containsKey("keyword-field_prediction"), is(true)); - assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); - assertThat(resultsObject.containsKey("is_training"), is(true)); - assertThat(resultsObject.get("is_training"), is(true)); + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getFieldValue(destDoc, "ml"); + assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); + assertThat(getFieldValue(resultsObject, "is_training"), is(true)); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); } @@ -142,19 +144,22 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, "keyword"); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", "Estimated memory usage for this analytics to be", "Starting analytics on node", "Started analytics", - "Creating destination index [" + destIndex + "]", + expectedDestIndexAuditMessage(), "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); - assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword"); + assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } - public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - String jobId, String dependentVariable, List dependentVariableValues) throws Exception { + public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, + String dependentVariable, + List dependentVariableValues, + String expectedMappingTypeForPredictedField) throws Exception { initialize(jobId); String predictedClassField = dependentVariable + "_prediction"; indexData(sourceIndex, 300, 0, dependentVariable); @@ -181,16 +186,13 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( client().admin().indices().refresh(new RefreshRequest(destIndex)); SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { - Map resultsObject = getMlResultsObjectFromDestDoc(getDestDoc(config, hit)); - assertThat(resultsObject.containsKey(predictedClassField), is(true)); - @SuppressWarnings("unchecked") - T predictedClassValue = (T) resultsObject.get(predictedClassField); - assertThat(predictedClassValue, is(in(dependentVariableValues))); + Map destDoc = getDestDoc(config, hit); + Map resultsObject = getFieldValue(destDoc, "ml"); + assertThat(getFieldValue(resultsObject, predictedClassField), is(in(dependentVariableValues))); assertTopClasses(resultsObject, numTopClasses, dependentVariable, dependentVariableValues); - assertThat(resultsObject.containsKey("is_training"), is(true)); // Let's just assert there's both training and non-training results - if ((boolean) resultsObject.get("is_training")) { + if (getFieldValue(resultsObject, "is_training")) { trainingRowsCount++; } else { nonTrainingRowsCount++; @@ -203,40 +205,39 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty( assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L)); assertModelStatePersisted(stateDocId()); assertInferenceModelPersisted(jobId); + assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField); assertThatAuditMessagesMatch(jobId, "Created analytics with analysis type [classification]", "Estimated memory usage for this analytics to be", "Starting analytics on node", "Started analytics", - "Creating destination index [" + destIndex + "]", + expectedDestIndexAuditMessage(), "Finished reindexing to destination index [" + destIndex + "]", "Finished analysis"); + assertEvaluation(dependentVariable, dependentVariableValues, "ml." + predictedClassField); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsKeyword() throws Exception { testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES); - assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml.keyword-field_prediction.keyword"); + "classification_training_percent_is_50_keyword", KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "keyword"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsInteger() throws Exception { testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES); - assertEvaluation(DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "ml.discrete-numerical-field_prediction"); + "classification_training_percent_is_50_integer", DISCRETE_NUMERICAL_FIELD, DISCRETE_NUMERICAL_FIELD_VALUES, "integer"); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsDouble() throws Exception { ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, () -> testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES)); + "classification_training_percent_is_50_double", NUMERICAL_FIELD, NUMERICAL_FIELD_VALUES, null)); assertThat(e.getMessage(), startsWith("invalid types [double] for required field [numerical-field];")); } public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableIsBoolean() throws Exception { testWithOnlyTrainingRowsAndTrainingPercentIsFifty( - "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES); - assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction"); + "classification_training_percent_is_50_boolean", BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "boolean"); } public void testDependentVariableCardinalityTooHighError() throws Exception { @@ -281,6 +282,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio String sourceIndex = "classification_two_jobs_with_same_randomize_seed_source"; String dependentVariable = KEYWORD_FIELD; + createIndex(sourceIndex); // We use 100 rows as we can't set this too low. If too low it is possible // we only train with rows of one of the two classes which leads to a failure. indexData(sourceIndex, 100, 0, dependentVariable); @@ -354,17 +356,24 @@ private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; this.destIndex = sourceIndex + "_results"; + this.analysisUsesExistingDestIndex = randomBoolean(); + createIndex(sourceIndex); + if (analysisUsesExistingDestIndex) { + createIndex(destIndex); + } } - private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) { - client().admin().indices().prepareCreate(sourceIndex) + private static void createIndex(String index) { + client().admin().indices().prepareCreate(index) .addMapping("_doc", BOOLEAN_FIELD, "type=boolean", NUMERICAL_FIELD, "type=double", DISCRETE_NUMERICAL_FIELD, "type=integer", KEYWORD_FIELD, "type=keyword") .get(); + } + private static void indexData(String sourceIndex, int numTrainingRows, int numNonTrainingRows, String dependentVariable) { BulkRequestBuilder bulkRequestBuilder = client().prepareBulk() .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); for (int i = 0; i < numTrainingRows; i++) { @@ -407,34 +416,30 @@ private static Map getDestDoc(DataFrameAnalyticsConfig config, S Map sourceDoc = hit.getSourceAsMap(); Map destDoc = destDocGetResponse.getSource(); for (String field : sourceDoc.keySet()) { - assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc, hasKey(field)); assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); } return destDoc; } - private static Map getMlResultsObjectFromDestDoc(Map destDoc) { - assertThat(destDoc.containsKey("ml"), is(true)); - @SuppressWarnings("unchecked") - Map resultsObject = (Map) destDoc.get("ml"); - return resultsObject; + /** + * Wrapper around extractValue with implicit casting to the appropriate type. + */ + private static T getFieldValue(Map doc, String... path) { + return (T)extractValue(doc, path); } - @SuppressWarnings("unchecked") - private static void assertTopClasses( - Map resultsObject, - int numTopClasses, - String dependentVariable, - List dependentVariableValues) { - assertThat(resultsObject.containsKey("top_classes"), is(true)); - List> topClasses = (List>) resultsObject.get("top_classes"); + private static void assertTopClasses(Map resultsObject, + int numTopClasses, + String dependentVariable, + List dependentVariableValues) { + List> topClasses = getFieldValue(resultsObject, "top_classes"); assertThat(topClasses, hasSize(numTopClasses)); List classNames = new ArrayList<>(topClasses.size()); List classProbabilities = new ArrayList<>(topClasses.size()); for (Map topClass : topClasses) { - assertThat(topClass, allOf(hasKey("class_name"), hasKey("class_probability"))); - classNames.add((T) topClass.get("class_name")); - classProbabilities.add((Double) topClass.get("class_probability")); + classNames.add(getFieldValue(topClass, "class_name")); + classProbabilities.add(getFieldValue(topClass, "class_probability")); } // Assert that all the predicted class names come from the set of dependent variable values. classNames.forEach(className -> assertThat(className, is(in(dependentVariableValues)))); @@ -507,7 +512,25 @@ private void assertEvaluation(String dependentVariable, List dependentVar } } - protected String stateDocId() { + private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) { + Map mappings = + client() + .execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex)) + .actionGet() + .mappings() + .get(destIndex) + .sourceAsMap(); + assertThat(getFieldValue(mappings, "properties", "ml", "properties", predictedClassField, "type"), equalTo(expectedType)); + assertThat( + getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"), + equalTo(expectedType)); + } + + private String stateDocId() { return jobId + "_classification_state#1"; } + + private String expectedDestIndexAuditMessage() { + return (analysisUsesExistingDestIndex ? "Using existing" : "Creating") + " destination index [" + destIndex + "]"; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java index 7f19deb8d5ba0..46393a3277153 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java @@ -84,12 +84,12 @@ protected void doExecute(Task task, private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ActionListener listener) { ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(client); - extractedFieldsDetectorFactory.createFromSource(request.getConfig(), true, ActionListener.wrap( - extractedFieldsDetector -> { - explain(task, request, extractedFieldsDetector, listener); - }, - listener::onFailure - )); + extractedFieldsDetectorFactory.createFromSource( + request.getConfig(), + ActionListener.wrap( + extractedFieldsDetector -> explain(task, request, extractedFieldsDetector, listener), + listener::onFailure) + ); } private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFieldsDetector extractedFieldsDetector, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index ee7f096ed696c..02c63f85357a1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -275,8 +275,7 @@ private void getStartContext(String id, ActionListener finalListen new SourceDestValidator(clusterService.state(), indexNameExpressionResolver).check(startContext.config); // Validate extraction is possible - boolean isTaskRestarting = startContext.startingState != DataFrameAnalyticsTask.StartingState.FIRST_TIME; - new ExtractedFieldsDetectorFactory(client).createFromSource(startContext.config, isTaskRestarting, ActionListener.wrap( + new ExtractedFieldsDetectorFactory(client).createFromSource(startContext.config, ActionListener.wrap( extractedFieldsDetector -> { startContext.extractedFields = extractedFieldsDetector.detect().v1(); toValidateDestEmptyListener.onResponse(startContext); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java index a84585620ece0..2a54741a5780d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndex.java @@ -28,6 +28,8 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.time.Clock; import java.util.Collections; @@ -82,7 +84,7 @@ public static void createDestinationIndex(Client client, } private static void prepareCreateIndexRequest(Client client, Clock clock, DataFrameAnalyticsConfig config, - ActionListener listener) { + ActionListener listener) { AtomicReference settingsHolder = new AtomicReference<>(); ActionListener> mappingsListener = ActionListener.wrap( @@ -103,12 +105,13 @@ private static void prepareCreateIndexRequest(Client client, Clock clock, DataFr listener::onFailure ); - GetSettingsRequest getSettingsRequest = new GetSettingsRequest(); - getSettingsRequest.indices(config.getSource().getIndex()); - getSettingsRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); - getSettingsRequest.names(PRESERVED_SETTINGS); - ClientHelper.executeWithHeadersAsync(config.getHeaders(), ML_ORIGIN, client, GetSettingsAction.INSTANCE, - getSettingsRequest, getSettingsResponseListener); + GetSettingsRequest getSettingsRequest = + new GetSettingsRequest() + .indices(config.getSource().getIndex()) + .indicesOptions(IndicesOptions.lenientExpandOpen()) + .names(PRESERVED_SETTINGS); + ClientHelper.executeWithHeadersAsync( + config.getHeaders(), ML_ORIGIN, client, GetSettingsAction.INSTANCE, getSettingsRequest, getSettingsResponseListener); } private static CreateIndexRequest createIndexRequest(Clock clock, DataFrameAnalyticsConfig config, Settings settings, @@ -119,8 +122,11 @@ private static CreateIndexRequest createIndexRequest(Clock clock, DataFrameAnaly String destinationIndex = config.getDest().getIndex(); String type = mappings.keysIt().next(); Map mappingsAsMap = mappings.valuesIt().next().sourceAsMap(); - addProperties(mappingsAsMap); - addMetaData(mappingsAsMap, config.getId(), clock); + Map properties = getOrPutDefault(mappingsAsMap, PROPERTIES, HashMap::new); + checkResultsFieldIsNotPresentInProperties(config, properties); + properties.putAll(createAdditionalMappings(config, Collections.unmodifiableMap(properties))); + Map metadata = getOrPutDefault(mappingsAsMap, META, HashMap::new); + metadata.putAll(createMetaData(config.getId(), clock)); return new CreateIndexRequest(destinationIndex, settings).mapping(type, mappingsAsMap); } @@ -154,21 +160,32 @@ private static Integer findMaxSettingValue(GetSettingsResponse settingsResponse, return maxValue; } - private static void addProperties(Map mappingsAsMap) { - Map properties = getOrPutDefault(mappingsAsMap, PROPERTIES, HashMap::new); + private static Map createAdditionalMappings(DataFrameAnalyticsConfig config, Map mappingsProperties) { + Map properties = new HashMap<>(); Map idCopyMapping = new HashMap<>(); idCopyMapping.put("type", "keyword"); properties.put(ID_COPY, idCopyMapping); + for (Map.Entry entry + : config.getAnalysis().getExplicitlyMappedFields(config.getDest().getResultsField()).entrySet()) { + String destFieldPath = entry.getKey(); + String sourceFieldPath = entry.getValue(); + Object sourceFieldMapping = mappingsProperties.get(sourceFieldPath); + if (sourceFieldMapping != null) { + properties.put(destFieldPath, sourceFieldMapping); + } + } + return properties; } - private static void addMetaData(Map mappingsAsMap, String analyticsId, Clock clock) { - Map metadata = getOrPutDefault(mappingsAsMap, META, HashMap::new); + private static Map createMetaData(String analyticsId, Clock clock) { + Map metadata = new HashMap<>(); metadata.put(CREATION_DATE_MILLIS, clock.millis()); metadata.put(CREATED_BY, "data-frame-analytics"); Map versionMapping = new HashMap<>(); versionMapping.put(CREATED, Version.CURRENT); metadata.put(VERSION, versionMapping); metadata.put(ANALYTICS, analyticsId); + return metadata; } @SuppressWarnings("unchecked") @@ -181,22 +198,44 @@ private static V getOrPutDefault(Map map, K key, Supplier v return value; } - public static void updateMappingsToDestIndex(Client client, DataFrameAnalyticsConfig analyticsConfig, GetIndexResponse getIndexResponse, + @SuppressWarnings("unchecked") + public static void updateMappingsToDestIndex(Client client, DataFrameAnalyticsConfig config, GetIndexResponse getIndexResponse, ActionListener listener) { // We have validated the destination index should match a single index assert getIndexResponse.indices().length == 1; - ImmutableOpenMap mappings = getIndexResponse.getMappings().get(getIndexResponse.indices()[0]); - String type = mappings.keysIt().next(); - - Map addedMappings = Collections.singletonMap(PROPERTIES, - Collections.singletonMap(ID_COPY, Collections.singletonMap("type", "keyword"))); + // Fetch mappings from destination index + String type = getIndexResponse.mappings().keysIt().next(); + Map destMappingsAsMap = getIndexResponse.mappings().valuesIt().next().sourceAsMap(); + Map destPropertiesAsMap = + (Map)destMappingsAsMap.getOrDefault(PROPERTIES, Collections.emptyMap()); + + // Verify that the results field does not exist in the dest index + checkResultsFieldIsNotPresentInProperties(config, destPropertiesAsMap); + + // Determine mappings to be added to the destination index + Map addedMappings = + Collections.singletonMap(PROPERTIES, createAdditionalMappings(config, Collections.unmodifiableMap(destPropertiesAsMap))); + + // Add the mappings to the destination index + PutMappingRequest putMappingRequest = + new PutMappingRequest(getIndexResponse.indices()) + .type(type); + .source(addedMappings); + ClientHelper.executeWithHeadersAsync( + config.getHeaders(), ML_ORIGIN, client, PutMappingAction.INSTANCE, putMappingRequest, listener); + } - PutMappingRequest putMappingRequest = new PutMappingRequest(getIndexResponse.indices()); - putMappingRequest.type(type); - putMappingRequest.source(addedMappings); - ClientHelper.executeWithHeadersAsync(analyticsConfig.getHeaders(), ML_ORIGIN, client, PutMappingAction.INSTANCE, - putMappingRequest, listener); + private static void checkResultsFieldIsNotPresentInProperties(DataFrameAnalyticsConfig config, Map properties) { + String resultsField = config.getDest().getResultsField(); + if (properties.containsKey(resultsField)) { + throw ExceptionsHelper.badRequestException( + "A field that matches the {}.{} [{}] already exists; please set a different {}", + DataFrameAnalyticsConfig.DEST.getPreferredName(), + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), + resultsField, + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName()); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 8e89113be7eba..6bcd22997fbdb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -76,7 +76,7 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current // The task has fully reindexed the documents and we should continue on with our analyses case ANALYZING: LOGGER.debug("[{}] Reassigning job that was analyzing", config.getId()); - startAnalytics(task, config, true); + startAnalytics(task, config); break; // If we are already at REINDEXING, we are not 100% sure if we reindexed ALL the docs. // We will delete the destination index, recreate, reindex @@ -124,7 +124,7 @@ private void executeStartingJob(DataFrameAnalyticsTask task, DataFrameAnalyticsC )); break; case RESUMING_ANALYZING: - startAnalytics(task, config, true); + startAnalytics(task, config); break; case FINISHED: default: @@ -168,7 +168,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF auditor.info( config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_REINDEXING, config.getDest().getIndex())); - startAnalytics(task, config, false); + startAnalytics(task, config); }, error -> task.updateState(DataFrameAnalyticsState.FAILED, error.getMessage()) ); @@ -223,7 +223,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF new GetIndexRequest().indices(config.getDest().getIndex()), destIndexListener); } - private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, boolean isTaskRestarting) { + private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { // Ensure we mark reindexing is finished for the case we are recovering a task that had finished reindexing task.setReindexingFinished(); @@ -249,7 +249,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi // TODO This could fail with errors. In that case we get stuck with the copied index. // We could delete the index in case of failure or we could try building the factory before reindexing // to catch the error early on. - DataFrameDataExtractorFactory.createForDestinationIndex(client, config, isTaskRestarting, dataExtractorFactoryListener); + DataFrameDataExtractorFactory.createForDestinationIndex(client, config, dataExtractorFactoryListener); } public void stop(DataFrameAnalyticsTask task) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 41b8f8293fcc5..7448a1af6eb99 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -267,10 +267,7 @@ private SearchRequestBuilder buildDataSummarySearchRequestBuilder() { } public Set getCategoricalFields(DataFrameAnalysis analysis) { - return context.extractedFields.getAllFields().stream() - .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()).containsAll(extractedField.getTypes())) - .map(ExtractedField::getName) - .collect(Collectors.toSet()); + return ExtractedFieldsDetector.getCategoricalFields(context.extractedFields, analysis); } public static class DataSummary { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index c7d27805c3b4e..3243d92bf77b6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -31,8 +31,8 @@ public class DataFrameDataExtractorFactory { private final boolean includeRowsWithMissingValues; private DataFrameDataExtractorFactory(Client client, String analyticsId, List indices, QueryBuilder sourceQuery, - ExtractedFields extractedFields, Map headers, - boolean includeRowsWithMissingValues) { + ExtractedFields extractedFields, Map headers, + boolean includeRowsWithMissingValues) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.indices = Objects.requireNonNull(indices); @@ -100,15 +100,13 @@ public static DataFrameDataExtractorFactory createForSourceIndices(Client client * * @param client ES Client used to make calls against the cluster * @param config The config from which to create the extractor factory - * @param isTaskRestarting Whether the task is restarting * @param listener The listener to notify on creation or failure */ public static void createForDestinationIndex(Client client, DataFrameAnalyticsConfig config, - boolean isTaskRestarting, ActionListener listener) { ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(client); - extractedFieldsDetectorFactory.createFromDest(config, isTaskRestarting, ActionListener.wrap( + extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap( extractedFieldsDetector -> { ExtractedFields extractedFields = extractedFieldsDetector.detect().v1(); DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index b4bc63f5c0696..632efd6ab85ae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -17,7 +17,7 @@ import org.elasticsearch.index.mapper.BooleanFieldMapper; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; @@ -53,16 +53,14 @@ public class ExtractedFieldsDetector { private final String[] index; private final DataFrameAnalyticsConfig config; - private final boolean isTaskRestarting; private final int docValueFieldsLimit; private final FieldCapabilitiesResponse fieldCapabilitiesResponse; private final Map fieldCardinalities; - ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, boolean isTaskRestarting, int docValueFieldsLimit, + ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse, Map fieldCardinalities) { this.index = Objects.requireNonNull(index); this.config = Objects.requireNonNull(config); - this.isTaskRestarting = isTaskRestarting; this.docValueFieldsLimit = docValueFieldsLimit; this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse); this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities); @@ -83,7 +81,6 @@ public Tuple> detect() { private Set getIncludedFields(Set fieldSelection) { Set fields = new TreeSet<>(fieldCapabilitiesResponse.get().keySet()); fields.removeAll(IGNORE_FIELDS); - checkResultsFieldIsNotPresent(); removeFieldsUnderResultsField(fields); applySourceFiltering(fields); FetchSourceContext analyzedFields = config.getAnalyzedFields(); @@ -115,24 +112,6 @@ private void removeFieldsUnderResultsField(Set fields) { fields.removeIf(field -> field.startsWith(resultsField + ".")); } - private void checkResultsFieldIsNotPresent() { - // If the task is restarting we do not mind the index containing the results field, we will overwrite all docs - if (isTaskRestarting) { - return; - } - - String resultsField = config.getDest().getResultsField(); - Map indexToFieldCaps = fieldCapabilitiesResponse.getField(resultsField); - if (indexToFieldCaps != null && indexToFieldCaps.isEmpty() == false) { - throw ExceptionsHelper.badRequestException( - "A field that matches the {}.{} [{}] already exists; please set a different {}", - DataFrameAnalyticsConfig.DEST.getPreferredName(), - DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), - resultsField, - DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName()); - } - } - private void applySourceFiltering(Set fields) { Iterator fieldsIterator = fields.iterator(); while (fieldsIterator.hasNext()) { @@ -395,7 +374,7 @@ private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFi private void addIncludedFields(ExtractedFields extractedFields, Set fieldSelection) { Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) .collect(Collectors.toSet()); - Set categoricalFields = getCategoricalFields(extractedFields); + Set categoricalFields = getCategoricalFields(extractedFields, config.getAnalysis()); for (ExtractedField includedField : extractedFields.getAllFields()) { FieldSelection.FeatureType featureType = categoricalFields.contains(includedField.getName()) ? FieldSelection.FeatureType.CATEGORICAL : FieldSelection.FeatureType.NUMERICAL; @@ -404,9 +383,9 @@ private void addIncludedFields(ExtractedFields extractedFields, Set getCategoricalFields(ExtractedFields extractedFields) { + static Set getCategoricalFields(ExtractedFields extractedFields, DataFrameAnalysis analysis) { return extractedFields.getAllFields().stream() - .filter(extractedField -> config.getAnalysis().getAllowedCategoricalTypes(extractedField.getName()) + .filter(extractedField -> analysis.getAllowedCategoricalTypes(extractedField.getName()) .containsAll(extractedField.getTypes())) .map(ExtractedField::getName) .collect(Collectors.toSet()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java index c44555921cf38..b2d9122ef5eb8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java @@ -49,26 +49,24 @@ public ExtractedFieldsDetectorFactory(Client client) { this.client = Objects.requireNonNull(client); } - public void createFromSource(DataFrameAnalyticsConfig config, boolean isTaskRestarting, - ActionListener listener) { - create(config.getSource().getIndex(), config, isTaskRestarting, listener); + public void createFromSource(DataFrameAnalyticsConfig config, ActionListener listener) { + create(config.getSource().getIndex(), config, listener); } - public void createFromDest(DataFrameAnalyticsConfig config, boolean isTaskRestarting, - ActionListener listener) { - create(new String[] {config.getDest().getIndex()}, config, isTaskRestarting, listener); + public void createFromDest(DataFrameAnalyticsConfig config, ActionListener listener) { + create(new String[] {config.getDest().getIndex()}, config, listener); } - private void create(String[] index, DataFrameAnalyticsConfig config, boolean isTaskRestarting, - ActionListener listener) { + private void create(String[] index, DataFrameAnalyticsConfig config, ActionListener listener) { AtomicInteger docValueFieldsLimitHolder = new AtomicInteger(); AtomicReference fieldCapsResponseHolder = new AtomicReference<>(); // Step 4. Create cardinality by field map and build detector ActionListener> fieldCardinalitiesHandler = ActionListener.wrap( fieldCardinalities -> { - ExtractedFieldsDetector detector = new ExtractedFieldsDetector(index, config, isTaskRestarting, - docValueFieldsLimitHolder.get(), fieldCapsResponseHolder.get(), fieldCardinalities); + ExtractedFieldsDetector detector = + new ExtractedFieldsDetector( + index, config, docValueFieldsLimitHolder.get(), fieldCapsResponseHolder.get(), fieldCardinalities); listener.onResponse(detector); }, listener::onFailure diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java index afe12aa4ce68f..65131a48501f5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsIndexTests.java @@ -5,18 +5,22 @@ */ package org.elasticsearch.xpack.ml.dataframe; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.indices.create.CreateIndexAction; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; -import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.admin.indices.get.GetIndexResponse; import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsAction; import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsRequest; import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsResponse; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingAction; +import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequest; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsAction; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MappingMetaData; @@ -30,8 +34,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.junit.Assert; +import org.junit.Before; import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; import java.io.IOException; import java.time.Clock; @@ -43,13 +53,19 @@ import java.util.Map; import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue; +import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; public class DataFrameAnalyticsIndexTests extends ESTestCase { @@ -57,13 +73,7 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase { private static final String ANALYTICS_ID = "some-analytics-id"; private static final String[] SOURCE_INDEX = new String[] {"source-index"}; private static final String DEST_INDEX = "dest-index"; - private static final DataFrameAnalyticsConfig ANALYTICS_CONFIG = - new DataFrameAnalyticsConfig.Builder() - .setId(ANALYTICS_ID) - .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, null)) - .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) - .setAnalysis(new OutlierDetection.Builder().build()) - .build(); + private static final String DEPENDENT_VARIABLE = "dep_var"; private static final int CURRENT_TIME_MILLIS = 123456789; private static final String CREATED_BY = "data-frame-analytics"; @@ -71,18 +81,17 @@ public class DataFrameAnalyticsIndexTests extends ESTestCase { private Client client = mock(Client.class); private Clock clock = Clock.fixed(Instant.ofEpochMilli(123456789L), ZoneId.systemDefault()); - public void testCreateDestinationIndex() throws IOException { + @Before + public void setUpMocks() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + } + + private Map testCreateDestinationIndex(DataFrameAnalysis analysis) throws IOException { + DataFrameAnalyticsConfig config = createConfig(analysis); ArgumentCaptor createIndexRequestCaptor = ArgumentCaptor.forClass(CreateIndexRequest.class); - doAnswer( - invocationOnMock -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(null); - return null; - }) + doAnswer(callListenerOnResponse(null)) .when(client).execute(eq(CreateIndexAction.INSTANCE), createIndexRequestCaptor.capture(), any()); Settings index1Settings = Settings.builder() @@ -106,24 +115,20 @@ public void testCreateDestinationIndex() throws IOException { GetSettingsResponse getSettingsResponse = new GetSettingsResponse(indexToSettings.build(), ImmutableOpenMap.of()); - doAnswer( - invocationOnMock -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getSettingsResponse); - return null; - } - ).when(client).execute(eq(GetSettingsAction.INSTANCE), getSettingsRequestCaptor.capture(), any()); + doAnswer(callListenerOnResponse(getSettingsResponse)) + .when(client).execute(eq(GetSettingsAction.INSTANCE), getSettingsRequestCaptor.capture(), any()); Map index1Properties = new HashMap<>(); index1Properties.put("field_1", "field_1_mappings"); index1Properties.put("field_2", "field_2_mappings"); + index1Properties.put(DEPENDENT_VARIABLE, Collections.singletonMap("type", "integer")); Map index1Mappings = Collections.singletonMap("properties", index1Properties); MappingMetaData index1MappingMetaData = new MappingMetaData("_doc", index1Mappings); Map index2Properties = new HashMap<>(); index2Properties.put("field_1", "field_1_mappings"); index2Properties.put("field_2", "field_2_mappings"); + index2Properties.put(DEPENDENT_VARIABLE, Collections.singletonMap("type", "integer")); Map index2Mappings = Collections.singletonMap("properties", index2Properties); MappingMetaData index2MappingMetaData = new MappingMetaData("_doc", index2Mappings); @@ -138,19 +143,13 @@ public void testCreateDestinationIndex() throws IOException { GetMappingsResponse getMappingsResponse = new GetMappingsResponse(mappings.build()); - doAnswer( - invocationOnMock -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; - listener.onResponse(getMappingsResponse); - return null; - } - ).when(client).execute(eq(GetMappingsAction.INSTANCE), getMappingsRequestCaptor.capture(), any()); + doAnswer(callListenerOnResponse(getMappingsResponse)) + .when(client).execute(eq(GetMappingsAction.INSTANCE), getMappingsRequestCaptor.capture(), any()); DataFrameAnalyticsIndex.createDestinationIndex( client, clock, - ANALYTICS_CONFIG, + config, ActionListener.wrap( response -> {}, e -> fail(e.getMessage()))); @@ -179,6 +178,141 @@ public void testCreateDestinationIndex() throws IOException { assertThat(extractValue("_doc._meta.analytics", map), equalTo(ANALYTICS_ID)); assertThat(extractValue("_doc._meta.creation_date_in_millis", map), equalTo(CURRENT_TIME_MILLIS)); assertThat(extractValue("_doc._meta.created_by", map), equalTo(CREATED_BY)); + return map; + } + } + + public void testCreateDestinationIndex_OutlierDetection() throws IOException { + testCreateDestinationIndex(new OutlierDetection.Builder().build()); + } + + public void testCreateDestinationIndex_Regression() throws IOException { + Map map = testCreateDestinationIndex(new Regression(DEPENDENT_VARIABLE)); + assertThat(extractValue("_doc.properties.ml.dep_var_prediction.type", map), equalTo("integer")); + } + + public void testCreateDestinationIndex_Classification() throws IOException { + Map map = testCreateDestinationIndex(new Classification(DEPENDENT_VARIABLE)); + assertThat(extractValue("_doc.properties.ml.dep_var_prediction.type", map), equalTo("integer")); + assertThat(extractValue("_doc.properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testCreateDestinationIndex_ResultsFieldsExistsInSourceIndex() { + DataFrameAnalyticsConfig config = createConfig(new OutlierDetection.Builder().build()); + + GetSettingsResponse getSettingsResponse = new GetSettingsResponse(ImmutableOpenMap.of(), ImmutableOpenMap.of()); + + ImmutableOpenMap.Builder mappings = ImmutableOpenMap.builder(); + mappings.put("", new MappingMetaData("_doc", Map.of("properties", Map.of("ml", "some-mapping")))); + GetMappingsResponse getMappingsResponse = new GetMappingsResponse(mappings.build()); + + doAnswer(callListenerOnResponse(getSettingsResponse)).when(client).execute(eq(GetSettingsAction.INSTANCE), any(), any()); + doAnswer(callListenerOnResponse(getMappingsResponse)).when(client).execute(eq(GetMappingsAction.INSTANCE), any(), any()); + + DataFrameAnalyticsIndex.createDestinationIndex( + client, + clock, + config, + ActionListener.wrap( + response -> fail("should not succeed"), + e -> assertThat( + e.getMessage(), + equalTo("A field that matches the dest.results_field [ml] already exists; please set a different results_field")) + ) + ); + } + + private Map testUpdateMappingsToDestIndex(DataFrameAnalysis analysis, + Map properties) throws IOException { + DataFrameAnalyticsConfig config = createConfig(analysis); + + ImmutableOpenMap.Builder mappings = ImmutableOpenMap.builder(); + mappings.put("", new MappingMetaData("_doc", Map.of("properties", properties))); + GetIndexResponse getIndexResponse = + new GetIndexResponse( + new String[] { DEST_INDEX }, mappings.build(), ImmutableOpenMap.of(), ImmutableOpenMap.of(), ImmutableOpenMap.of()); + + ArgumentCaptor putMappingRequestCaptor = ArgumentCaptor.forClass(PutMappingRequest.class); + + doAnswer(callListenerOnResponse(new AcknowledgedResponse(true))) + .when(client).execute(eq(PutMappingAction.INSTANCE), putMappingRequestCaptor.capture(), any()); + + DataFrameAnalyticsIndex.updateMappingsToDestIndex( + client, + config, + getIndexResponse, + ActionListener.wrap( + response -> assertThat(response.isAcknowledged(), is(true)), + e -> fail(e.getMessage()) + ) + ); + + verify(client, atLeastOnce()).threadPool(); + verify(client).execute(eq(PutMappingAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client); + + PutMappingRequest putMappingRequest = putMappingRequestCaptor.getValue(); + assertThat(putMappingRequest.indices(), arrayContaining(DEST_INDEX)); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, putMappingRequest.source())) { + Map map = parser.map(); + assertThat(extractValue("properties.ml__id_copy.type", map), equalTo("keyword")); + return map; } } + + public void testUpdateMappingsToDestIndex_OutlierDetection() throws IOException { + testUpdateMappingsToDestIndex(new OutlierDetection.Builder().build(), Map.of(DEPENDENT_VARIABLE, Map.of("type", "integer"))); + } + + public void testUpdateMappingsToDestIndex_Regression() throws IOException { + Map map = + testUpdateMappingsToDestIndex(new Regression(DEPENDENT_VARIABLE), Map.of(DEPENDENT_VARIABLE, Map.of("type", "integer"))); + assertThat(extractValue("properties.ml.dep_var_prediction.type", map), equalTo("integer")); + } + + public void testUpdateMappingsToDestIndex_Classification() throws IOException { + Map map = + testUpdateMappingsToDestIndex(new Classification(DEPENDENT_VARIABLE), Map.of(DEPENDENT_VARIABLE, Map.of("type", "integer"))); + assertThat(extractValue("properties.ml.dep_var_prediction.type", map), equalTo("integer")); + assertThat(extractValue("properties.ml.top_classes.class_name.type", map), equalTo("integer")); + } + + public void testUpdateMappingsToDestIndex_ResultsFieldsExistsInSourceIndex() { + DataFrameAnalyticsConfig config = createConfig(new OutlierDetection.Builder().build()); + + ImmutableOpenMap.Builder mappings = ImmutableOpenMap.builder(); + mappings.put("", new MappingMetaData("_doc", Map.of("properties", Map.of("ml", "some-mapping")))); + GetIndexResponse getIndexResponse = + new GetIndexResponse( + new String[] { DEST_INDEX }, mappings.build(), ImmutableOpenMap.of(), ImmutableOpenMap.of(), ImmutableOpenMap.of()); + + ElasticsearchStatusException e = + expectThrows( + ElasticsearchStatusException.class, + () -> DataFrameAnalyticsIndex.updateMappingsToDestIndex( + client, config, getIndexResponse, ActionListener.wrap(Assert::fail))); + assertThat( + e.getMessage(), + equalTo("A field that matches the dest.results_field [ml] already exists; please set a different results_field")); + + verifyZeroInteractions(client); + } + + private static Answer callListenerOnResponse(Response response) { + return invocationOnMock -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(response); + return null; + }; + } + + private static DataFrameAnalyticsConfig createConfig(DataFrameAnalysis analysis) { + return new DataFrameAnalyticsConfig.Builder() + .setId(ANALYTICS_ID) + .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null, null)) + .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) + .setAnalysis(analysis) + .build(); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 6b882d03f2919..1daea1e57ca14 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -54,7 +54,7 @@ public void testDetect_GivenFloatField() { .addAggregatableField("some_float", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -72,7 +72,7 @@ public void testDetect_GivenNumericFieldWithMultipleTypes() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -90,7 +90,7 @@ public void testDetect_GivenOutlierDetectionAndNonNumericField() { .addAggregatableField("some_keyword", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]." + @@ -102,7 +102,7 @@ public void testDetect_GivenOutlierDetectionAndFieldWithNumericAndNonNumericType .addAggregatableField("indecisive_field", "float", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]. " + @@ -118,7 +118,7 @@ public void testDetect_GivenOutlierDetectionAndMultipleFields() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -147,7 +147,7 @@ public void testDetect_GivenRegressionAndMultipleFields() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("foo"), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -174,7 +174,7 @@ public void testDetect_GivenRegressionAndRequiredFieldMissing() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("foo"), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("required field [foo] is missing; analysis requires fields [foo]")); @@ -190,7 +190,7 @@ public void testDetect_GivenRegressionAndRequiredFieldExcluded() { analyzedFields = new FetchSourceContext(true, new String[0], new String[] {"foo"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("foo"), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("required field [foo] is missing; analysis requires fields [foo]")); @@ -206,7 +206,7 @@ public void testDetect_GivenRegressionAndRequiredFieldNotIncluded() { analyzedFields = new FetchSourceContext(true, new String[] {"some_float", "some_keyword"}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("foo"), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("required field [foo] is missing; analysis requires fields [foo]")); @@ -220,7 +220,7 @@ public void testDetect_GivenFieldIsBothIncludedAndExcluded() { analyzedFields = new FetchSourceContext(true, new String[] {"foo", "bar"}, new String[] {"foo"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -241,7 +241,7 @@ public void testDetect_GivenFieldIsNotIncludedAndIsExcluded() { analyzedFields = new FetchSourceContext(true, new String[] {"foo"}, new String[] {"bar"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -263,7 +263,7 @@ public void testDetect_GivenRegressionAndRequiredFieldHasInvalidType() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("foo"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("foo"), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("invalid types [keyword] for required field [foo]; " + @@ -279,7 +279,7 @@ public void testDetect_GivenClassificationAndRequiredFieldHasInvalidType() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildClassificationConfig("some_float"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildClassificationConfig("some_float"), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("invalid types [float] for required field [some_float]; " + @@ -294,7 +294,7 @@ public void testDetect_GivenClassificationAndDependentVariableHasInvalidCardinal .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX, - buildClassificationConfig("some_keyword"), false, 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L)); + buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L)); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [2] distinct values but there were at least [3]")); @@ -305,7 +305,7 @@ public void testDetect_GivenIgnoredField() { .addAggregatableField("_id", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]. " + @@ -319,7 +319,7 @@ public void testDetect_GivenIncludedIgnoredField() { analyzedFields = new FetchSourceContext(true, new String[]{"_id"}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No field [_id] could be detected")); @@ -332,7 +332,7 @@ public void testDetect_GivenExcludedFieldIsMissing() { analyzedFields = new FetchSourceContext(true, new String[]{"*"}, new String[] {"bar"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No field [bar] could be detected")); @@ -346,7 +346,7 @@ public void testDetect_GivenExcludedFieldIsUnsupported() { analyzedFields = new FetchSourceContext(true, null, new String[] {"categorical"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); @@ -377,7 +377,7 @@ public void testDetect_ShouldSortFieldsAlphabetically() { FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) @@ -394,7 +394,7 @@ public void testDetect_GivenIncludeWithMissingField() { analyzedFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No field [your_field1] could be detected")); @@ -409,7 +409,7 @@ public void testDetect_GivenExcludeAllValidFields() { analyzedFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]. " + "Supported types are [boolean, byte, double, float, half_float, integer, long, scaled_float, short].")); @@ -425,7 +425,7 @@ public void testDetect_GivenInclusionsAndExclusions() { analyzedFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) @@ -450,7 +450,7 @@ public void testDetect_GivenIncludedFieldHasUnsupportedType() { analyzedFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("field [your_keyword] has unsupported type [keyword]. " + @@ -458,22 +458,6 @@ public void testDetect_GivenIncludedFieldHasUnsupportedType() { } public void testDetect_GivenIndexContainsResultsField() { - FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() - .addAggregatableField(RESULTS_FIELD, "float") - .addAggregatableField("my_field1", "float") - .addAggregatableField("your_field2", "float") - .addAggregatableField("your_keyword", "keyword") - .build(); - - ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); - - assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " + - "please set a different results_field")); - } - - public void testDetect_GivenIndexContainsResultsFieldAndTaskIsRestarting() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField(RESULTS_FIELD + ".outlier_score", "float") .addAggregatableField("my_field1", "float") @@ -482,7 +466,7 @@ public void testDetect_GivenIndexContainsResultsFieldAndTaskIsRestarting() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) @@ -498,23 +482,6 @@ public void testDetect_GivenIndexContainsResultsFieldAndTaskIsRestarting() { } public void testDetect_GivenIncludedResultsField() { - FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() - .addAggregatableField(RESULTS_FIELD, "float") - .addAggregatableField("my_field1", "float") - .addAggregatableField("your_field2", "float") - .addAggregatableField("your_keyword", "keyword") - .build(); - analyzedFields = new FetchSourceContext(true, new String[]{RESULTS_FIELD}, new String[0]); - - ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); - - assertThat(e.getMessage(), equalTo("A field that matches the dest.results_field [ml] already exists; " + - "please set a different results_field")); - } - - public void testDetect_GivenIncludedResultsFieldAndTaskIsRestarting() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() .addAggregatableField(RESULTS_FIELD + ".outlier_score", "float") .addAggregatableField("my_field1", "float") @@ -524,7 +491,7 @@ public void testDetect_GivenIncludedResultsFieldAndTaskIsRestarting() { analyzedFields = new FetchSourceContext(true, new String[]{RESULTS_FIELD}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), true, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); assertThat(e.getMessage(), equalTo("No field [ml] could be detected")); @@ -539,7 +506,7 @@ public void testDetect_GivenLessFieldsThanDocValuesLimit() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), true, 4, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 4, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) @@ -558,7 +525,7 @@ public void testDetect_GivenEqualFieldsToDocValuesLimit() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), true, 3, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 3, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) @@ -577,7 +544,7 @@ public void testDetect_GivenMoreFieldsThanDocValuesLimit() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), true, 2, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 2, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List extractedFieldNames = fieldExtraction.v1().getAllFields().stream().map(ExtractedField::getName) @@ -594,7 +561,7 @@ private void testDetect_GivenBooleanField(DataFrameAnalyticsConfig config, boole .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, config, false, 100, fieldCapabilities, config.getAnalysis().getFieldCardinalityLimits()); + SOURCE_INDEX, config, 100, fieldCapabilities, config.getAnalysis().getFieldCardinalityLimits()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -650,7 +617,7 @@ public void testDetect_GivenMultiFields() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("a_float"), true, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("a_float"), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(5)); @@ -681,7 +648,7 @@ public void testDetect_GivenMultiFieldAndParentIsRequired() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildClassificationConfig("field_1"), true, 100, fieldCapabilities, Collections.singletonMap("field_1", 2L)); + SOURCE_INDEX, buildClassificationConfig("field_1"), 100, fieldCapabilities, Collections.singletonMap("field_1", 2L)); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); @@ -705,7 +672,7 @@ public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), true, 100, fieldCapabilities, + SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), 100, fieldCapabilities, Collections.singletonMap("field_1.keyword", 2L)); Tuple> fieldExtraction = extractedFieldsDetector.detect(); @@ -732,7 +699,7 @@ public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("field_2"), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); @@ -758,7 +725,7 @@ public void testDetect_GivenMultiFields_OverDocValueLimit() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("field_2"), true, 0, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("field_2"), 0, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); @@ -783,7 +750,7 @@ public void testDetect_GivenParentAndMultiFieldBothAggregatable() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("field_2.double"), true, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("field_2.double"), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); @@ -808,7 +775,7 @@ public void testDetect_GivenParentAndMultiFieldNoneAggregatable() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("field_2"), true, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("field_2"), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); @@ -833,7 +800,7 @@ public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() { analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildRegressionConfig("field_2"), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildRegressionConfig("field_2"), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); assertThat(fieldExtraction.v1().getAllFields(), hasSize(2)); @@ -858,7 +825,7 @@ public void testDetect_GivenSourceFilteringWithIncludes() { sourceFiltering = new FetchSourceContext(true, new String[] {"field_1*"}, null); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields(); @@ -881,7 +848,7 @@ public void testDetect_GivenSourceFilteringWithExcludes() { sourceFiltering = new FetchSourceContext(true, null, new String[] {"field_1*"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildOutlierDetectionConfig(), false, 100, fieldCapabilities, Collections.emptyMap()); + SOURCE_INDEX, buildOutlierDetectionConfig(), 100, fieldCapabilities, Collections.emptyMap()); Tuple> fieldExtraction = extractedFieldsDetector.detect(); List allFields = fieldExtraction.v1().getAllFields();