-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Read function Name from pretrained model #1529
Conversation
@@ -87,7 +87,8 @@ public void downloadPrebuiltModelConfig(String taskId, MLRegisterModelInput regi | |||
.url(modelZipFileUrl) | |||
.deployModel(deployModel) | |||
.modelNodeIds(modelNodeIds) | |||
.modelGroupId(modelGroupId); | |||
.modelGroupId(modelGroupId) | |||
.functionName(FunctionName.from((String) config.get("model_task_type")));; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we get Function name from registerModelInput
like the way we get other inputs from 57-62 lines?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The registerModelInput is from the request body json. So we can request customer to provide it. But if we want to keep the request convention to only contain "name, version, model_format" for our pretrained model, we can only read it from pretrained config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if config.get("model_task_type")
is null?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for each pretrained model, we should have that field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like may be we should address this section
currently we default to text embedding, which doesn't seem right. What happens if we start adding different pre-trained models like right now we added splade model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like may be we should address this section
currently we default to text embedding, which doesn't seem right. What happens if we start adding different pre-trained models like right now we added splade model.
Can we make the function name mandatory, when it's null, throw exception instead of setting to a default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one solution could be: we can add function_name in our model listing and then get the function name from there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like may be we should address this section
currently we default to text embedding, which doesn't seem right. What happens if we start adding different pre-trained models like right now we added splade model.Can we make the function name mandatory, when it's null, throw exception instead of setting to a default value?
No. When we register pretrained model, we don't provide function name so in some scenario it would be null. We need a default value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one solution could be: we can add function_name in our model listing and then get the function name from there.
I think both work. Model listing and pretrained config are both files we maintain, so it's just an option to read from which file in our s3 bucket.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean now. We get this model_task_type
from config.json like this. Yeah this should work.
@@ -609,7 +609,7 @@ private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask, | |||
|
|||
private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { | |||
String taskId = mlTask.getTaskId(); | |||
FunctionName functionName = mlTask.getFunctionName(); | |||
FunctionName functionName = registerModelInput.getFunctionName(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to change to registerModelInput
? Will it break BWC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't provide url, the function name from ML task would be read from request body, while the modelInput would be generated from config. If we still use the function name from ml task, the name would still be null and the default would be "text_embedding". I have tried locally for body with url and without url, both worked for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLTask must track the correct function name. Can you check if the function name in MLTask correct or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. I have tested it again. If the request body is like:
{
"name": "amazon/neural-sparse/opensearch-neural-sparse-encoding-v1",
"version": "1.0.0",
"model_format": "TORCH_SCRIPT"
}
The function name inside the ml task would be text_embedding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, after changing to registerModelInput.getFunctionName();
, is the function name in MLTask correct now for both text embedding and sparse model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already done by set function name inside mltask and rewrite to ml index.
e0b2487
to
8dc7d88
Compare
Signed-off-by: xinyual <[email protected]>
Signed-off-by: xinyual <[email protected]>
8dc7d88
to
cb35dd8
Compare
There's a merge conflict. |
Signed-off-by: xinyual <[email protected]>
I guess we could do it now. I have merged from main branch. |
Signed-off-by: xinyual <[email protected]>
The backport to
To backport manually, run these commands in your terminal: # Fetch latest updates from GitHub
git fetch
# Create a new working tree
git worktree add .worktrees/backport-2.x 2.x
# Navigate to the new working tree
cd .worktrees/backport-2.x
# Create a new branch
git switch --create backport/backport-1529-to-2.x
# Cherry-pick the merged commit of this pull request and resolve the conflicts
git cherry-pick -x --mainline 1 4d53db5d987b1940102c0f1eba12295a2f1bd5ca
# Push it to GitHub
git push --set-upstream origin backport/backport-1529-to-2.x
# Go back to the original working tree
cd ../..
# Delete the working tree
git worktree remove .worktrees/backport-2.x Then, create a pull request where the |
…roject#1529) Signed-off-by: xinyual <[email protected]>
Signed-off-by: xinyual <[email protected]>
* read Function name from pretrained config Signed-off-by: xinyual <[email protected]> * rewrite mltask Signed-off-by: xinyual <[email protected]> * optimize import Signed-off-by: xinyual <[email protected]> * apply spotless Signed-off-by: xinyual <[email protected]> * add test for function name Signed-off-by: xinyual <[email protected]> * apply spotless Signed-off-by: xinyual <[email protected]> * maintain single import Signed-off-by: xinyual <[email protected]> * add more test Signed-off-by: xinyual <[email protected]> * apply spot less Signed-off-by: xinyual <[email protected]> * apply spot less Signed-off-by: xinyual <[email protected]> --------- Signed-off-by: xinyual <[email protected]>
Description
Currently if we register pretrained model without url, it will set the default function name to text_embedding since we only have text embedding pretrained model. But now we have sparse encoding, so we need to read the function name from pretrained model config.
Issues Resolved
[List any issues this PR will resolve]
Check List
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.