Skip to content

Commit

Permalink
[ML] unifying NLP task output with existing inference outputs (elasti…
Browse files Browse the repository at this point in the history
…c#78530)

This unifies our current inference results and configuration options with our existing inference responses and configuration.

Now results_field is configurable and updatable for all our inference configs.

The following responses now resemble our classification response.

text_classification
zero_shot_classification (does not support "num_top_classes", as labels are fully configurable)
mask_fill. This inference response also has <results_field>_sequence for the predicted value in the sequence.
pass_through and text_embedding now write their results to the results_field parameter.

For all, if results_field is not provided, the default is predicted_value.

NER is not really unified yet, its a unique situation as its multiple token multi-class.
  • Loading branch information
benwtrent authored Oct 1, 2021
1 parent 9e0299f commit 5725bb3
Show file tree
Hide file tree
Showing 26 changed files with 355 additions and 372 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextClassificationResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
Expand Down Expand Up @@ -295,9 +294,6 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
PyTorchPassThroughResults.NAME,
PyTorchPassThroughResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
TextClassificationResults.NAME,
TextClassificationResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
TextEmbeddingResults.NAME,
TextEmbeddingResults::new));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults

public static final String NAME = "classification";

public static final String PREDICTION_PROBABILITY = "prediction_probability";
public static final String PREDICTION_SCORE = "prediction_score";

private final String topNumClassesField;
private final String resultsField;
// Accessed in sub-classes
protected final String resultsField;
private final String classificationLabel;
private final Double predictionProbability;
private final Double predictionScore;
Expand Down Expand Up @@ -60,15 +60,41 @@ private ClassificationInferenceResults(double value,
ClassificationConfig classificationConfig,
Double predictionProbability,
Double predictionScore) {
this(
value,
classificationLabel,
topClasses,
featureImportance,
classificationConfig.getTopClassesResultsField(),
classificationConfig.getResultsField(),
classificationConfig.getPredictionFieldType(),
classificationConfig.getNumTopFeatureImportanceValues(),
predictionProbability,
predictionScore
);
}

public ClassificationInferenceResults(
double value,
String classificationLabel,
List<TopClassEntry> topClasses,
List<ClassificationFeatureImportance> featureImportance,
String topNumClassesField,
String resultsField,
PredictionFieldType predictionFieldType,
int numTopFeatureImportanceValues,
Double predictionProbability,
Double predictionScore
) {
super(value);
this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
this.topNumClassesField = classificationConfig.getTopClassesResultsField();
this.resultsField = classificationConfig.getResultsField();
this.predictionFieldType = classificationConfig.getPredictionFieldType();
this.topNumClassesField = topNumClassesField;
this.resultsField = resultsField;
this.predictionFieldType = predictionFieldType;
this.predictionProbability = predictionProbability;
this.predictionScore = predictionScore;
this.featureImportance = takeTopFeatureImportances(featureImportance, classificationConfig.getNumTopFeatureImportanceValues());
this.featureImportance = takeTopFeatureImportances(featureImportance, numTopFeatureImportanceValues);
}

