-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CosineExpanded Metric for IVF-PQ (normalize inputs) #346
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Couple nitpicks and a more substantial question about the cluster center handling here.
@@ -156,6 +172,35 @@ void select_clusters(raft::resources const& handle, | |||
n_lists, | |||
stream); | |||
|
|||
if (metric == distance::DistanceType::CosineExpanded) { | |||
// TODO: store dataset norms in a different manner for the cosine metric to avoid the copy here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get it. What's the difference to the inner product here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should choose which clusters to probe using cosine distance.
We don't currently have cosine distance for ivf-pq (see rapidsai/cuvs#346) and we also don't have correlation distance support at all. re-add the metricprocessor code to handle this
batch_labels_view, | ||
utils::mapping<float>{}); | ||
|
||
if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuvs::cluster::kmeans_balanced::predict
Already supports Cosine metric, so there is no need to add normalization + switch to inner product
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I tried that. I also tried normalizing the cluster centers, but that does not give good recall. I get the best recall when I normalize the inputs and use inner product.
@@ -1754,7 +1796,13 @@ auto build(raft::resources const& handle, | |||
cluster_centers, index.n_lists(), index.dim()); | |||
cuvs::cluster::kmeans::balanced_params kmeans_params; | |||
kmeans_params.n_iters = params.kmeans_n_iters; | |||
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>((int)index.metric()); | |||
if (index.metric() == distance::DistanceType::CosineExpanded) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuvs::cluster::kmeans_balanced::fit
Already supports Cosine metric, so there is no need to add normalization + switch to inner product.
@@ -137,6 +141,18 @@ void select_clusters(raft::resources const& handle, | |||
alpha = -1.0; | |||
beta = 0.0; | |||
} break; | |||
case cuvs::distance::DistanceType::CosineExpanded: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @achirkin noted, the norms of the centers and of the queries should be accounted for when computing the cosine distance. Right now only the norms of the queries is used, and this can result in the wrong clusters getting selected.
In IVF-Flat: https://github.com/rapidsai/cuvs/blob/branch-24.10/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh#L166
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to do that, but that gives poorer recall. Simply using inner product to select the clusters to probe gives better recall in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Among all of the things that I tried, normalizing the dataset and queries and using inner product directly works the best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tarang-jain for the updates. I think this is a long-awaited feature, but it's not so urgent to squeeze it in 24.10. I'd suggest we take a bit more time to make sure it has good, well understood performance from day one in the main branch.
If we decide to push this to 24.12, it would be nice to run a few benchmarks to see how cosine metric fares against other metrics and against the cuVS main branch.
auto float_vec_batch = raft::make_device_mdarray<float, internal_extents_t>( | ||
handle, | ||
device_memory, | ||
raft::make_extents<internal_extents_t>(vec_batch.size(), index->dim())); | ||
raft::linalg::map(handle, | ||
float_vec_batch.view(), | ||
utils::mapping<float>{}, | ||
raft::make_device_matrix_view<const T, internal_extents_t>( | ||
vec_batch.data(), vec_batch.size(), index->dim())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This extra new code adds an overhead to the already existing metrics; the extra allocation uses the device memory, which is otherwise is carefully accounted in the calculation above. This means (1) we may have slowdown in important use-cases (e.g. CAGRA build using IVF-PQ), (2) we may get OOM error under some conditions.
If having here an extra allocation is really unavoidable for the cosine metric, I'd suggest limiting it only to this metric, using batches_mr
for the allocation and then adjusting the estimate of the required workspace size above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@achirkin for float datatype, we can normalize in place. We just need the extra memory for the float batch when the data type is uint8 or int8.
Need this for code freeze
/merge |
No description provided.