Skip to content

Commit

Permalink
[ML] Deduplicate multi-fields for data frame analytics (#48799)
Browse files Browse the repository at this point in the history
In the case multi-fields exist in the source index, we pick
all variants of them in our extracted fields detection for
data frame analytics. This means we may have multiple instances
of the same feature. The worse consequence of this is when the
dependent variable (for regression or classification) is also
duplicated which means we train a model on the dependent variable
itself.

Now that #48770 is merged, this commit is adding logic to
only select one variant of multi-fields.

Closes #48756
  • Loading branch information
dimitris-athanasiou authored Nov 1, 2019
1 parent 44c74bb commit ed39fa6
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -238,7 +239,9 @@ private ExtractedFields detectExtractedFields(Set<String> fields) {
// We sort the fields to ensure the checksum for each document is deterministic
Collections.sort(sortedFields);
ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit;
extractedFields = deduplicateMultiFields(extractedFields, preferSource);
if (preferSource) {
extractedFields = fetchFromSourceIfSupported(extractedFields);
if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) {
throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " +
Expand All @@ -250,9 +253,59 @@ private ExtractedFields detectExtractedFields(Set<String> fields) {
return extractedFields;
}

private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, boolean preferSource) {
Set<String> requiredFields = config.getAnalysis().getRequiredFields().stream().map(RequiredField::getName)
.collect(Collectors.toSet());
Map<String, ExtractedField> nameOrParentToField = new LinkedHashMap<>();
for (ExtractedField currentField : extractedFields.getAllFields()) {
String nameOrParent = currentField.isMultiField() ? currentField.getParentField() : currentField.getName();
ExtractedField existingField = nameOrParentToField.putIfAbsent(nameOrParent, currentField);
if (existingField != null) {
ExtractedField parent = currentField.isMultiField() ? existingField : currentField;
ExtractedField multiField = currentField.isMultiField() ? currentField : existingField;
nameOrParentToField.put(nameOrParent, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField));
}
}
return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()));
}

private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set<String> requiredFields,
ExtractedField parent, ExtractedField multiField) {
// Check requirements first
if (requiredFields.contains(parent.getName())) {
return parent;
}
if (requiredFields.contains(multiField.getName())) {
return multiField;
}

// If both are multi-fields it means there are several. In this case parent is the previous multi-field
// we selected. We'll just keep that.
if (parent.isMultiField() && multiField.isMultiField()) {
return parent;
}

// If we prefer source only the parent may support it. If it does we pick it immediately.
if (preferSource && parent.supportsFromSource()) {
return parent;
}

// If any of the two is a doc_value field let's prefer it as it'd support aggregations.
// We check the parent first as it'd be a shorter field name.
if (parent.getMethod() == ExtractedField.Method.DOC_VALUE) {
return parent;
}
if (multiField.getMethod() == ExtractedField.Method.DOC_VALUE) {
return multiField;
}

// None is aggregatable. Let's pick the parent for its shorter name.
return parent;
}

private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) {
List<ExtractedField> adjusted = new ArrayList<>(extractedFields.getAllFields().size());
for (ExtractedField field : extractedFields.getDocValueFields()) {
for (ExtractedField field : extractedFields.getAllFields()) {
adjusted.add(field.supportsFromSource() ? field.newFromSource() : field);
}
return new ExtractedFields(adjusted);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,151 @@ public void testDetect_GivenBooleanField_BooleanMappedAsString() {
assertThat(booleanField.value(hit), arrayContaining("false", "true", "false"));
}

public void testDetect_GivenMultiFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("a_float", "float")
.addNonAggregatableField("text_without_keyword", "text")
.addNonAggregatableField("text_1", "text")
.addAggregatableField("text_1.keyword", "keyword")
.addNonAggregatableField("text_2", "text")
.addAggregatableField("text_2.keyword", "keyword")
.addAggregatableField("keyword_1", "keyword")
.addNonAggregatableField("keyword_1.text", "text")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("a_float"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(5));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("a_float", "keyword_1", "text_1.keyword", "text_2.keyword", "text_without_keyword"));
}

public void testDetect_GivenMultiFieldAndParentIsRequired() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("field_1"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

public void testDetect_GivenMultiFieldAndMultiFieldIsRequired() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildClassificationConfig("field_1.keyword"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword", "field_2"));
}

public void testDetect_GivenSeveralMultiFields_ShouldPickFirstSorted() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword_3", "keyword")
.addAggregatableField("field_1.keyword_2", "keyword")
.addAggregatableField("field_1.keyword_1", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1.keyword_1", "field_2"));
}

public void testDetect_GivenMultiFields_OverDocValueLimit() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword_1", "keyword")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 0, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

public void testDetect_GivenParentAndMultiFieldBothAggregatable() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addAggregatableField("field_1", "keyword")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2.keyword", "float")
.addAggregatableField("field_2.double", "double")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2.double"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2.double"));
}

public void testDetect_GivenParentAndMultiFieldNoneAggregatable() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addNonAggregatableField("field_1.text", "text")
.addAggregatableField("field_2", "float")
.build();

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2"), RESULTS_FIELD, true, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

public void testDetect_GivenMultiFields_AndExplicitlyIncludedFields() {
FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder()
.addNonAggregatableField("field_1", "text")
.addAggregatableField("field_1.keyword", "keyword")
.addAggregatableField("field_2", "float")
.build();
FetchSourceContext analyzedFields = new FetchSourceContext(true, new String[] { "field_1", "field_2" }, new String[0]);

ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(
SOURCE_INDEX, buildRegressionConfig("field_2", analyzedFields), RESULTS_FIELD, false, 100, fieldCapabilities);
ExtractedFields extractedFields = extractedFieldsDetector.detect();

assertThat(extractedFields.getAllFields().size(), equalTo(2));
List<String> extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName)
.collect(Collectors.toList());
assertThat(extractedFieldNames, contains("field_1", "field_2"));
}

private static DataFrameAnalyticsConfig buildOutlierDetectionConfig() {
return buildOutlierDetectionConfig(null);
}
Expand Down Expand Up @@ -576,9 +721,17 @@ private static class MockFieldCapsResponseBuilder {
private final Map<String, Map<String, FieldCapabilities>> fieldCaps = new HashMap<>();

private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) {
return addField(field, true, types);
}

private MockFieldCapsResponseBuilder addNonAggregatableField(String field, String... types) {
return addField(field, false, types);
}

private MockFieldCapsResponseBuilder addField(String field, boolean isAggregatable, String... types) {
Map<String, FieldCapabilities> caps = new HashMap<>();
for (String type : types) {
caps.put(type, new FieldCapabilities(field, type, true, true));
caps.put(type, new FieldCapabilities(field, type, true, isAggregatable));
}
fieldCaps.put(field, caps);
return this;
Expand Down

0 comments on commit ed39fa6

Please sign in to comment.