-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ML] Weighted token model support (#93186)
- Loading branch information
Showing
22 changed files
with
828 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
...gin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SlimResults.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.results; | ||
|
||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.common.io.stream.Writeable; | ||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import java.util.stream.Collectors; | ||
|
||
public class SlimResults extends NlpInferenceResults { | ||
|
||
public static final String NAME = "slim_result"; | ||
|
||
public record WeightedToken(int token, float weight) implements Writeable, ToXContentObject { | ||
|
||
public static final String TOKEN = "token"; | ||
public static final String WEIGHT = "weight"; | ||
|
||
public WeightedToken(StreamInput in) throws IOException { | ||
this(in.readVInt(), in.readFloat()); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeVInt(token); | ||
out.writeFloat(weight); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(TOKEN, token); | ||
builder.field(WEIGHT, weight); | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
public Map<String, Object> asMap() { | ||
return Map.of(TOKEN, token, WEIGHT, weight); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Strings.toString(this); | ||
} | ||
} | ||
|
||
private final String resultsField; | ||
private final List<WeightedToken> weightedTokens; | ||
|
||
public SlimResults(String resultField, List<WeightedToken> weightedTokens, boolean isTruncated) { | ||
super(isTruncated); | ||
this.resultsField = resultField; | ||
this.weightedTokens = weightedTokens; | ||
} | ||
|
||
public SlimResults(StreamInput in) throws IOException { | ||
super(in); | ||
this.resultsField = in.readString(); | ||
this.weightedTokens = in.readList(WeightedToken::new); | ||
} | ||
|
||
public List<WeightedToken> getWeightedTokens() { | ||
return weightedTokens; | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return NAME; | ||
} | ||
|
||
@Override | ||
public String getResultsField() { | ||
return resultsField; | ||
} | ||
|
||
@Override | ||
public Object predictedValue() { | ||
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value"); | ||
} | ||
|
||
@Override | ||
void doXContentBody(XContentBuilder builder, Params params) throws IOException { | ||
builder.startArray(resultsField); | ||
for (var weightedToken : weightedTokens) { | ||
weightedToken.toXContent(builder, params); | ||
} | ||
builder.endArray(); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
if (super.equals(o) == false) return false; | ||
SlimResults that = (SlimResults) o; | ||
return Objects.equals(resultsField, that.resultsField) && Objects.equals(weightedTokens, that.weightedTokens); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(super.hashCode(), resultsField, weightedTokens); | ||
} | ||
|
||
@Override | ||
void doWriteTo(StreamOutput out) throws IOException { | ||
out.writeString(resultsField); | ||
out.writeList(weightedTokens); | ||
} | ||
|
||
@Override | ||
void addMapFields(Map<String, Object> map) { | ||
map.put(resultsField, weightedTokens.stream().map(WeightedToken::asMap).collect(Collectors.toList())); | ||
} | ||
} |
175 changes: 175 additions & 0 deletions
175
...core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/SlimConfig.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel; | ||
|
||
import org.elasticsearch.Version; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.common.io.stream.StreamInput; | ||
import org.elasticsearch.common.io.stream.StreamOutput; | ||
import org.elasticsearch.core.Nullable; | ||
import org.elasticsearch.xcontent.ConstructingObjectParser; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
import org.elasticsearch.xcontent.XContentParser; | ||
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; | ||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; | ||
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; | ||
|
||
import java.io.IOException; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
public class SlimConfig implements NlpConfig { | ||
|
||
public static final String NAME = "slim"; | ||
|
||
public static SlimConfig fromXContentStrict(XContentParser parser) { | ||
return STRICT_PARSER.apply(parser, null); | ||
} | ||
|
||
public static SlimConfig fromXContentLenient(XContentParser parser) { | ||
return LENIENT_PARSER.apply(parser, null); | ||
} | ||
|
||
private static final ConstructingObjectParser<SlimConfig, Void> STRICT_PARSER = createParser(false); | ||
private static final ConstructingObjectParser<SlimConfig, Void> LENIENT_PARSER = createParser(true); | ||
|
||
private static ConstructingObjectParser<SlimConfig, Void> createParser(boolean ignoreUnknownFields) { | ||
ConstructingObjectParser<SlimConfig, Void> parser = new ConstructingObjectParser<>( | ||
NAME, | ||
ignoreUnknownFields, | ||
a -> new SlimConfig((VocabularyConfig) a[0], (Tokenization) a[1], (String) a[2]) | ||
); | ||
parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> { | ||
if (ignoreUnknownFields == false) { | ||
throw ExceptionsHelper.badRequestException( | ||
"illegal setting [{}] on inference model creation", | ||
VOCABULARY.getPreferredName() | ||
); | ||
} | ||
return VocabularyConfig.fromXContentLenient(p); | ||
}, VOCABULARY); | ||
parser.declareNamedObject( | ||
ConstructingObjectParser.optionalConstructorArg(), | ||
(p, c, n) -> p.namedObject(Tokenization.class, n, ignoreUnknownFields), | ||
TOKENIZATION | ||
); | ||
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); | ||
return parser; | ||
} | ||
|
||
private final VocabularyConfig vocabularyConfig; | ||
private final Tokenization tokenization; | ||
private final String resultsField; | ||
|
||
public SlimConfig(@Nullable VocabularyConfig vocabularyConfig, @Nullable Tokenization tokenization, @Nullable String resultsField) { | ||
this.vocabularyConfig = Optional.ofNullable(vocabularyConfig) | ||
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore())); | ||
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization; | ||
if (this.tokenization instanceof BertTokenization == false) { | ||
throw ExceptionsHelper.badRequestException( | ||
"SLIM must be configured with BERT tokenizer, [{}] given", | ||
this.tokenization.getName() | ||
); | ||
} | ||
// TODO support spanning | ||
if (this.tokenization.span != -1) { | ||
throw ExceptionsHelper.badRequestException( | ||
"[{}] does not support windowing long text sequences; configured span [{}]", | ||
NAME, | ||
this.tokenization.span | ||
); | ||
} | ||
this.resultsField = resultsField; | ||
} | ||
|
||
public SlimConfig(StreamInput in) throws IOException { | ||
vocabularyConfig = new VocabularyConfig(in); | ||
tokenization = in.readNamedWriteable(Tokenization.class); | ||
resultsField = in.readOptionalString(); | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
vocabularyConfig.writeTo(out); | ||
out.writeNamedWriteable(tokenization); | ||
out.writeOptionalString(resultsField); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(VOCABULARY.getPreferredName(), vocabularyConfig, params); | ||
NamedXContentObjectHelper.writeNamedObject(builder, params, TOKENIZATION.getPreferredName(), tokenization); | ||
if (resultsField != null) { | ||
builder.field(RESULTS_FIELD.getPreferredName(), resultsField); | ||
} | ||
builder.endObject(); | ||
return builder; | ||
} | ||
|
||
@Override | ||
public String getWriteableName() { | ||
return NAME; | ||
} | ||
|
||
@Override | ||
public boolean isTargetTypeSupported(TargetType targetType) { | ||
// TargetType relates to boosted tree models | ||
return false; | ||
} | ||
|
||
@Override | ||
public boolean isAllocateOnly() { | ||
return true; | ||
} | ||
|
||
@Override | ||
public Version getMinimalSupportedVersion() { | ||
return Version.V_8_7_0; | ||
} | ||
|
||
@Override | ||
public String getResultsField() { | ||
return resultsField; | ||
} | ||
|
||
@Override | ||
public VocabularyConfig getVocabularyConfig() { | ||
return vocabularyConfig; | ||
} | ||
|
||
@Override | ||
public Tokenization getTokenization() { | ||
return tokenization; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return Strings.toString(this); | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
SlimConfig that = (SlimConfig) o; | ||
return Objects.equals(vocabularyConfig, that.vocabularyConfig) | ||
&& Objects.equals(tokenization, that.tokenization) | ||
&& Objects.equals(resultsField, that.resultsField); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(vocabularyConfig, tokenization, resultsField); | ||
} | ||
} |
Oops, something went wrong.