diff --git a/src/common/raft/proto/raft_index.cuh b/src/common/raft/proto/raft_index.cuh index 76f8dd10a..ff25b7e04 100644 --- a/src/common/raft/proto/raft_index.cuh +++ b/src/common/raft/proto/raft_index.cuh @@ -253,8 +253,8 @@ struct raft_index { } } if (refine_ratio > 1.0f) { - if constexpr (vector_index_kind != raft_index_kind::cagra) { - if (dataset.has_value()) { + if (dataset.has_value()) { + if constexpr (std::is_same_v) { raft::neighbors::refine( res, *dataset, @@ -265,16 +265,42 @@ struct raft_index { underlying_index.metric() ); } else { - RAFT_LOG_WARN( - "Refinement requested, but no dataset provided. " - "Ignoring refinement request." + // https://github.com/rapidsai/raft/issues/1950 + raft::neighbors::refine( + res, + raft::make_device_matrix_view( + dataset->data_handle(), + IdxT(dataset->extent(0)), + IdxT(dataset->extent(1)) + ), + raft::make_device_matrix_view( + queries.data_handle(), + IdxT(queries.extent(0)), + IdxT(queries.extent(1)) + ), + raft::make_const_mdspan( + raft::make_device_matrix_view( + neighbors_tmp.data_handle(), + IdxT(neighbors_tmp.extent(0)), + IdxT(neighbors_tmp.extent(1)) + ) + ), + raft::make_device_matrix_view( + neighbors.data_handle(), + IdxT(neighbors.extent(0)), + IdxT(neighbors.extent(1)) + ), + raft::make_device_matrix_view( + distances.data_handle(), + IdxT(distances.extent(0)), + IdxT(distances.extent(1)) + ), + underlying_index.metric() ); } } else { - // TODO(wphicks): Determine why CAGRA refinement fails to compile RAFT_LOG_WARN( - "Refinement requested, but refinement is not yet implemented for " - "CAGRA. Ignoring refinement request." + "Refinement requested, but no dataset provided. Ignoring refinement request." ); } }