diff --git a/.gitignore b/.gitignore index b9ca0cd25c..e1c2d340ff 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ client/build/ common/build/ ml-algorithms/build/ plugin/build/ +.DS_Store diff --git a/client/build.gradle b/client/build.gradle index 53ce5e3ecf..d52c430038 100644 --- a/client/build.gradle +++ b/client/build.gradle @@ -15,7 +15,7 @@ plugins { dependencies { implementation project(':opensearch-ml-common') compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" - testImplementation group: 'junit', name: 'junit', version: '4.12' + testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' } diff --git a/common/build.gradle b/common/build.gradle index 0e985d59d9..ee5bc5d30e 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -13,7 +13,7 @@ plugins { dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.reflections', name: 'reflections', version: '0.9.12' - testImplementation group: 'junit', name: 'junit', version: '4.12' + testImplementation group: 'junit', name: 'junit', version: '4.13.2' compileOnly "org.opensearch.client:opensearch-rest-client:${opensearch_version}" implementation "org.opensearch:common-utils:${common_utils_version}" testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index e6fec77804..5281f1f1dd 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -9,6 +9,9 @@ import static org.opensearch.ml.common.model.MLModelConfig.MODEL_TYPE_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.EMBEDDING_DIMENSION_FIELD; import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.FRAMEWORK_TYPE_FIELD; +import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.MODEL_MAX_LENGTH_FIELD; +import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.NORMALIZE_RESULT_FIELD; +import static org.opensearch.ml.common.model.TextEmbeddingModelConfig.POOLING_METHOD_FIELD; public class CommonValue { @@ -87,6 +90,9 @@ public class CommonValue { + MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" + EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\"" + FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\"" + + POOLING_METHOD_FIELD + "\":{\"type\":\"keyword\"},\"" + + NORMALIZE_RESULT_FIELD + "\":{\"type\":\"boolean\"},\"" + + MODEL_MAX_LENGTH_FIELD + "\":{\"type\":\"integer\"},\"" + ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n" + " \"" + MLModel.MODEL_CONTENT_HASH_VALUE_FIELD diff --git a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java index 29a478efaa..aa34ee9704 100644 --- a/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java +++ b/common/src/main/java/org/opensearch/ml/common/model/TextEmbeddingModelConfig.java @@ -33,12 +33,19 @@ public class TextEmbeddingModelConfig extends MLModelConfig { public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension"; public static final String FRAMEWORK_TYPE_FIELD = "framework_type"; + public static final String POOLING_METHOD_FIELD = "pooling_method"; + public static final String NORMALIZE_RESULT_FIELD = "normalize_result"; + public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length"; - private Integer embeddingDimension; - private FrameworkType frameworkType; + private final Integer embeddingDimension; + private final FrameworkType frameworkType; + private final PoolingMethod poolingMethod; + private final boolean normalizeResult; + private final Integer modelMaxLength; @Builder(toBuilder = true) - public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig) { + public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, FrameworkType frameworkType, String allConfig, + PoolingMethod poolingMethod, boolean normalizeResult, Integer modelMaxLength) { super(modelType, allConfig); if (embeddingDimension == null) { throw new IllegalArgumentException("embedding dimension is null"); @@ -48,6 +55,13 @@ public TextEmbeddingModelConfig(String modelType, Integer embeddingDimension, Fr } this.embeddingDimension = embeddingDimension; this.frameworkType = frameworkType; + if (poolingMethod != null) { + this.poolingMethod = poolingMethod; + } else { + this.poolingMethod = PoolingMethod.MEAN; + } + this.normalizeResult = normalizeResult; + this.modelMaxLength = modelMaxLength; } public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException { @@ -55,6 +69,9 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc Integer embeddingDimension = null; FrameworkType frameworkType = null; String allConfig = null; + PoolingMethod poolingMethod = PoolingMethod.MEAN; + boolean normalizeResult = false; + Integer modelMaxLength = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -74,12 +91,21 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc case ALL_CONFIG_FIELD: allConfig = parser.text(); break; + case POOLING_METHOD_FIELD: + poolingMethod = PoolingMethod.from(parser.text().toUpperCase(Locale.ROOT)); + break; + case NORMALIZE_RESULT_FIELD: + normalizeResult = parser.booleanValue(); + break; + case MODEL_MAX_LENGTH_FIELD: + modelMaxLength = parser.intValue(); + break; default: parser.skipChildren(); break; } } - return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig); + return new TextEmbeddingModelConfig(modelType, embeddingDimension, frameworkType, allConfig, poolingMethod, normalizeResult, modelMaxLength); } @Override @@ -91,6 +117,9 @@ public TextEmbeddingModelConfig(StreamInput in) throws IOException{ super(in); embeddingDimension = in.readInt(); frameworkType = in.readEnum(FrameworkType.class); + poolingMethod = in.readEnum(PoolingMethod.class); + normalizeResult = in.readBoolean(); + modelMaxLength = in.readOptionalInt(); } @Override @@ -98,6 +127,9 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeInt(embeddingDimension); out.writeEnum(frameworkType); + out.writeEnum(poolingMethod); + out.writeBoolean(normalizeResult); + out.writeOptionalInt(modelMaxLength); } @Override @@ -115,13 +147,31 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (allConfig != null) { builder.field(ALL_CONFIG_FIELD, allConfig); } + if (modelMaxLength != null) { + builder.field(MODEL_MAX_LENGTH_FIELD, modelMaxLength); + } + builder.field(POOLING_METHOD_FIELD, poolingMethod); + builder.field(NORMALIZE_RESULT_FIELD, normalizeResult); builder.endObject(); return builder; } + public enum PoolingMethod { + MEAN, + CLS; + + public static PoolingMethod from(String value) { + try { + return PoolingMethod.valueOf(value); + } catch (Exception e) { + throw new IllegalArgumentException("Wrong pooling method"); + } + } + } public enum FrameworkType { HUGGINGFACE_TRANSFORMERS, - SENTENCE_TRANSFORMERS; + SENTENCE_TRANSFORMERS, + HUGGINGFACE_TRANSFORMERS_NEURON; public static FrameworkType from(String value) { try { diff --git a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java index dba34262df..952350739b 100644 --- a/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java +++ b/common/src/test/java/org/opensearch/ml/common/model/TextEmbeddingModelConfigTests.java @@ -51,7 +51,8 @@ public void toXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); config.toXContent(builder, EMPTY_PARAMS); String configContent = TestHelper.xContentBuilderToString(builder); - assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}", configContent); + System.out.println(configContent); + assertEquals("{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\",\"pooling_method\":\"MEAN\",\"normalize_result\":false}", configContent); } @Test diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java index 9de69f9b90..ea2ca330c5 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload/MLUploadInputTest.java @@ -42,8 +42,8 @@ public class MLUploadInputTest { public ExpectedException exceptionRule = ExpectedException.none(); private final String expectedInputStr = "{\"function_name\":\"LINEAR_REGRESSION\",\"name\":\"modelName\",\"version\":\"version\",\"url\":\"url\",\"model_format\":\"ONNX\"," + "\"model_config\":{\"model_type\":\"testModelType\",\"embedding_dimension\":100,\"framework_type\":\"SENTENCE_TRANSFORMERS\"," + - "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"}," + - "\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; + "\"all_config\":\"{\\\"field1\\\":\\\"value1\\\",\\\"field2\\\":\\\"value2\\\"}\"" + + ",\"pooling_method\":\"MEAN\",\"normalize_result\":false},\"load_model\":true,\"model_node_ids\":[\"modelNodeIds\"]}"; private final FunctionName functionName = FunctionName.LINEAR_REGRESSION; private final String modelName = "modelName"; private final String version = "version"; diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java index 5f6f59412a..3d16ea20c7 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaInputTest.java @@ -40,7 +40,8 @@ public class MLCreateModelMetaInputTest { @Before public void setup() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config"); + config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", + TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512); mLCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0", "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2); } @@ -74,8 +75,8 @@ public void testToXContent() throws IOException { XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); mLCreateModelMetaInput.toXContent(builder, EMPTY_PARAMS); String mlModelContent = TestHelper.xContentBuilderToString(builder); - final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"LOADING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," - + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\"},\"total_chunks\":2}"; + final String expected = "{\"name\":\"Model Name\",\"function_name\":\"BATCH_RCF\",\"version\":\"1.0\",\"description\":\"Model Description\",\"model_format\":\"TORCH_SCRIPT\",\"model_state\":\"LOADING\",\"model_content_size_in_bytes\":200,\"model_content_hash_value\":\"123\",\"model_config\":{\"model_type\":\"Model Type\"," + + "\"embedding_dimension\":123,\"framework_type\":\"SENTENCE_TRANSFORMERS\",\"all_config\":\"All Config\",\"model_max_length\":512,\"pooling_method\":\"MEAN\",\"normalize_result\":true},\"total_chunks\":2}"; assertEquals(expected, mlModelContent); } } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java index 425ab0ddf3..fb5b45984b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/upload_chunk/MLCreateModelMetaRequestTest.java @@ -30,7 +30,8 @@ public class MLCreateModelMetaRequestTest { @Before public void setUp() { - config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config"); + config = new TextEmbeddingModelConfig("Model Type", 123, FrameworkType.SENTENCE_TRANSFORMERS, "All Config", + TextEmbeddingModelConfig.PoolingMethod.MEAN, true, 512); mlCreateModelMetaInput = new MLCreateModelMetaInput("Model Name", FunctionName.BATCH_RCF, "1.0", "Model Description", MLModelFormat.TORCH_SCRIPT, MLModelState.LOADING, 200L, "123", config, 2); } diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 301fdfa7e5..1516cae170 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -29,7 +29,7 @@ dependencies { implementation group: 'io.protostuff', name: 'protostuff-runtime', version: '1.8.0' implementation group: 'io.protostuff', name: 'protostuff-api', version: '1.8.0' implementation group: 'io.protostuff', name: 'protostuff-collectionschema', version: '1.8.0' - testImplementation group: 'junit', name: 'junit', version: '4.12' + testImplementation group: 'junit', name: 'junit', version: '4.13.2' testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' testImplementation group: 'org.mockito', name: 'mockito-inline', version: '4.4.0' implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre' @@ -62,7 +62,7 @@ jacocoTestCoverageVerification { } limit { counter = 'BRANCH' - minimum = 0.80 //TODO: increase coverage to 0.85 + minimum = 0.79 //TODO: increase coverage to 0.85 } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java index 23f04dda72..ff056985bf 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslator.java @@ -22,6 +22,7 @@ import ai.djl.translate.Batchifier; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; import java.util.Map; @@ -33,10 +34,18 @@ public class HuggingfaceTextEmbeddingTranslator implements Translator arguments) { * @throws IOException if I/O error occurs */ public HuggingfaceTextEmbeddingTranslator build() throws IOException { - return new HuggingfaceTextEmbeddingTranslator(tokenizer, batchifier); + return new HuggingfaceTextEmbeddingTranslator(tokenizer, batchifier, poolingMethod, normalizeResult, modelType, neuron); + } + + public Builder poolingMethod(TextEmbeddingModelConfig.PoolingMethod poolingMethod) { + this.poolingMethod = poolingMethod; + return this; + } + + public Builder normalizeResult(boolean normalizeResult) { + this.normalizeResult = normalizeResult; + return this; + } + + public Builder modelType(String modelType) { + this.modelType = modelType; + return this; + } + + public Builder neuron(boolean neuron) { + this.neuron = neuron; + return this; } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java index 74ecaa6861..6a58dfeff9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/HuggingfaceTextEmbeddingTranslatorFactory.java @@ -7,13 +7,13 @@ import ai.djl.Model; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; -import ai.djl.huggingface.translator.TextEmbeddingTranslator; import ai.djl.modality.Input; import ai.djl.modality.Output; import ai.djl.translate.TranslateException; import ai.djl.translate.Translator; import ai.djl.translate.TranslatorFactory; import ai.djl.util.Pair; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import java.io.IOException; import java.lang.reflect.Type; @@ -31,6 +31,18 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); } + private final TextEmbeddingModelConfig.PoolingMethod poolingMethod; + private boolean normalizeResult; + private final String modelType; + private final boolean neuron; + + public HuggingfaceTextEmbeddingTranslatorFactory(TextEmbeddingModelConfig.PoolingMethod poolingMethod, boolean normalizeResult, String modelType, boolean neuron) { + this.poolingMethod = poolingMethod; + this.normalizeResult = normalizeResult; + this.modelType = modelType; + this.neuron = neuron; + } + /** {@inheritDoc} */ @Override public Set> getSupportedTypes() { @@ -51,7 +63,12 @@ public Translator newInstance( .optManager(model.getNDManager()) .build(); HuggingfaceTextEmbeddingTranslator translator = - HuggingfaceTextEmbeddingTranslator.builder(tokenizer, arguments).build(); + HuggingfaceTextEmbeddingTranslator.builder(tokenizer, arguments) + .poolingMethod(poolingMethod) + .normalizeResult(normalizeResult) + .modelType(modelType) + .neuron(neuron) + .build(); if (input == String.class && output == float[].class) { return (Translator) translator; } else if (input == Input.class && output == Output.class) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java index 6a8222bfbd..eb510182ef 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/ONNXSentenceTransformerTextEmbeddingTranslator.java @@ -15,6 +15,7 @@ import ai.djl.translate.Batchifier; import ai.djl.translate.ServingTranslator; import ai.djl.translate.TranslatorContext; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.output.model.MLResultDataType; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensors; @@ -30,6 +31,15 @@ public class ONNXSentenceTransformerTextEmbeddingTranslator implements ServingTranslator { private static final int[] AXIS = {0}; private HuggingFaceTokenizer tokenizer; + private TextEmbeddingModelConfig.PoolingMethod poolingMethod; + private boolean normalizeResult; + private String modelType; + + public ONNXSentenceTransformerTextEmbeddingTranslator(TextEmbeddingModelConfig.PoolingMethod poolingMethod, boolean normalizeResult, String modelType) { + this.poolingMethod = poolingMethod; + this.normalizeResult = normalizeResult; + this.modelType = modelType; + } @Override public Batchifier getBatchifier() { @@ -52,23 +62,53 @@ public NDList processInput(TranslatorContext ctx, Input input) { ctx.setAttachment("encoding", encode); long[] indices = encode.getIds(); long[] attentionMask = encode.getAttentionMask(); - long[] tokenTypeIds = encode.getTypeIds(); - NDArray indicesArray = manager.create(indices); + NDArray indicesArray = manager.create(indices).expandDims(0); indicesArray.setName("input_ids"); - NDArray attentionMaskArray = manager.create(attentionMask); + NDArray attentionMaskArray = manager.create(attentionMask).expandDims(0); attentionMaskArray.setName("attention_mask"); - NDArray tokenTypeIdsArray = manager.create(tokenTypeIds); - tokenTypeIdsArray.setName("token_type_ids"); - ndList.add(indicesArray.expandDims(0)); - ndList.add(tokenTypeIdsArray.expandDims(0)); - ndList.add(attentionMaskArray.expandDims(0)); + ndList.add(indicesArray); + ndList.add(attentionMaskArray); + if ("bert".equalsIgnoreCase(modelType) || "albert".equalsIgnoreCase(modelType)) { + long[] tokenTypeIds = encode.getTypeIds(); + NDArray tokenTypeIdsArray = manager.create(tokenTypeIds).expandDims(0); + tokenTypeIdsArray.setName("token_type_ids"); + ndList.add(tokenTypeIdsArray); + } return ndList; } /** {@inheritDoc} */ @Override public Output processOutput(TranslatorContext ctx, NDList list) { + NDArray embeddings = null; + switch (this.poolingMethod) { + case MEAN: + embeddings = meanPooling(ctx, list); + break; + case CLS: + embeddings = list.get(0).get(0).get(0); + break; + default: + throw new IllegalArgumentException("Unsupported pooling method"); + } + + if (normalizeResult) { + embeddings = embeddings.normalize(2, 0); + } + + Number[] data = embeddings.toArray(); + List outputs = new ArrayList<>(); + long[] shape = embeddings.getShape().getShape(); + outputs.add(new ModelTensor(SENTENCE_EMBEDDING, data, shape, MLResultDataType.FLOAT32, null)); + + Output output = new Output(); + ModelTensors modelTensorOutput = new ModelTensors(outputs); + output.add(modelTensorOutput.toBytes()); + return output; + } + + private static NDArray meanPooling(TranslatorContext ctx, NDList list) { NDArray embeddings = list.get(0); int shapeLength = embeddings.getShape().getShape().length; if (shapeLength == 3) { @@ -84,19 +124,10 @@ public Output processOutput(TranslatorContext ctx, NDList list) { NDArray clamp = inputAttentionMaskSum.clip(1e-9, 1e12); NDArray prod = embeddings.mul(inputAttentionMask); NDArray sum = prod.sum(AXIS); - embeddings = sum.div(clamp).normalize(2, 0); - - List outputs = new ArrayList<>(); - Number[] data = embeddings.toArray(); - outputs.add(new ModelTensor(SENTENCE_EMBEDDING, data, shape, MLResultDataType.FLOAT32, null)); - - Output output = new Output(); - ModelTensors modelTensorOutput = new ModelTensors(outputs); - output.add(modelTensorOutput.toBytes()); - return output; + embeddings = sum.div(clamp); + return embeddings; } - @Override public void setArguments(Map arguments) { } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java index 6e61cd7b16..1c12be8fad 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModel.java @@ -21,7 +21,6 @@ import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.exception.MLException; -import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; @@ -42,7 +41,6 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -180,23 +178,33 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String devices = Engine.getEngine(engine).getDevices(); for (int i = 0; i < devices.length; i++) { log.debug("load model {} on device {}: {}", modelId, i, devices[i]); - Map arguments = new HashMap<>(); Criteria.Builder criteriaBuilder = Criteria.builder() .setTypes(Input.class, Output.class) .optApplication(Application.UNDEFINED) - .optArguments(arguments) .optEngine(engine) .optDevice(devices[i]) .optModelPath(modelPath); TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig) modelConfig; TextEmbeddingModelConfig.FrameworkType transformersType = textEmbeddingModelConfig.getFrameworkType(); + String modelType = textEmbeddingModelConfig.getModelType(); + TextEmbeddingModelConfig.PoolingMethod poolingMethod = textEmbeddingModelConfig.getPoolingMethod(); + boolean normalizeResult = textEmbeddingModelConfig.isNormalizeResult(); + Integer modelMaxLength = textEmbeddingModelConfig.getModelMaxLength(); + if (modelMaxLength != null) { + criteriaBuilder.optArgument("modelMaxLength", modelMaxLength); + } + //TODO: refactor this when we support more engine type if (ONNX_ENGINE.equals(engine)) { //ONNX - criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator()); + criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMethod, normalizeResult, modelType)); } else { // pytorch if (transformersType == SENTENCE_TRANSFORMERS) { criteriaBuilder.optTranslator(new SentenceTransformerTextEmbeddingTranslator()); } else { - criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory()); + boolean neuron = false; + if (transformersType.name().endsWith("_NEURON")) { + neuron = true; + } + criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory(poolingMethod, normalizeResult, modelType, neuron)); } } Criteria criteria = criteriaBuilder.build(); @@ -206,7 +214,11 @@ protected void loadTextEmbeddingModel(File modelZipFile, String modelId, String modelList.add(model); Input input = new Input(); - input.add("warm up sentence"); + if (modelMaxLength != null) { + input.add("sentence ".repeat(modelMaxLength)); + } else { + input.add("warm up sentence"); + } // First request takes longer time. Predict once to warm up model. predictor.predict(input); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java index ebca604935..06ecb3b475 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java @@ -159,43 +159,57 @@ public void initModel_predict_TorchScript_SentenceTransformer_ResultFilter() { @Test public void initModel_predict_TorchScript_Huggingface() throws URISyntaxException { - Map params = new HashMap<>(); - params.put(MODEL_HELPER, modelHelper); - params.put(MODEL_ZIP_FILE, new File(getClass().getResource("all-MiniLM-L6-v2_torchscript_huggingface.zip").toURI())); - params.put(ML_ENGINE, mlEngine); - Path modelCachePath = mlEngine.getModelCachePath(model.getModelId(), model.getName(), model.getVersion()); - File file = new File(modelCachePath.toUri()); - file.mkdirs(); - TextEmbeddingModelConfig hugginfaceModelConfig = modelConfig.toBuilder() - .frameworkType(HUGGINGFACE_TRANSFORMERS).build(); - MLModel mlModel = model.toBuilder().modelFormat(MLModelFormat.TORCH_SCRIPT).modelConfig(hugginfaceModelConfig).build(); - textEmbeddingModel.initModel(mlModel, params); - ModelTensorOutput output = (ModelTensorOutput)textEmbeddingModel.predict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); - List mlModelOutputs = output.getMlModelOutputs(); - assertEquals(2, mlModelOutputs.size()); - for (int i=0;i mlModelTensors = tensors.getMlModelTensors(); - assertEquals(1, mlModelTensors.size()); - assertEquals(dimension, mlModelTensors.get(position).getData().length); - } - textEmbeddingModel.close(); + String modelFile = "all-MiniLM-L6-v2_torchscript_huggingface.zip"; + String modelType = "bert"; + TextEmbeddingModelConfig.PoolingMethod poolingMethod = TextEmbeddingModelConfig.PoolingMethod.MEAN; + boolean normalize = true; + int modelMaxLength = 512; + MLModelFormat modelFormat = MLModelFormat.TORCH_SCRIPT; + initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMethod, normalize, modelMaxLength, modelFormat, dimension); + } + + @Test + public void initModel_predict_ONNX_bert() throws URISyntaxException { + String modelFile = "all-MiniLM-L6-v2_onnx.zip"; + String modelType = "bert"; + TextEmbeddingModelConfig.PoolingMethod poolingMethod = TextEmbeddingModelConfig.PoolingMethod.MEAN; + boolean normalize = true; + int modelMaxLength = 512; + MLModelFormat modelFormat = MLModelFormat.ONNX; + initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMethod, normalize, modelMaxLength, modelFormat, dimension); } @Test - public void initModel_predict_ONNX() throws URISyntaxException { + public void initModel_predict_ONNX_albert() throws URISyntaxException { + String modelFile = "paraphrase-albert-small-v2_onnx.zip"; + String modelType = "albert"; + TextEmbeddingModelConfig.PoolingMethod poolingMethod = TextEmbeddingModelConfig.PoolingMethod.MEAN; + boolean normalize = false; + int modelMaxLength = 512; + MLModelFormat modelFormat = MLModelFormat.ONNX; + initModel_predict_HuggingfaceModel(modelFile, modelType, poolingMethod, normalize, modelMaxLength, modelFormat, 768); + } + + private void initModel_predict_HuggingfaceModel(String modelFile, String modelType, TextEmbeddingModelConfig.PoolingMethod poolingMethod, + boolean normalizeResult, Integer modelMaxLength, + MLModelFormat modelFormat, int dimension) throws URISyntaxException { Map params = new HashMap<>(); params.put(MODEL_HELPER, modelHelper); - params.put(MODEL_ZIP_FILE, new File(getClass().getResource("all-MiniLM-L6-v2_onnx.zip").toURI())); + params.put(MODEL_ZIP_FILE, new File(getClass().getResource(modelFile).toURI())); params.put(ML_ENGINE, mlEngine); TextEmbeddingModelConfig onnxModelConfig = modelConfig.toBuilder() - .frameworkType(HUGGINGFACE_TRANSFORMERS).build(); - MLModel mlModel = model.toBuilder().modelFormat(MLModelFormat.ONNX).modelConfig(onnxModelConfig).build(); + .frameworkType(HUGGINGFACE_TRANSFORMERS) + .modelType(modelType) + .poolingMethod(poolingMethod) + .normalizeResult(normalizeResult) + .modelMaxLength(modelMaxLength) + .build(); + MLModel mlModel = model.toBuilder().modelFormat(modelFormat).modelConfig(onnxModelConfig).build(); textEmbeddingModel.initModel(mlModel, params); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); ModelTensorOutput output = (ModelTensorOutput)textEmbeddingModel.predict(mlInput); List mlModelOutputs = output.getMlModelOutputs(); + System.out.println(Arrays.toString(mlModelOutputs.get(0).getMlModelTensors().get(0).getData())); assertEquals(2, mlModelOutputs.size()); for (int i=0;i