-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ortvalue features support for MGX EP #81
base: rocm6.3_internal_testing
Are you sure you want to change the base?
Conversation
if (!IsRocmDeviceIdValid(logging::LoggingManager::DefaultLogger(), device.Id())) { | ||
throw std::runtime_error("The provided device id doesn't match any available GPUs on the machine."); | ||
} | ||
allocator = GetRocmAllocator(device.Id()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might be in an odd situation here as our offering has both MIGraphx and ROCm EPs include, thus we should we get both allocators? Did you test this when we build both MIGraphX and ROCm EPs? How does the allocator work for that?
#elif USE_MIGRAPHX | ||
// InputDeflist is null because OrtValue creation is not tied to a specific model | ||
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list) | ||
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put this comment in reference to MIGraphX and not CUDA
|
||
AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id) { | ||
// Current approach is not thread-safe, but there are some bigger infra pieces to put together in order to make | ||
// multi-threaded MIGraphX allocation work we need to maintain a per-thread MIGraphX allocator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this an issue and attach it to the ticket if we want to make this on a per thread allocation. We should roadmap this out so we can tackle these pieces in the new year
// make it stream aware | ||
true, | ||
// enable cross stream sharing? | ||
false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this something we want to make controllable from he API later?
// The function will return once the pageable buffer has been copied to the staging memory for DMA transfer | ||
// to device memory, but the DMA to final destination may not have completed. | ||
|
||
HIP_CALL_THROW(hipStreamSynchronize(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we always want to be using hipstream 0 for this?
(static_cast<size_t>(info.model_cache_enable) << 21) ^ | ||
(static_cast<size_t>(info.save_compiled_model) << 22) ^ | ||
(static_cast<size_t>(info.load_compiled_model) << 23) ^ | ||
(static_cast<size_t>(info.exhaustive_tune) << 24); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going forward is the intent to add the other flags (fp16/int8) and other quantize modes in here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
Few questions about this. Overall looks good.
I've added questions/comments. One detail about combined ROCm/MIGraphX EP builds and if you've tested this with both.
also if you can, download and use lintrunner in your env to solve the lint issue. It'll make upstreaming easier
|
Created PR request with implementation of
ortvalue_from_numpy()
andortvalue_from_shape_and_type()
features for MGX EP on Windows and Linux in order of getting better performance forllama2 int 4
model execution. Some methods have been overridden and some of them implemented similar like it was done in ROCm EP. Implementing these features we significantly decreased amount of time needed for creating and copying tensors, almost whole time is dedicated to GPU now, which caused much better performance in tok/s for our GPUs. Similar option added for ROCM EP.