Skip to content

Commit

Permalink
[ML][Inference] adjust so target_field always has inference result an…
Browse files Browse the repository at this point in the history
…d optionally allow new top classes field in the classification config (elastic#49923)
  • Loading branch information
benwtrent authored and SivagurunathanV committed Jan 21, 2020
1 parent a470b6e commit dde9e59
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -90,13 +92,15 @@ public String valueAsString() {
}

@Override
public void writeResult(IngestDocument document, String resultField) {
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
assert config instanceof ClassificationConfig;
ClassificationConfig classificationConfig = (ClassificationConfig)config;
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(resultField, "resultField");
if (topClasses.isEmpty()) {
document.setFieldValue(resultField, valueAsString());
} else {
document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
document.setFieldValue(resultField, valueAsString());
if (topClasses.isEmpty() == false) {
document.setFieldValue(classificationConfig.getTopClassesResultsField(),
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;

public interface InferenceResults extends NamedXContentObject, NamedWriteable {

void writeResult(IngestDocument document, String resultField);
void writeResult(IngestDocument document, String resultField, InferenceConfig config);

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -49,7 +50,7 @@ public int hashCode() {
}

@Override
public void writeResult(IngestDocument document, String resultField) {
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand Down Expand Up @@ -50,7 +51,7 @@ public int hashCode() {
}

@Override
public void writeResult(IngestDocument document, String resultField) {
public void writeResult(IngestDocument document, String resultField, InferenceConfig config) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(resultField, "resultField");
document.setFieldValue(resultField, value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,50 +21,65 @@ public class ClassificationConfig implements InferenceConfig {

public static final String NAME = "classification";

public static final String DEFAULT_TOP_CLASSES_RESULT_FIELD = "top_classes";
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final ParseField TOP_CLASSES_RESULT_FIELD = new ParseField("top_classes_result_field");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;

public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0, DEFAULT_TOP_CLASSES_RESULT_FIELD);

private final int numTopClasses;
private final String topClassesResultsField;

public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
String topClassesResultsField = (String)options.remove(TOP_CLASSES_RESULT_FIELD.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses);
return new ClassificationConfig(numTopClasses, topClassesResultsField);
}

public ClassificationConfig(Integer numTopClasses) {
this(numTopClasses, null);
}

public ClassificationConfig(Integer numTopClasses, String topClassesResultsField) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
this.topClassesResultsField = topClassesResultsField == null ? DEFAULT_TOP_CLASSES_RESULT_FIELD : topClassesResultsField;
}

public ClassificationConfig(StreamInput in) throws IOException {
this.numTopClasses = in.readInt();
this.topClassesResultsField = in.readString();
}

public int getNumTopClasses() {
return numTopClasses;
}

public String getTopClassesResultsField() {
return topClassesResultsField;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(numTopClasses);
out.writeString(topClassesResultsField);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses);
return Objects.equals(numTopClasses, that.numTopClasses) && Objects.equals(topClassesResultsField, that.topClassesResultsField);
}

@Override
public int hashCode() {
return Objects.hash(numTopClasses);
return Objects.hash(numTopClasses, topClassesResultsField);
}

@Override
Expand All @@ -73,6 +88,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (numTopClasses != 0) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
builder.field(TOP_CLASSES_RESULT_FIELD.getPreferredName(), topClassesResultsField);
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -37,15 +38,15 @@ private static ClassificationInferenceResults.TopClassEntry createRandomClassEnt
public void testWriteResultsWithClassificationLabel() {
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList());
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", ClassificationConfig.EMPTY_PARAMS);

assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
}

public void testWriteResultsWithoutClassificationLabel() {
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList());
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", ClassificationConfig.EMPTY_PARAMS);

assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0"));
}
Expand All @@ -60,15 +61,17 @@ public void testWriteResultsWithTopClasses() {
"foo",
entries);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", new ClassificationConfig(3, "bar"));

List<?> list = document.getFieldValue("result_field", List.class);
List<?> list = document.getFieldValue("bar", List.class);
assertThat(list.size(), equalTo(3));

for(int i = 0; i < 3; i++) {
Map<String, Object> map = (Map<String, Object>)list.get(i);
assertThat(map, equalTo(entries.get(i).asValueMap()));
}

assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;

import java.util.HashMap;

Expand All @@ -24,7 +25,7 @@ public static RegressionInferenceResults createRandomResults() {
public void testWriteResults() {
RegressionInferenceResults result = new RegressionInferenceResults(0.3);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
result.writeResult(document, "result_field", new RegressionConfig());

assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,27 @@
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.hamcrest.Matchers.equalTo;

public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {

public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10),
randomBoolean() ? null : randomAlphaOfLength(10));
}

public void testFromMap() {
ClassificationConfig expected = new ClassificationConfig(0);
ClassificationConfig expected = ClassificationConfig.EMPTY_PARAMS;
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));

expected = new ClassificationConfig(3);
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
equalTo(expected));
expected = new ClassificationConfig(3, "foo");
Map<String, Object> configMap = new HashMap<>();
configMap.put(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3);
configMap.put(ClassificationConfig.TOP_CLASSES_RESULT_FIELD.getPreferredName(), "foo");
assertThat(ClassificationConfig.fromMap(configMap), equalTo(expected));
}

public void testFromMapWithUnknownField() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,8 @@ public void testSimulate() {
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class\",\n" +
" \"inference_config\": {\"classification\":{}},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class_prob\",\n" +
" \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" +
" \"inference_config\": {\"classification\": " +
" {\"num_top_classes\":2, \"top_classes_result_field\": \"result_class_prob\"}},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -154,7 +153,7 @@ void mutateDocument(InternalInferModelAction.Response response, IngestDocument i
if (response.getInferenceResults().isEmpty()) {
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField, inferenceConfig);
if (includeModelMetadata) {
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
}
Expand Down Expand Up @@ -227,8 +226,7 @@ int numInferenceProcessors() {
}

@Override
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
throws Exception {
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config) {

if (this.maxIngestProcessors <= currentInferenceProcessors) {
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
Expand Down Expand Up @@ -267,7 +265,7 @@ void setMaxIngestProcessors(int maxIngestProcessors) {
this.maxIngestProcessors = maxIngestProcessors;
}

InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) throws IOException {
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) {
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);

if (inferenceConfig.size() != 1) {
Expand All @@ -284,7 +282,7 @@ InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) thro
Map<String, Object> valueMap = (Map<String, Object>)value;

if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
checkSupportedVersion(new ClassificationConfig(0));
checkSupportedVersion(ClassificationConfig.EMPTY_PARAMS);
return ClassificationConfig.fromMap(valueMap);
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
checkSupportedVersion(new RegressionConfig());
Expand Down
Loading

0 comments on commit dde9e59

Please sign in to comment.