Skip to content

Commit

Permalink
add more parameters for text embedding model (opensearch-project#640)
Browse files Browse the repository at this point in the history
* add more parameters for text embedding model

Signed-off-by: Yaliang Wu <[email protected]>

* upgrade junit version to 4.13.2

Signed-off-by: Yaliang Wu <[email protected]>

* address comments

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Feb 28, 2023
1 parent de60f53 commit bef97c8
Show file tree
Hide file tree
Showing 20 changed files with 313 additions and 80 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ client/build/
common/build/
ml-algorithms/build/
plugin/build/
.DS_Store
2 changes: 1 addition & 1 deletion client/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ plugins {
dependencies {
implementation project(':opensearch-ml-common')
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
testImplementation group: 'junit', name: 'junit', version: '4.13.1'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'

}
Expand Down
2 changes: 1 addition & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ plugins {
dependencies {
compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation group: 'org.reflections', name: 'reflections', version: '0.9.12'
testImplementation group: 'junit', name: 'junit', version: '4.13.1'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
compileOnly "org.opensearch:common-utils:${common_utils_version}"
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD;
import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.POOLING_METHOD_FIELD;

public class CommonValue {

Expand Down Expand Up @@ -87,6 +90,9 @@ public class CommonValue {
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
+ EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\""
+ FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
+ POOLING_METHOD_FIELD + "\":{\"type\":\"keyword\"},\""
+ NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\""
+ MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\""
+ ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n"
+ " \""
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ public class TextEmbeddingModelConfig extends MLModelConfig {

public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension";
public static final String FRAMEWORK_TYPE_FIELD = "framework_type";
public static final String POOLING_METHOD_FIELD = "pooling_method";
public static final String NORMALIZE_RESULT_FIELD = "normalize_result";
public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length";

private Integer embeddingDimension;
private FrameworkType frameworkType;
private final Integer embeddingDimension;
private final FrameworkType frameworkType;
private final PoolingMethod poolingMethod;
private final boolean normalizeResult;
private final Integer modelMaxLength;

@Builder(toBuilder = true)
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig) {
public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig,
PoolingMethod poolingMethod, boolean normalizeResult, Integer modelMaxLength) {
super(modelType, allConfig);
if (embeddingDimension == null) {
throw new IllegalArgumentException("embedding dimension is null");
Expand All @@ -48,13 +55,23 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr
}
this.embeddingDimension = embeddingDimension;
this.frameworkType = frameworkType;
if (poolingMethod != null) {
this.poolingMethod = poolingMethod;
} else {
this.poolingMethod = PoolingMethod.MEAN;
}
this.normalizeResult = normalizeResult;
this.modelMaxLength = modelMaxLength;
}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
String modelType = null;
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMethod poolingMethod = PoolingMethod.MEAN;
boolean normalizeResult = false;
Integer modelMaxLength = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -74,12 +91,21 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case ALL_CONFIG_FIELD:
allConfig = parser.text();
break;
case POOLING_METHOD_FIELD:
poolingMethod = PoolingMethod.from(parser.text().toUpperCase(Locale.ROOT));
break;
case NORMALIZE_RESULT_FIELD:
normalizeResult = parser.booleanValue();
break;
case MODEL_MAX_LENGTH_FIELD:
modelMaxLength = parser.intValue();
break;
default:
parser.skipChildren();
break;
}
}
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig);
return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMethod, normalizeResult, modelMaxLength);
}

@Override
Expand All @@ -91,13 +117,19 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{
super(in);
embeddingDimension = in.readInt();
frameworkType = in.readEnum(FrameworkType.class);
poolingMethod = in.readEnum(PoolingMethod.class);
normalizeResult = in.readBoolean();
modelMaxLength = in.readOptionalInt();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeInt(embeddingDimension);
out.writeEnum(frameworkType);
out.writeEnum(poolingMethod);
out.writeBoolean(normalizeResult);
out.writeOptionalInt(modelMaxLength);
}

@Override
Expand All @@ -115,13 +147,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (allConfig != null) {
builder.field(ALL_CONFIG_FIELD, allConfig);
}
if (modelMaxLength != null) {
builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength);
}
builder.field(POOLING_METHOD_FIELD, poolingMethod);
builder.field(NORMALIZE_RESULT_FIELD, normalizeResult);
builder.endObject();
return builder;
}

