Skip to content

Commit

Permalink
[ML] Weighted token model support (#93186)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Jan 24, 2023
1 parent 54fe770 commit cc86b16
Show file tree
Hide file tree
Showing 22 changed files with 828 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
import org.elasticsearch.xpack.core.ml.inference.results.QuestionAnsweringInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SlimResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextSimilarityInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
Expand Down Expand Up @@ -56,6 +57,8 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RobertaTokenizationUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.SlimConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModelLocation;
Expand Down Expand Up @@ -315,6 +318,20 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
FillMaskConfig::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
LenientlyParsedInferenceConfig.class,
new ParseField(SlimConfig.NAME),
SlimConfig::fromXContentLenient
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
StrictlyParsedInferenceConfig.class,
new ParseField(SlimConfig.NAME),
SlimConfig::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
LenientlyParsedInferenceConfig.class,
Expand Down Expand Up @@ -436,6 +453,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
RegressionConfigUpdate::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
InferenceConfigUpdate.class,
new ParseField(SlimConfigUpdate.NAME),
SlimConfigUpdate::fromXContentStrict
)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
InferenceConfigUpdate.class,
Expand Down Expand Up @@ -588,6 +612,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceResults.class, PyTorchPassThroughResults.NAME, PyTorchPassThroughResults::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, SlimResults.NAME, SlimResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, TextEmbeddingResults.NAME, TextEmbeddingResults::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down Expand Up @@ -619,6 +644,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, NerConfig.NAME, NerConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, FillMaskConfig.NAME, FillMaskConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, SlimConfig.NAME, SlimConfig::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfig.class, TextClassificationConfig.NAME, TextClassificationConfig::new)
);
Expand Down Expand Up @@ -658,6 +684,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
namedWriteables.add(
new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, ResultsFieldUpdate.NAME, ResultsFieldUpdate::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfigUpdate.class, SlimConfigUpdate.NAME, SlimConfigUpdate::new));
namedWriteables.add(
new NamedWriteableRegistry.Entry(
InferenceConfigUpdate.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ abstract class NlpInferenceResults implements InferenceResults {

abstract void addMapFields(Map<String, Object> map);

public boolean isTruncated() {
return isTruncated;
}

@Override
public final void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(isTruncated);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

public class SlimResults extends NlpInferenceResults {

public static final String NAME = "slim_result";

public record WeightedToken(int token, float weight) implements Writeable, ToXContentObject {

public static final String TOKEN = "token";
public static final String WEIGHT = "weight";

public WeightedToken(StreamInput in) throws IOException {
this(in.readVInt(), in.readFloat());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(token);
out.writeFloat(weight);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TOKEN, token);
builder.field(WEIGHT, weight);
builder.endObject();
return builder;
}

public Map<String, Object> asMap() {
return Map.of(TOKEN, token, WEIGHT, weight);
}

@Override
public String toString() {
return Strings.toString(this);
}
}

private final String resultsField;
private final List<WeightedToken> weightedTokens;

public SlimResults(String resultField, List<WeightedToken> weightedTokens, boolean isTruncated) {
super(isTruncated);
this.resultsField = resultField;
this.weightedTokens = weightedTokens;
}

public SlimResults(StreamInput in) throws IOException {
super(in);
this.resultsField = in.readString();
this.weightedTokens = in.readList(WeightedToken::new);
}

public List<WeightedToken> getWeightedTokens() {
return weightedTokens;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public Object predictedValue() {
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
}

@Override
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.startArray(resultsField);
for (var weightedToken : weightedTokens) {
weightedToken.toXContent(builder, params);
}
builder.endArray();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
SlimResults that = (SlimResults) o;
return Objects.equals(resultsField, that.resultsField) && Objects.equals(weightedTokens, that.weightedTokens);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), resultsField, weightedTokens);
}

@Override
void doWriteTo(StreamOutput out) throws IOException {
out.writeString(resultsField);
out.writeList(weightedTokens);
}

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, weightedTokens.stream().map(WeightedToken::asMap).collect(Collectors.toList()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class SlimConfig implements NlpConfig {

public static final String NAME = "slim";

public static SlimConfig fromXContentStrict(XContentParser parser) {
return STRICT_PARSER.apply(parser, null);
}

public static SlimConfig fromXContentLenient(XContentParser parser) {
return LENIENT_PARSER.apply(parser, null);
}

private static final ConstructingObjectParser<SlimConfig, Void> STRICT_PARSER = createParser(false);
private static final ConstructingObjectParser<SlimConfig, Void> LENIENT_PARSER = createParser(true);

private static ConstructingObjectParser<SlimConfig, Void> createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser<SlimConfig, Void> parser = new ConstructingObjectParser<>(
NAME,
ignoreUnknownFields,
a -> new SlimConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2])
);
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
if (ignoreUnknownFields == false) {
throw ExceptionsHelper.badRequestException(
"illegal setting [{}] on inference model creation",
VOCABULARY.getPreferredName()
);
}
return VocabularyConfig.fromXContentLenient(p);
}, VOCABULARY);
parser.declareNamedObject(
ConstructingObjectParser.optionalConstructorArg(),
(p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields),
TOKENIZATION
);
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD);
return parser;
}

private final VocabularyConfig vocabularyConfig;
private final Tokenization tokenization;
private final String resultsField;

public SlimConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable String resultsField) {
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig)
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
if (this.tokenization instanceof BertTokenization == false) {
throw ExceptionsHelper.badRequestException(
"SLIM must be configured with BERT tokenizer, [{}] given",
this.tokenization.getName()
);
}
// TODO support spanning
if (this.tokenization.span != -1) {
throw ExceptionsHelper.badRequestException(
"[{}] does not support windowing long text sequences; configured span [{}]",
NAME,
this.tokenization.span
);
}
this.resultsField = resultsField;
}

public SlimConfig(StreamInput in) throws IOException {
vocabularyConfig = new VocabularyConfig(in);
tokenization = in.readNamedWriteable(Tokenization.class);
resultsField = in.readOptionalString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
vocabularyConfig.writeTo(out);
out.writeNamedWriteable(tokenization);
out.writeOptionalString(resultsField);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params);
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization);
if (resultsField != null) {
builder.field(RESULTS_FIELD.getPreferredName(), resultsField);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public boolean isTargetTypeSupported(TargetType targetType) {
// TargetType relates to boosted tree models
return false;
}

@Override
public boolean isAllocateOnly() {
return true;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.V_8_7_0;
}

@Override
public String getResultsField() {
return resultsField;
}

@Override
public VocabularyConfig getVocabularyConfig() {
return vocabularyConfig;
}

@Override
public Tokenization getTokenization() {
return tokenization;
}

@Override
public String getName() {
return NAME;
}

@Override
public String toString() {
return Strings.toString(this);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SlimConfig that = (SlimConfig) o;
return Objects.equals(vocabularyConfig, that.vocabularyConfig)
&& Objects.equals(tokenization, that.tokenization)
&& Objects.equals(resultsField, that.resultsField);
}

@Override
public int hashCode() {
return Objects.hash(vocabularyConfig, tokenization, resultsField);
}
}
Loading

0 comments on commit cc86b16

Please sign in to comment.