-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework memcpy transformer to support WebGPU EP being added #22329
Conversation
@@ -16,16 +16,16 @@ using namespace ONNX_NAMESPACE; | |||
namespace onnxruntime { | |||
namespace test { | |||
|
|||
typedef std::vector<onnxruntime::NodeArg*> ArgMap; | |||
typedef std::vector<NodeArg*> ArgMap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most diffs here are from removing the unnecessary onnxruntime::
prefix.
2 tests are updated to assign the 'If' node in the main graph to an EP. The existing test is unchanged apart from the assignment. Test is moved to a lambda with the 'If' node being assigned to the CPU and CUDA EPs.
- lines 195 and 350 is where the existing test was moved into a lambda
- lines 258 and 400 is where the assignment of the 'If' node occurs
ORT_ENFORCE(!incompatible_gpu_eps, "Mixing CUDA/TensorRT, ROCm/MIGraphX, and WebGPU is not supported."); | ||
|
||
for (auto& provider : provider_types_) { | ||
if (utils::ProviderIsCpuBased(provider) == false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Key aspect when reviewing is that the transformer only runs for GPU based EPs.
bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; } | ||
bool operator==(const ConstIterator& rhs) const noexcept { return current_ == rhs.current_; } | ||
bool operator!=(const ConstIterator& rhs) const noexcept { return current_ != rhs.current_; } | ||
size_t operator-(const ConstIterator& rhs) const noexcept { return current_ - rhs.current_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be ptrdiff_t?
Description
Rework the memcpy transformer to simplify and support an additional GPU based EP.
Miscellaneous:
Easier to review with whitespace diffs hidden.
Motivation and Context
Fix CI failures when WebGPU and CUDA EPs are enabled in the same build.