Skip to content

Commit

Permalink
[7.6] Do not copy mapping from dependent variable to prediction field…
Browse files Browse the repository at this point in the history
… in regression analysis (#51227) (#51289)
  • Loading branch information
przemekwitek authored Jan 22, 2020
1 parent d9cf8fc commit 83ffe96
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.mapper.FieldAliasMapper;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand All @@ -28,6 +29,7 @@

import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;

public class Classification implements DataFrameAnalysis {

Expand Down Expand Up @@ -248,12 +250,32 @@ public Map<String, Long> getFieldCardinalityLimits() {
return Collections.singletonMap(dependentVariable, 2L);
}

@SuppressWarnings("unchecked")
@Override
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
return new HashMap<String, String>() {{
put(resultsFieldName + "." + predictionFieldName, dependentVariable);
put(resultsFieldName + ".top_classes.class_name", dependentVariable);
}};
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
if ((dependentVariableMapping instanceof Map) == false) {
return Collections.emptyMap();
}
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
// If the source field is an alias, fetch the concrete field that the alias points to.
if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
dependentVariableMapping = extractMapping(path, mappingsProperties);
}
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
// Hence, we need to check the "instanceof" condition again.
if ((dependentVariableMapping instanceof Map) == false) {
return Collections.emptyMap();
}
Map<String, Object> additionalProperties = new HashMap<>();
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
return additionalProperties;
}

private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
Map<String, Long> getFieldCardinalityLimits();

/**
* Returns fields for which the mappings should be copied from source index to destination index.
* Each entry of the returned {@link Map} is of the form:
* key - field path in the destination index
* value - field path in the source index from which the mapping should be taken
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
*
* @param mappingsProperties mappings.properties portion of the index mappings
* @param resultsFieldName name of the results field under which all the results are stored
* @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
* @return {@link Map} containing fields for which the mappings should be handled explicitly
*/
Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);

/**
* @return {@code true} if this analysis supports data frame rows with missing values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public Map<String, Long> getFieldCardinalityLimits() {
}

@Override
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
return Collections.emptyMap();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,10 @@ public Map<String, Long> getFieldCardinalityLimits() {
}

@Override
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
// high (over 10M) values of dependent variable.
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Map;
import java.util.Set;

import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
Expand Down Expand Up @@ -171,8 +172,40 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
}

public void testFieldMappingsToCopyIsNonEmpty() {
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
public void testGetExplicitlyMappedFields() {
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
assertThat(
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
is(anEmptyMap()));
assertThat(
new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
"results"),
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
assertThat(
new Classification("foo").getExplicitlyMappedFields(
new HashMap<String, Object>() {{
put("foo", new HashMap<String, String>() {{
put("type", "alias");
put("path", "bar");
}});
put("bar", Collections.singletonMap("type", "long"));
}},
"results"),
allOf(
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
assertThat(
new Classification("foo").getExplicitlyMappedFields(
Collections.singletonMap("foo", new HashMap<String, String>() {{
put("type", "alias");
put("path", "missing");
}}),
"results"),
is(anEmptyMap()));
}

public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ public void testFieldCardinalityLimitsIsEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
}

public void testFieldMappingsToCopyIsEmpty() {
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
public void testGetExplicitlyMappedFields() {
assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
}

public void testGetStateDocId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ protected Regression createTestInstance() {
return createRandom();
}

public static Regression createRandom() {
private static Regression createRandom() {
String dependentVariableName = randomAlphaOfLength(10);
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
Expand Down Expand Up @@ -110,8 +110,10 @@ public void testFieldCardinalityLimitsIsEmpty() {
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
}

public void testFieldMappingsToCopyIsNonEmpty() {
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
public void testGetExplicitlyMappedFields() {
assertThat(
new Regression("foo").getExplicitlyMappedFields(null, "results"),
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
}

public void testGetStateDocId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
// Pending fix
//import com.google.common.collect.Ordering;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
Expand Down Expand Up @@ -43,7 +41,6 @@
import java.util.Set;

import static java.util.stream.Collectors.toList;
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -117,7 +114,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
Expand Down Expand Up @@ -158,7 +155,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
Expand Down Expand Up @@ -221,7 +218,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
assertThatAuditMessagesMatch(jobId,
"Created analytics with analysis type [classification]",
"Estimated memory usage for this analytics to be",
Expand Down Expand Up @@ -309,7 +306,7 @@ public void testStopAndRestart() throws Exception {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

Expand Down Expand Up @@ -366,7 +363,7 @@ public void testDependentVariableIsNested() throws Exception {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

Expand All @@ -385,7 +382,7 @@ public void testDependentVariableIsAliasToKeyword() throws Exception {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

Expand All @@ -404,7 +401,7 @@ public void testDependentVariableIsAliasToNested() throws Exception {
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
assertModelStatePersisted(stateDocId());
assertInferenceModelPersisted(jobId);
assertMlResultsFieldMappings(predictedClassField, "keyword");
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
}

Expand Down Expand Up @@ -565,15 +562,6 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
return destDoc;
}

/**
* Wrapper around extractValue that:
* - allows dots (".") in the path elements provided as arguments
* - supports implicit casting to the appropriate type
*/
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
return (T)extractValue(String.join(".", path), doc);
}

private static <T> void assertTopClasses(Map<String, Object> resultsObject,
int numTopClasses,
String dependentVariable,
Expand Down Expand Up @@ -657,27 +645,6 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
}
}

private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
Map<String, Object> mappings =
client()
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
.actionGet()
.mappings()
.get(destIndex)
.get("_doc")
.sourceAsMap();
assertThat(
mappings.toString(),
getFieldValue(
mappings,
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
equalTo(expectedType));
assertThat(
mappings.toString(),
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
equalTo(expectedType));
}

private String stateDocId() {
return jobId + "_classification_state#1";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.action.admin.indices.get.GetIndexAction;
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
Expand Down Expand Up @@ -53,6 +55,7 @@
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -281,4 +284,36 @@ protected static void assertModelStatePersisted(String stateDocId) {
.get();
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
}

protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
Map<String, Object> mappings =
client()
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
.actionGet()
.mappings()
.get(index)
.get("_doc")
.sourceAsMap();
assertThat(
mappings.toString(),
getFieldValue(
mappings,
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
equalTo(expectedType));
if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
assertThat(
mappings.toString(),
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
equalTo(expectedType));
}
}

/**
* Wrapper around extractValue that:
* - allows dots (".") in the path elements provided as arguments
* - supports implicit casting to the appropriate type
*/
protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
return (T)extractValue(String.join(".", path), doc);
}
}
Loading

0 comments on commit 83ffe96

Please sign in to comment.