-
Notifications
You must be signed in to change notification settings - Fork 788
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Smart execution providers #35
Smart execution providers #35
Conversation
@xenova This should do the trick. I need to test the web and check that the fallback is proceeding fine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dependency is correct :)
This is great! Thanks for putting the time in to get it working. I'll test on my side and merge as soon as I can :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the execution provided switch to wasm if the webgl backend fails? If so, then this is alright. If not, then I am slightly worried about using backends (webgl/cuda/webgpu) that do not fully support the necessary operations.
Can you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX will fallback to the next one if its fails. However this behaviour is flaky and because f that I have not included all the backends. With the ones selected we should be ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you're accessing ONNX through the tensor utils file, but I think it would be better if we create a separate file (e.g., backend.js or onnx.js) which handles the loading and fallbacks of the various imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same feeling but I did not want to be too big or out of the scope.
I think that we need also to review tensor_utils, i.e. I do not think that we have to implement softmax ourselves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall - just some questions about how fallbacks are handled and organization details
@xenova Im not totally happy with the PR unless we remove all the backends and allow the user to install Also the fallback is not working if the model do not support a layer 🤦 I think that we should provide a way to expose the desired executor. let ONNX;
let executionProviders = [ 'wasm' ];
try {
ONNX = require('onnxruntime-node');
executionProviders = [ 'cuda', 'cpu' ];
} catch (err) {
ONNX = require('onnxruntime-web');
if(typeof process === 'object') {
// https://github.com/microsoft/onnxruntime/issues/10311
ONNX.env.wasm.numThreads = 1;
}
} With the code below we have at least fixed a rough edge and allow the user to use node bindings if desired. What do you think? |
Yeah I agree 👍 Once WebGPU releases, I'll be more focused on getting GPU support working (both for node and browser). I'll merge this main into this PR, make some edits, then merge the PR back into main (hopefully soon haha). Thanks again for your contributions! |
I ran some tests to see what kind of speedup these changes make, and it's amazing!
🎉 I'm doing some final merging and will hopefully get it published soon :) |
@DavidGOrtega Can you grant me write access to your fork? I would like to push the changes without having to create a new fork and make a PR. (I think you can just add me as a collaborator?) |
Smart execution providers (Merges #35 into main)
Got it working, PR merged! 🎉 Thanks again for your contributions! |
The purpose of this PR is:
onnxruntime-node
in case is installed as dep in any node app using transformers.js (5x faster than WASM). If not installed it will falback to WASM provider.numThreads
to oneonnxruntime-web
requirement dependancy under one file