-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pqCodeDistances read out of bounds of pqCentroids in loadingThreads subset #1421
Comments
|
Summary: This diff removes a long-standing limitation with GpuIndexIVFPQ, in that only a limited number of dimensions per sub-quantizer were supported when not using precomputed codes. This is part of the general cleanup and extension/optimization that I am performing of the GPU PQ code. Now, we keep the same old specialized distance computations, but if we attempt to use a number of dimensions per sub-Q that are not specialized, we fall back to a general implementation based on batch matrix multiplication for computing PQ distances per code. The batch MM PQ distance computation is enabled automatically if you use an odd number of dimensions per sub-quantizer (say, 7, 11, 53, ...). It can also be manually enabled via the `useMMCodeDistance` option in `GpuIndexIVFPQConfig` for testing purposes, though the result should be within some epsilon of the other implementation. This diff also removes the iterated GEMM wrapper. I don't honestly know why I was using this instead of `cublasGemmStridedBatchedEx`, maybe I couldn't find that or this was originally implemented in a much older version of CUDA. The iterated GEMM call was used in a few other places (e.g., precomputed code computation). Now, this (and the PQ distance computation) use batch MM which is a single CUDA call. This diff also adds stream synchronization to the temporary memory manager, as the fallback PQ distance computation needs to use temporary memory, and there were too many buffers for these to pre-allocate. It also fixes the bug in #1421. Reviewed By: mdouze Differential Revision: D24130629 fbshipit-source-id: 1c8bc53c86d0523832ad89c8bd4fa4b5fc187cae
Fixed in 9b007c7 |
I'm recently very interested in faiss, and run some tests on faiss gpu IVFPQ for 128-dimensional vectors, with M=32, and 256 centroids. Here are some questions of mine.
faiss/faiss/gpu/impl/PQCodeDistances-inl.cuh
Line 51 in d8af513
faiss/faiss/gpu/impl/PQCodeDistances-inl.cuh
Line 58 in d8af513
In the test, there are 288 threads and 256 centroids, which means
code=288
at line 51. What confuses me most is why line 58 doesn't go wrong whenpqCentroids(32x4x256)
trying to obtain data out of the range (when code >= 256), and wouldn't it go wrong at some point?The 2nd question is, why do Loading Threads need
subQuantizerData
since they don't processsubQuantizerData
at all.The 3rd question is, would it be better had line 56-59 been moved to line 155? why or why not?
The text was updated successfully, but these errors were encountered: