Skip to content

Commit

Permalink
Add workaround for refinement issue
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Nov 16, 2023
1 parent c869f27 commit e988f47
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions src/common/raft/proto/raft_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ struct raft_index {
}
}
if (refine_ratio > 1.0f) {
if constexpr (vector_index_kind != raft_index_kind::cagra) {
if (dataset.has_value()) {
if (dataset.has_value()) {
if constexpr (std::is_same_v<IdxT, InputIdxT>) {
raft::neighbors::refine(
res,
*dataset,
Expand All @@ -265,16 +265,42 @@ struct raft_index {
underlying_index.metric()
);
} else {
RAFT_LOG_WARN(
"Refinement requested, but no dataset provided. "
"Ignoring refinement request."
// https://github.com/rapidsai/raft/issues/1950
raft::neighbors::refine(
res,
raft::make_device_matrix_view(
dataset->data_handle(),
IdxT(dataset->extent(0)),
IdxT(dataset->extent(1))
),
raft::make_device_matrix_view(
queries.data_handle(),
IdxT(queries.extent(0)),
IdxT(queries.extent(1))
),
raft::make_const_mdspan(
raft::make_device_matrix_view(
neighbors_tmp.data_handle(),
IdxT(neighbors_tmp.extent(0)),
IdxT(neighbors_tmp.extent(1))
)
),
raft::make_device_matrix_view(
neighbors.data_handle(),
IdxT(neighbors.extent(0)),
IdxT(neighbors.extent(1))
),
raft::make_device_matrix_view(
distances.data_handle(),
IdxT(distances.extent(0)),
IdxT(distances.extent(1))
),
underlying_index.metric()
);
}
} else {
// TODO(wphicks): Determine why CAGRA refinement fails to compile
RAFT_LOG_WARN(
"Refinement requested, but refinement is not yet implemented for "
"CAGRA. Ignoring refinement request."
"Refinement requested, but no dataset provided. Ignoring refinement request."
);
}
}
Expand Down

0 comments on commit e988f47

Please sign in to comment.