public enum PoolingMethod {
MEAN,
CLS;

public static PoolingMethod from(String value) {
try {
return PoolingMethod.valueOf(value);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong pooling method");
}
}
}
public enum FrameworkType {
HUGGINGFACE_TRANSFORMERS,
SENTENCE_TRANSFORMERS;
SENTENCE_TRANSFORMERS,
HUGGINGFACE_TRANSFORMERS_NEURON;

public static FrameworkType from(String value) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ public void toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
config.toXContent(builder, EMPTY_PARAMS);
String configContent = TestHelper.xContentBuilderToString(builder);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent);
System.out.println(configContent);
assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"pooling_method\":\"MEAN\",\"normalize_result\":false}", configContent);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public class MLUploadInputTest {
public ExpectedException exceptionRule = ExpectedException.none();
private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," +
"\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," +
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}," +
"\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
"\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" +
",\"pooling_method\":\"MEAN\",\"normalize_result\":false},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}";
private final FunctionName functionName = FunctionName.LINEAR_REGRESSION;
private final String modelName = "modelName";
private final String version = "version";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public class MLCreateModelMetaInputTest {

@Before
public void setup() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config");
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512);
mLCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2);
}
Expand Down Expand Up @@ -74,8 +75,8 @@ public void testToXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
mLCreateModelMetaInput.toXContent(builder, EMPTY_PARAMS);
String mlModelContent = TestHelper.xContentBuilderToString(builder);
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"LOADING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\","
+ "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\"},\"total_chunks\":2}";
final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"LOADING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," +
"\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_method\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}";
assertEquals(expected, mlModelContent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public class MLCreateModelMetaRequestTest {

@Before
public void setUp() {
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config");
config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config",
TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512);
mlCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0",
"Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2);
}
Expand Down
4 changes: 2 additions & 2 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies {
implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0'
implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0'
testImplementation group: 'junit', name: 'junit', version: '4.13.1'
testImplementation group: 'junit', name: 'junit', version: '4.13.2'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0'
testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0'
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
Expand Down Expand Up @@ -62,7 +62,7 @@ jacocoTestCoverageVerification {
}
limit {
counter = 'BRANCH'
minimum = 0.80 //TODO: increase coverage to 0.85
minimum = 0.79 //TODO: increase coverage to 0.85
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;

import java.io.IOException;
import java.util.Map;
Expand All @@ -33,10 +34,18 @@ public class HuggingfaceTextEmbeddingTranslator implements Translator<String, fl

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier;
private TextEmbeddingModelConfig.PoolingMethod poolingMethod;
private boolean normalizeResult;
private String modelType;
private boolean neuron;

HuggingfaceTextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier) {
HuggingfaceTextEmbeddingTranslator(HuggingFaceTokenizer tokenizer, Batchifier batchifier, TextEmbeddingModelConfig.PoolingMethod poolingMethod, boolean normalizeResult, String modelType, boolean neuron) {
this.tokenizer = tokenizer;
this.batchifier = batchifier;
this.poolingMethod = poolingMethod;
this.normalizeResult = normalizeResult;
this.modelType = modelType;
this.neuron = neuron;
}

/** {@inheritDoc} */
Expand All @@ -56,13 +65,39 @@ public NDList processInput(TranslatorContext ctx, String input) {
NDList ndList = new NDList(2);
ndList.add(manager.create(indices));
ndList.add(manager.create(attentionMask));
if (neuron && ("bert".equalsIgnoreCase(modelType) || "albert".equalsIgnoreCase(modelType))) {
long[] tokenTypeIds = encoding.getTypeIds();
ndList.add(manager.create(tokenTypeIds));
}
return ndList;
}

/** {@inheritDoc} */
@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray embeddings = null;
switch (this.poolingMethod) {
case MEAN:
embeddings = meanPooling(ctx, list);
break;
case CLS:
embeddings = list.get(0).get(0);
break;
default:
throw new IllegalArgumentException("Unsupported pooling method");
}

if (normalizeResult) {
embeddings = embeddings.normalize(2, 0);
}
return embeddings.toFloatArray();
}

private static NDArray meanPooling(TranslatorContext ctx, NDList list) {
NDArray embeddings = list.get("last_hidden_state");
if (embeddings == null) {
embeddings = list.get(0);
}
Encoding encoding = (Encoding) ctx.getAttachment("encoding");
long[] attentionMask = encoding.getAttentionMask();
NDManager manager = ctx.getNDManager();
Expand All @@ -73,11 +108,9 @@ public float[] processOutput(TranslatorContext ctx, NDList list) {
NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12);
NDArray prod = embeddings.mul(inputAttentionMask);
NDArray sum = prod.sum(AXIS);
embeddings = sum.div(clamp).normalize(2, 0);

return embeddings.toFloatArray();
embeddings = sum.div(clamp);
return embeddings;
}

/**
* Creates a builder to build a {@code TextEmbeddingTranslator}.
*
Expand Down Expand Up @@ -107,6 +140,10 @@ public static final class Builder {

private HuggingFaceTokenizer tokenizer;
private Batchifier batchifier = Batchifier.STACK;
private TextEmbeddingModelConfig.PoolingMethod poolingMethod;
private boolean normalizeResult;
private String modelType;
private boolean neuron;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -140,7 +177,27 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public HuggingfaceTextEmbeddingTranslator build() throws IOException {
return new HuggingfaceTextEmbeddingTranslator(tokenizer, batchifier);
return new HuggingfaceTextEmbeddingTranslator(tokenizer, batchifier, poolingMethod, normalizeResult, modelType, neuron);
}

public Builder poolingMethod(TextEmbeddingModelConfig.PoolingMethod poolingMethod) {
this.poolingMethod = poolingMethod;
return this;
}

public Builder normalizeResult(boolean normalizeResult) {
this.normalizeResult = normalizeResult;
return this;
}

public Builder modelType(String modelType) {
this.modelType = modelType;
return this;
}

public Builder neuron(boolean neuron) {
this.neuron = neuron;
return this;
}
}
}
Loading

0 comments on commit bef97c8

Please sign in to comment.