Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more parameters for text embedding model #640

Merged
merged 3 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ client/build/
common/build/
ml-algorithms/build/
plugin/build/
.DS_Store
2 changes: 1 addition & 1 deletion client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ plugins {
dependencies {
implementation project(':opensearch-ml-common')
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'

}
Expand Down
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ plugins {
dependencies {
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation group: 'org.reflections', name: 'reflections', version: '0.9.12'
testImplementation group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
implementation "org.opensearch:common-utils:${common_utils_version}"
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.POOLING_METHOD_FIELD;

public class CommonValue {

Expand Down Expand Up @@ -87,6 +90,9 @@ public class CommonValue {
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
+ EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\""
+ FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
+ POOLING_METHOD_FIELD + "\":{\"type\":\"keyword\"},\""
+ NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\""
+ MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\""
+ ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n"
+ " \""
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ public class TextEmbeddingModelConfig extends MLModelConfig {

public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension";
public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
public static final String POOLING_METHOD_FIELD = "pooling_method";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";

private Integer embeddingDimension;
private FrameworkType frameworkType;
private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMethod poolingMethod;
private final boolean normalizeResult;
private final Integer modelMaxLength;

@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig) {
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMethod poolingMethod, boolean normalizeResult, Integer modelMaxLength) {
super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
Expand All @@ -48,13 +55,23 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
}
this.embeddingDimension = embeddingDimension;
this.frameworkType = frameworkType;
if (poolingMethod != null) {
this.poolingMethod = poolingMethod;
} else {
this.poolingMethod = PoolingMethod.MEAN;
}
Comment on lines +60 to +62
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set the default value below, We don't need this else branch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you mean if we set default value in line 72 PoolingMethod poolingMethod = null;, we don't need line 61 this.poolingMethod = PoolingMethod.MEAN; ?

I think we still need this. This is constructor method, user can create a new object directly without calling parse method

this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
String modelType = null;
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMethod poolingMethod = PoolingMethod.MEAN;
boolean normalizeResult = false;
Integer modelMaxLength = null;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we have a default value for modelMaxLength?

Copy link
Collaborator Author

@ylwu-amzn ylwu-amzn Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We depend on DJL engine to set the default value.


ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -74,12 +91,21 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case ALL_CONFIG_FIELD:
allConfig = parser.text();
break;
case POOLING_METHOD_FIELD:
poolingMethod = PoolingMethod.from(parser.text().toUpperCase(Locale.ROOT));
break;
case NORMALIZE_RESULT_FIELD:
normalizeResult = parser.booleanValue();
break;
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
default:
parser.skipChildren();
break;
}
}
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig);
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMethod, normalizeResult, modelMaxLength);
}

@Override
Expand All @@ -91,13 +117,19 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
super(in);
embeddingDimension = in.readInt();
frameworkType = in.readEnum(FrameworkType.class);
poolingMethod = in.readEnum(PoolingMethod.class);
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(embeddingDimension);
out.writeEnum(frameworkType);
out.writeEnum(poolingMethod);
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
}

@Override
Expand All @@ -115,13 +147,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (allConfig != null) {
builder.field(ALL_CONFIG_FIELD, allConfig);
}
if (modelMaxLength != null) {
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
}
builder.field(POOLING_METHOD_FIELD, poolingMethod);
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
builder.endObject();
return builder;
}

public enum PoolingMethod {
MEAN,
CLS;

public static PoolingMethod from(String value) {
try {
return PoolingMethod.valueOf(value);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");
}
}
}
public enum FrameworkType {
HUGGINGFACE_TRANSFORMERS,
SENTENCE_TRANSFORMERS;
SENTENCE_TRANSFORMERS,
HUGGINGFACE_TRANSFORMERS_NEURON;

public static FrameworkType from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
System.out.println(configContent);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"pooling_method\":\"MEAN\",\"normalize_result\":false}", configContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ public class MLUploadInputTest {
public ExpectedException exceptionRule = ExpectedException.none();
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
"\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," +
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}," +
"\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" +
",\"pooling_method\":\"MEAN\",\"normalize_result\":false},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
private final FunctionName functionName = FunctionName.LINEAR_REGRESSION;
private final String modelName = "modelName";
private final String version = "version";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public class MLCreateModelMetaInputTest {

@Before
public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config");
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512);
mLCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2);
}
Expand Down Expand Up @@ -74,8 +75,8 @@ public void testToXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLCreateModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"LOADING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\","
+ "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\"},\"total_chunks\":2}";
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"LOADING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_method\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public class MLCreateModelMetaRequestTest {

