Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more parameters for text embedding model #640

Merged
merged 3 commits into from
Dec 20, 2022

Conversation

ylwu-amzn
Copy link
Collaborator

Signed-off-by: Yaliang Wu [email protected]

Description

Add more parameters for text embedding model

  1. model max length: how many tokens the model can support at most
  2. pooling method: we only support mean pooling method in 2.4 release, this PR add cls pooling support
  3. normalize result: boolean, will normalize result if this is true.

Issues Resolved

[List any issues this PR will resolve]

Check List

  • New functionality includes testing.
    • All tests pass
  • New functionality has been documented.
    • New functionality has javadoc added
  • Commits are signed per the DCO using --signoff

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

@ylwu-amzn ylwu-amzn requested a review from a team December 19, 2022 22:49
@ylwu-amzn ylwu-amzn added the enhancement New feature or request label Dec 19, 2022
rbhavna
rbhavna previously approved these changes Dec 20, 2022
@@ -74,12 +91,21 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc
case ALL_CONFIG_FIELD:
allConfig = parser.text();
break;
case POOLING_METHOD_FIELD:
poolingMethod = PoolingMethod.from(parser.text().toUpperCase());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add Locale.ROOT same as FRAMEWORK_TYPE_FIELD?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, will add this

}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
String modelType = null;
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMethod poolingMethod = null;
boolean normalizeResult = false;
Integer modelMaxLength = null;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we have a default value for modelMaxLength?

Copy link
Collaborator Author

@ylwu-amzn ylwu-amzn Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We depend on DJL engine to set the default value.

}

public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException {
String modelType = null;
Integer embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
PoolingMethod poolingMethod = null;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set the default value PoolingMethod.MEAN here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the constructor will set default value. I will set default value here to make it more clear.

Comment on lines +60 to +62
} else {
this.poolingMethod = PoolingMethod.MEAN;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we set the default value below, We don't need this else branch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you mean if we set default value in line 72 PoolingMethod poolingMethod = null;, we don't need line 61 this.poolingMethod = PoolingMethod.MEAN; ?

I think we still need this. This is constructor method, user can create a new object directly without calling parse method

try {
return PoolingMethod.valueOf(value);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong framework type");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will fix

@@ -31,6 +31,18 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class));
}

private final TextEmbeddingModelConfig.PoolingMethod poolingMethod;
private boolean normalizeResult;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final?

Comment on lines 196 to 208
if (ONNX_ENGINE.equals(engine)) { //ONNX
criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator());
criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMethod, normalizeResult, modelType));
} else { // pytorch
if (transformersType == SENTENCE_TRANSFORMERS) {
criteriaBuilder.optTranslator(new SentenceTransformerTextEmbeddingTranslator());
} else {
criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory());
boolean neuron = false;
if (transformersType.name().endsWith("_NEURON")) {
neuron = true;
}
criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory(poolingMethod, normalizeResult, modelType, neuron));
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could refactor this part to support more engines better in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can refactor this when we support new engines, add some todo now

Comment on lines 217 to 221
StringBuilder builder = new StringBuilder();
for (int j=0;j<modelMaxLength;j++) {
builder.append("sentence ");
}
input.add(builder.toString());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just replacing this part with one line as below?
input.add("sentence ".repeat(modelMaxLength));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, good point

Signed-off-by: Yaliang Wu <[email protected]>
@codecov-commenter
Copy link

codecov-commenter commented Dec 20, 2022

Codecov Report

Merging #640 (cd8d30a) into 2.x (d92e229) will decrease coverage by 0.11%.
The diff coverage is n/a.

@@             Coverage Diff              @@
##                2.x     #640      +/-   ##
============================================
- Coverage     84.68%   84.57%   -0.12%     
+ Complexity      984      982       -2     
============================================
  Files            92       92              
  Lines          3540     3540              
  Branches        326      326              
============================================
- Hits           2998     2994       -4     
- Misses          407      410       +3     
- Partials        135      136       +1     
Flag Coverage Δ
ml-commons 84.57% <ø> (-0.12%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
.../cluster/MLCommonsClusterManagerEventListener.java 65.62% <0.00%> (-12.50%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@ylwu-amzn ylwu-amzn merged commit b3ae98d into opensearch-project:2.x Dec 20, 2022
ylwu-amzn added a commit to ylwu-amzn/ml-commons that referenced this pull request Feb 17, 2023
* add more parameters for text embedding model

Signed-off-by: Yaliang Wu <[email protected]>

* upgrade junit version to 4.13.2

Signed-off-by: Yaliang Wu <[email protected]>

* address comments

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
ylwu-amzn added a commit to ylwu-amzn/ml-commons that referenced this pull request Feb 28, 2023
* add more parameters for text embedding model

Signed-off-by: Yaliang Wu <[email protected]>

* upgrade junit version to 4.13.2

Signed-off-by: Yaliang Wu <[email protected]>

* address comments

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
ylwu-amzn added a commit that referenced this pull request Feb 28, 2023
* add more parameters for text embedding model



* upgrade junit version to 4.13.2



* address comments

Signed-off-by: Yaliang Wu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants