-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add more parameters for text embedding model #640
Conversation
Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Yaliang Wu <[email protected]>
@@ -74,12 +91,21 @@ public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOExc | |||
case ALL_CONFIG_FIELD: | |||
allConfig = parser.text(); | |||
break; | |||
case POOLING_METHOD_FIELD: | |||
poolingMethod = PoolingMethod.from(parser.text().toUpperCase()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add Locale.ROOT same as FRAMEWORK_TYPE_FIELD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will add this
} | ||
|
||
public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException { | ||
String modelType = null; | ||
Integer embeddingDimension = null; | ||
FrameworkType frameworkType = null; | ||
String allConfig = null; | ||
PoolingMethod poolingMethod = null; | ||
boolean normalizeResult = false; | ||
Integer modelMaxLength = null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we have a default value for modelMaxLength?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We depend on DJL engine to set the default value.
} | ||
|
||
public static TextEmbeddingModelConfig parse(XContentParser parser) throws IOException { | ||
String modelType = null; | ||
Integer embeddingDimension = null; | ||
FrameworkType frameworkType = null; | ||
String allConfig = null; | ||
PoolingMethod poolingMethod = null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set the default value PoolingMethod.MEAN here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the constructor will set default value. I will set default value here to make it more clear.
} else { | ||
this.poolingMethod = PoolingMethod.MEAN; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we set the default value below, We don't need this else branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you mean if we set default value in line 72 PoolingMethod poolingMethod = null;
, we don't need line 61 this.poolingMethod = PoolingMethod.MEAN;
?
I think we still need this. This is constructor method, user can create a new object directly without calling parse
method
try { | ||
return PoolingMethod.valueOf(value); | ||
} catch (Exception e) { | ||
throw new IllegalArgumentException("Wrong framework type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copy error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, will fix
@@ -31,6 +31,18 @@ public class HuggingfaceTextEmbeddingTranslatorFactory implements TranslatorFact | |||
SUPPORTED_TYPES.add(new Pair<>(Input.class, Output.class)); | |||
} | |||
|
|||
private final TextEmbeddingModelConfig.PoolingMethod poolingMethod; | |||
private boolean normalizeResult; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
final?
if (ONNX_ENGINE.equals(engine)) { //ONNX | ||
criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator()); | ||
criteriaBuilder.optTranslator(new ONNXSentenceTransformerTextEmbeddingTranslator(poolingMethod, normalizeResult, modelType)); | ||
} else { // pytorch | ||
if (transformersType == SENTENCE_TRANSFORMERS) { | ||
criteriaBuilder.optTranslator(new SentenceTransformerTextEmbeddingTranslator()); | ||
} else { | ||
criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory()); | ||
boolean neuron = false; | ||
if (transformersType.name().endsWith("_NEURON")) { | ||
neuron = true; | ||
} | ||
criteriaBuilder.optTranslatorFactory(new HuggingfaceTextEmbeddingTranslatorFactory(poolingMethod, normalizeResult, modelType, neuron)); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could refactor this part to support more engines better in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we can refactor this when we support new engines, add some todo now
StringBuilder builder = new StringBuilder(); | ||
for (int j=0;j<modelMaxLength;j++) { | ||
builder.append("sentence "); | ||
} | ||
input.add(builder.toString()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just replacing this part with one line as below?
input.add("sentence ".repeat(modelMaxLength));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, good point
Signed-off-by: Yaliang Wu <[email protected]>
Codecov Report
@@ Coverage Diff @@
## 2.x #640 +/- ##
============================================
- Coverage 84.68% 84.57% -0.12%
+ Complexity 984 982 -2
============================================
Files 92 92
Lines 3540 3540
Branches 326 326
============================================
- Hits 2998 2994 -4
- Misses 407 410 +3
- Partials 135 136 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
* add more parameters for text embedding model Signed-off-by: Yaliang Wu <[email protected]> * upgrade junit version to 4.13.2 Signed-off-by: Yaliang Wu <[email protected]> * address comments Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Yaliang Wu <[email protected]>
* add more parameters for text embedding model Signed-off-by: Yaliang Wu <[email protected]> * upgrade junit version to 4.13.2 Signed-off-by: Yaliang Wu <[email protected]> * address comments Signed-off-by: Yaliang Wu <[email protected]> Signed-off-by: Yaliang Wu <[email protected]>
* add more parameters for text embedding model * upgrade junit version to 4.13.2 * address comments Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Yaliang Wu [email protected]
Description
Add more parameters for text embedding model
Issues Resolved
[List any issues this PR will resolve]
Check List
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.