@Before
public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config");
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512);
mlCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2);
}
Expand Down
4 changes: 2 additions & 2 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies {
implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0'
testImplementation group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0'
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
Expand Down Expand Up @@ -62,7 +62,7 @@ jacocoTestCoverageVerification {
}
limit {
counter = 'BRANCH'
minimum = 0.80 //TODO: increase coverage to 0.85
minimum = 0.79 //TODO: increase coverage to 0.85
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
import java.util.Map;
Expand All @@ -33,10 +34,18 @@ public class HuggingfaceTextEmbeddingTranslator implements Translator<String, fl

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier;
private TextEmbeddingModelConfig.PoolingMethod poolingMethod;
private boolean normalizeResult;
private String modelType;
private boolean neuron;

HuggingfaceTextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
HuggingfaceTextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier, TextEmbeddingModelConfig.PoolingMethod poolingMethod, boolean normalizeResult, String modelType, boolean neuron) {
this.tokenizer = tokenizer;
this.batchifier = batchifier;
this.poolingMethod = poolingMethod;
this.normalizeResult = normalizeResult;
this.modelType = modelType;
this.neuron = neuron;
}

/** {@inheritDoc} */
Expand All @@ -56,13 +65,39 @@ public NDList processInput(TranslatorContext ctx, String input) {
NDList ndList = new NDList(2);
ndList.add(manager.create(indices));
ndList.add(manager.create(attentionMask));
if (neuron && ("bert".equalsIgnoreCase(modelType) || "albert".equalsIgnoreCase(modelType))) {
long[] tokenTypeIds = encoding.getTypeIds();
ndList.add(manager.create(tokenTypeIds));
}
return ndList;
}

/** {@inheritDoc} */
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray embeddings = null;
switch (this.poolingMethod) {
case MEAN:
embeddings = meanPooling(ctx, list);
break;
case CLS:
embeddings = list.get(0).get(0);
break;
default:
throw new IllegalArgumentException("Unsupported pooling method");
}

if (normalizeResult) {
embeddings = embeddings.normalize(2, 0);
}
return embeddings.toFloatArray();
}

private static NDArray meanPooling(TranslatorContext ctx, NDList list) {
NDArray embeddings = list.get("last_hidden_state");
if (embeddings == null) {
embeddings = list.get(0);
}
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
long[] attentionMask = encoding.getAttentionMask();
NDManager manager = ctx.getNDManager();
Expand All @@ -73,11 +108,9 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12);
NDArray prod = embeddings.mul(inputAttentionMask);
NDArray sum = prod.sum(AXIS);
embeddings = sum.div(clamp).normalize(2, 0);

return embeddings.toFloatArray();
embeddings = sum.div(clamp);
return embeddings;
}

/**
* Creates a builder to build a {@code TextEmbeddingTranslator}.
*
Expand Down Expand Up @@ -107,6 +140,10 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier = Batchifier.STACK;
private TextEmbeddingModelConfig.PoolingMethod poolingMethod;
private boolean normalizeResult;
private String modelType;
private boolean neuron;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -140,7 +177,27 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public HuggingfaceTextEmbeddingTranslator build() throws IOException {
return new HuggingfaceTextEmbeddingTranslator(tokenizer, batchifier);
return new HuggingfaceTextEmbeddingTranslator(tokenizer, batchifier, poolingMethod, normalizeResult, modelType, neuron);
}

public Builder poolingMethod(TextEmbeddingModelConfig.PoolingMethod poolingMethod) {
this.poolingMethod = poolingMethod;
return this;
}

public Builder normalizeResult(boolean normalizeResult) {
this.normalizeResult = normalizeResult;
return this;
}

public Builder modelType(String modelType) {
this.modelType = modelType;
return this;
}

public Builder neuron(boolean neuron) {
this.neuron = neuron;
return this;
}
}
}
Loading