static List<ClassificationFeatureImportance> takeTopFeatureImportances(List<ClassificationFeatureImportance> featureImportances,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,159 +7,91 @@

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;

import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class FillMaskResults implements InferenceResults {
public class FillMaskResults extends ClassificationInferenceResults {

public static final String NAME = "fill_mask_result";
public static final String DEFAULT_RESULTS_FIELD = "results";

private final List<Prediction> predictions;

public FillMaskResults(List<Prediction> predictions) {
this.predictions = predictions;
private final String predictedSequence;

public FillMaskResults(
double value,
String classificationLabel,
String predictedSequence,
List<TopClassEntry> topClasses,
String topNumClassesField,
String resultsField,
Double predictionProbability
) {
super(
value,
classificationLabel,
topClasses,
List.of(),
topNumClassesField,
resultsField,
PredictionFieldType.STRING,
0,
predictionProbability,
null
);
this.predictedSequence = predictedSequence;
}

public FillMaskResults(StreamInput in) throws IOException {
this.predictions = in.readList(Prediction::new);
}

public List<Prediction> getPredictions() {
return predictions;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startArray(DEFAULT_RESULTS_FIELD);
for (Prediction prediction : predictions) {
prediction.toXContent(builder, params);
}
builder.endArray();
return builder;
super(in);
this.predictedSequence = in.readString();
}

@Override
public String getWriteableName() {
return NAME;
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(predictedSequence);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeList(predictions);
public String getPredictedSequence() {
return predictedSequence;
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(DEFAULT_RESULTS_FIELD, predictions.stream().map(Prediction::toMap).collect(Collectors.toList()));
map.put(resultsField + "_sequence", predictedSequence);
map.putAll(super.asMap());
return map;
}

@Override
public Object predictedValue() {
if (predictions.isEmpty()) {
return null;
}
return predictions.get(0).token;
public String getWriteableName() {
return NAME;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return super.toXContent(builder, params).field(resultsField + "_sequence", predictedSequence);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
FillMaskResults that = (FillMaskResults) o;
return Objects.equals(predictions, that.predictions);
return Objects.equals(predictedSequence, that.predictedSequence);
}

@Override
public int hashCode() {
return Objects.hash(predictions);
}

public static class Prediction implements ToXContentObject, Writeable {

private static final ParseField TOKEN = new ParseField("token");
private static final ParseField SCORE = new ParseField("score");
private static final ParseField SEQUENCE = new ParseField("sequence");

private final String token;
private final double score;
private final String sequence;

public Prediction(String token, double score, String sequence) {
this.token = Objects.requireNonNull(token);
this.score = score;
this.sequence = Objects.requireNonNull(sequence);
}

public Prediction(StreamInput in) throws IOException {
token = in.readString();
score = in.readDouble();
sequence = in.readString();
}

public double getScore() {
return score;
}

public String getSequence() {
return sequence;
}

public String getToken() {
return token;
}

public Map<String, Object> toMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TOKEN.getPreferredName(), token);
map.put(SCORE.getPreferredName(), score);
map.put(SEQUENCE.getPreferredName(), sequence);
return map;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TOKEN.getPreferredName(), token);
builder.field(SCORE.getPreferredName(), score);
builder.field(SEQUENCE.getPreferredName(), sequence);
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(token);
out.writeDouble(score);
out.writeString(sequence);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Prediction result = (Prediction) o;
return Double.compare(result.score, score) == 0 &&
Objects.equals(token, result.token) &&
Objects.equals(sequence, result.sequence);
}

@Override
public int hashCode() {
return Objects.hash(token, score, sequence);
}
return Objects.hash(super.hashCode(), predictedSequence);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.Map;

public interface InferenceResults extends NamedWriteable, ToXContentFragment {
String PREDICTION_PROBABILITY = "prediction_probability";
String MODEL_ID_RESULTS_FIELD = "model_id";

static void writeResult(InferenceResults results, IngestDocument ingestDocument, String resultField, String modelId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.util.Objects;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;

public class NerResults implements InferenceResults {

public static final String NAME = "ner_result";
Expand Down Expand Up @@ -58,7 +60,7 @@ public void writeTo(StreamOutput out) throws IOException {
@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(FillMaskResults.DEFAULT_RESULTS_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
map.put(DEFAULT_RESULTS_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
return map;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -16,22 +15,23 @@
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;

public class PyTorchPassThroughResults implements InferenceResults {

public static final String NAME = "pass_through_result";
static final String DEFAULT_RESULTS_FIELD = "results";

private static final ParseField INFERENCE = new ParseField("inference");

private final double[][] inference;
private final String resultsField;

public PyTorchPassThroughResults(double[][] inference) {
public PyTorchPassThroughResults(String resultsField, double[][] inference) {
this.inference = inference;
this.resultsField = resultsField;
}

public PyTorchPassThroughResults(StreamInput in) throws IOException {
inference = in.readArray(StreamInput::readDoubleArray, length -> new double[length][]);
inference = in.readArray(StreamInput::readDoubleArray, double[][]::new);
resultsField = in.readString();
}

public double[][] getInference() {
Expand All @@ -40,7 +40,7 @@ public double[][] getInference() {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(INFERENCE.getPreferredName(), inference);
builder.field(resultsField, inference);
return builder;
}

Expand All @@ -52,12 +52,13 @@ public String getWriteableName() {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeArray(StreamOutput::writeDoubleArray, inference);
out.writeString(resultsField);
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(DEFAULT_RESULTS_FIELD, inference);
map.put(resultsField, inference);
return map;
}

Expand All @@ -71,11 +72,11 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
PyTorchPassThroughResults that = (PyTorchPassThroughResults) o;
return Arrays.deepEquals(inference, that.inference);
return Arrays.deepEquals(inference, that.inference) && Objects.equals(resultsField, that.resultsField);
}

@Override
public int hashCode() {
return Arrays.deepHashCode(inference);
return Objects.hash(Arrays.deepHashCode(inference), resultsField);
}
}
Loading

0 comments on commit 5725bb3

Please sign in to comment.