From f966936973fcab2f3f96114520a7bce8e5943cc2 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 1 Nov 2019 13:37:20 +0200 Subject: [PATCH] [7.x][ML] Deduplicate multi-fields for data frame analytics (#48799) In the case multi-fields exist in the source index, we pick all variants of them in our extracted fields detection for data frame analytics. This means we may have multiple instances of the same feature. The worse consequence of this is when the dependent variable (for regression or classification) is also duplicated which means we train a model on the dependent variable itself. Now that #48770 is merged, this commit is adding logic to only select one variant of multi-fields. Closes #48756 Backport of #48799 --- .../extractor/ExtractedFieldsDetector.java | 57 ++++++- .../ExtractedFieldsDetectorTests.java | 155 +++++++++++++++++- 2 files changed, 209 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 6ed1ea62fe8a4..2a82ae7dcf963 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -31,6 +31,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -238,7 +239,9 @@ private ExtractedFields detectExtractedFields(Set fields) { // We sort the fields to ensure the checksum for each document is deterministic Collections.sort(sortedFields); ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse); - if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { + boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; + extractedFields = deduplicateMultiFields(extractedFields, preferSource); + if (preferSource) { extractedFields = fetchFromSourceIfSupported(extractedFields); if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " + @@ -250,9 +253,59 @@ private ExtractedFields detectExtractedFields(Set fields) { return extractedFields; } + private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource) { + Set requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName) + .collect(Collectors.toSet()); + Map nameOrParentToField = new LinkedHashMap<>(); + for (ExtractedField currentField : extractedFields.getAllFields()) { + String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName(); + ExtractedField existingField = nameOrParentToField.putIfAbsent(nameOrParent, currentField); + if (existingField != null) { + ExtractedField parent = currentField.isMultiField() ? existingField : currentField; + ExtractedField multiField = currentField.isMultiField() ? currentField : existingField; + nameOrParentToField.put(nameOrParent, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField)); + } + } + return new ExtractedFields(new ArrayList<>(nameOrParentToField.values())); + } + + private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set requiredFields, + ExtractedField parent, ExtractedField multiField) { + // Check requirements first + if (requiredFields.contains(parent.getName())) { + return parent; + } + if (requiredFields.contains(multiField.getName())) { + return multiField; + } + + // If both are multi-fields it means there are several. In this case parent is the previous multi-field + // we selected. We'll just keep that. + if (parent.isMultiField() && multiField.isMultiField()) { + return parent; + } + + // If we prefer source only the parent may support it. If it does we pick it immediately. + if (preferSource && parent.supportsFromSource()) { + return parent; + } + + // If any of the two is a doc_value field let's prefer it as it'd support aggregations. + // We check the parent first as it'd be a shorter field name. + if (parent.getMethod() == ExtractedField.Method.DOC_VALUE) { + return parent; + } + if (multiField.getMethod() == ExtractedField.Method.DOC_VALUE) { + return multiField; + } + + // None is aggregatable. Let's pick the parent for its shorter name. + return parent; + } + private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); - for (ExtractedField field : extractedFields.getDocValueFields()) { + for (ExtractedField field : extractedFields.getAllFields()) { adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); } return new ExtractedFields(adjusted); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index ce819e9e6d84a..7dc203dfe24b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -534,6 +534,151 @@ public void testDetect_GivenBooleanField_BooleanMappedAsString() { assertThat(booleanField.value(hit), arrayContaining("false", "true", "false")); } + public void testDetect_GivenMultiFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("a_float", "float") + .addNonAggregatableField("text_without_keyword", "text") + .addNonAggregatableField("text_1", "text") + .addAggregatableField("text_1.keyword", "keyword") + .addNonAggregatableField("text_2", "text") + .addAggregatableField("text_2.keyword", "keyword") + .addAggregatableField("keyword_1", "keyword") + .addNonAggregatableField("keyword_1.text", "text") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("a_float"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(5)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword")); + } + + public void testDetect_GivenMultiFieldAndParentIsRequired() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "keyword") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildClassificationConfig("field_1"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + + public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "keyword") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1.keyword", "field_2")); + } + + public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addAggregatableField("field_1.keyword_3", "keyword") + .addAggregatableField("field_1.keyword_2", "keyword") + .addAggregatableField("field_1.keyword_1", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2")); + } + + public void testDetect_GivenMultiFields_OverDocValueLimit() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addAggregatableField("field_1.keyword_1", "keyword") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 0, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + + public void testDetect_GivenParentAndMultiFieldBothAggregatable() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "keyword") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2.keyword", "float") + .addAggregatableField("field_2.double", "double") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2.double"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2.double")); + } + + public void testDetect_GivenParentAndMultiFieldNoneAggregatable() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addNonAggregatableField("field_1.text", "text") + .addAggregatableField("field_2", "float") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + + public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addNonAggregatableField("field_1", "text") + .addAggregatableField("field_1.keyword", "keyword") + .addAggregatableField("field_2", "float") + .build(); + FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildRegressionConfig("field_2", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + assertThat(extractedFields.getAllFields().size(), equalTo(2)); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, contains("field_1", "field_2")); + } + private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() { return buildOutlierDetectionConfig(null); } @@ -576,9 +721,17 @@ private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) { + return addField(field, true, types); + } + + private MockFieldCapsResponseBuilder addNonAggregatableField(String field, String... types) { + return addField(field, false, types); + } + + private MockFieldCapsResponseBuilder addField(String field, boolean isAggregatable, String... types) { Map caps = new HashMap<>(); for (String type : types) { - caps.put(type, new FieldCapabilities(field, type, true, true)); + caps.put(type, new FieldCapabilities(field, type, true, isAggregatable)); } fieldCaps.put(field, caps); return this;