Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correct signatures for torch allocator plug in
Since pytorch/pytorch#91398, the signature of the pluggable allocate and deallocate functions must accept the device id. The current version only accepts a device id for allocate, which means that when using a stream ordered allocator with devices other than device zero, we pass an invalid stream into the deallocation function. To fix this, adapt the signature to match the one pytorch expects. Now, since we have the device available during allocation and deallocation, we would like to use that device to obtain the appropriate memory resource. Unfortunately, since RMM's cuda_device_id does not have a nullary constructor, we can't use it in Cython without some hacky workarounds. However, since we don't actually need to build a Python module, but rather just a single shared library that offers two extern "C" functions, let's just write our allocator hooks directly in C++. - Closes #1405
- Loading branch information