diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 6ed1ea62fe8a4..2a82ae7dcf963 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -31,6 +31,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -238,7 +239,9 @@ private ExtractedFields detectExtractedFields(Set fields) { // We sort the fields to ensure the checksum for each document is deterministic Collections.sort(sortedFields); ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse); - if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { + boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; + extractedFields = deduplicateMultiFields(extractedFields, preferSource); + if (preferSource) { extractedFields = fetchFromSourceIfSupported(extractedFields); if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " + @@ -250,9 +253,59 @@ private ExtractedFields detectExtractedFields(Set fields) { return extractedFields; } + private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource) { + Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) + .collect(Collectors.toSet()); + Map nameOrParentToField = new LinkedHashMap<>(); + for (ExtractedField currentField : extractedFields.getAllFields()) { + String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName(); + ExtractedField existingField = nameOrParentToField.putIfAbsent(nameOrParent, currentField); + if (existingField != null) { + ExtractedField parent = currentField.isMultiField() ? existingField : currentField; + ExtractedField multiField = currentField.isMultiField() ? currentField : existingField; + nameOrParentToField.put(nameOrParent, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField)); + } + } + return new ExtractedFields(new ArrayList<>(nameOrParentToField.values())); + } + + private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set requiredFields, + ExtractedField parent, ExtractedField multiField) { + // Check requirements first + if (requiredFields.contains(parent.getName())) { + return parent; + } + if (requiredFields.contains(multiField.getName())) { + return multiField; + } + + // If both are multi-fields it means there are several. In this case parent is the previous multi-field + // we selected. We'll just keep that. + if (parent.isMultiField() && multiField.isMultiField()) { + return parent; + } + + // If we prefer source only the parent may support it. If it does we pick it immediately. + if (preferSource && parent.supportsFromSource()) { + return parent; + } + + // If any of the two is a doc_value field let's prefer it as it'd support aggregations. + // We check the parent first as it'd be a shorter field name. + if (parent.getMethod() == ExtractedField.Method.DOC_VALUE) { + return parent; + } + if (multiField.getMethod() == ExtractedField.Method.DOC_VALUE) { + return multiField; + } + + // None is aggregatable. Let's pick the parent for its shorter name. + return parent; + } + private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); - for (ExtractedField field : extractedFields.getDocValueFields()) { + for (ExtractedField field : extractedFields.getAllFields()) { adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); } return new ExtractedFields(adjusted); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index ce819e9e6d84a..7dc203dfe24b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -534,6 +534,151 @@ public void testDetect_GivenBooleanField_BooleanMappedAsString() { assertThat(booleanField.value(hit), arrayContaining("false", "true", "false")); } + public void testDetect_GivenMultiFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("a_float", "float") + .addNonAggregatableField("text_without_keyword", "text") + .addNonAggregatableField("text_1", "text") + .addAggregatableField("text_1.keyword", "keyword") + .addNonAggregatableField("text_2", "text") + .addAggregatableField("text_2.keyword", "keyword") + .addAggregatableField("keyword_1", "keyword") + .addNonAggregatableField("keyword_1.text", "text") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("a_float"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(5)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword")); + } + + public void testDetect_GivenMultiFieldAndParentIsRequired() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "keyword") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildClassificationConfig("field_1"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + + public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "keyword") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1.keyword", "field_2")); + } + + public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addAggregatableField("field_1.keyword_3", "keyword") + .addAggregatableField("field_1.keyword_2", "keyword") + .addAggregatableField("field_1.keyword_1", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2")); + } + + public void testDetect_GivenMultiFields_OverDocValueLimit() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addAggregatableField("field_1.keyword_1", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 0, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + + public void testDetect_GivenParentAndMultiFieldBothAggregatable() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "keyword") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2.keyword", "float") + .addAggregatableField("field_2.double", "double") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2.double"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2.double")); + } + + public void testDetect_GivenParentAndMultiFieldNoneAggregatable() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addNonAggregatableField("field_1.text", "text") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + + public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2", "float") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() { return buildOutlierDetectionConfig(null); } @@ -576,9 +721,17 @@ private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) { + return addField(field, true, types); + } + + private MockFieldCapsResponseBuilder addNonAggregatableField(String field, String... types) { + return addField(field, false, types); + } + + private MockFieldCapsResponseBuilder addField(String field, boolean isAggregatable, String... types) { Map caps = new HashMap<>(); for (String type : types) { - caps.put(type, new FieldCapabilities(field, type, true, true)); + caps.put(type, new FieldCapabilities(field, type, true, isAggregatable)); } fieldCaps.put(field, caps); return this;