Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re enable IVF random sampling #2225

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3f5e149
Make subsampling use less memory
tfeher Feb 5, 2024
1d2a681
Add subsample benchmark
tfeher Feb 5, 2024
cabd94f
Merge branch 'branch-24.04' into ivf_subsample2
tfeher Mar 11, 2024
4040a96
debug
tfeher Mar 12, 2024
e09c9f7
Fix bug
tfeher Mar 12, 2024
a6f9083
add tests
tfeher Mar 12, 2024
941e165
cleanup
tfeher Mar 13, 2024
eb73ef5
added sample_rows to matrix namespace
tfeher Mar 13, 2024
cc2cf24
add test for sample rows
tfeher Mar 13, 2024
eb7e6d1
Add mdspan input API, fix cmakelists
tfeher Mar 13, 2024
7857f2f
corrections
tfeher Mar 13, 2024
93ff94f
Add test to sample_rows
tfeher Mar 13, 2024
f2c28ce
Revert "[HOTFIX] 24.02 Revert Random Sampling (#2144)"
tfeher Mar 14, 2024
47eefd4
Use the new matrix::sample_rows API
tfeher Mar 14, 2024
3f9cbc3
Address issues
tfeher Mar 15, 2024
57cb99c
change member variables in test to local vars
tfeher Mar 15, 2024
84e307e
Fix omp gather and add bench
tfeher Mar 18, 2024
1dd9e13
Merge branch 'branch-24.04' into ivf_subsample2
tfeher Mar 18, 2024
c369149
Merge remote-tracking branch 'tfeher/ivf_subsample2' into re_enable_i…
tfeher Mar 18, 2024
4ced8c4
Merge remote-tracking branch 'origin/branch-24.04' into re_enable_ivf…
tfeher Mar 18, 2024
84609de
Adjust comment
tfeher Mar 18, 2024
6ab5f9a
Fix params for sample_rows
tfeher Mar 19, 2024
f01fa61
Change IVF cluster warning messages to debug msg
tfeher Mar 19, 2024
739ff05
Merge branch 'branch-24.04' into re_enable_ivf_random_sampling
tfeher Mar 19, 2024
28a0ed7
Remove changes from ann_utils.cuh
tfeher Mar 19, 2024
d88e2d3
allocate trainset usind default allocator
tfeher Mar 19, 2024
66e696a
Merge branch 'branch-24.06' into re_enable_ivf_random_sampling
tfeher Mar 21, 2024
4d7d7dd
Merge branch 'branch-24.06' into re_enable_ivf_random_sampling
tfeher Apr 8, 2024
bf67643
Merge branch 'branch-24.06' into re_enable_ivf_random_sampling
tfeher Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -881,10 +881,10 @@ auto build_fine_clusters(const raft::resources& handle,
if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; }
}
if (k != static_cast<IdxT>(mesocluster_sizes[i]))
RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu",
static_cast<int>(i),
static_cast<size_t>(k),
static_cast<size_t>(mesocluster_sizes[i]));
RAFT_LOG_DEBUG("Incorrect mesocluster size at %d. %zu vs %zu",
tfeher marked this conversation as resolved.
Show resolved Hide resolved
static_cast<int>(i),
static_cast<size_t>(k),
static_cast<size_t>(mesocluster_sizes[i]));
if (k == 0) {
RAFT_LOG_DEBUG("Empty cluster %d", i);
RAFT_EXPECTS(fine_clusters_nums[i] == 0,
Expand Down Expand Up @@ -1030,7 +1030,7 @@ void build_hierarchical(const raft::resources& handle,
const IdxT mesocluster_size_max_balanced = div_rounding_up_safe<size_t>(
2lu * size_t(n_rows), std::max<size_t>(size_t(n_mesoclusters), 1lu));
if (mesocluster_size_max > mesocluster_size_max_balanced) {
RAFT_LOG_WARN(
RAFT_LOG_DEBUG(
"build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). "
"At most %u points will be used for training within each mesocluster. "
"Consider increasing the number of training iterations `n_iters`.",
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1682,12 +1682,14 @@ auto build(raft::resources const& handle,
handle, device_mr, make_extents<internal_extents_t>(n_rows_train, dim));

if constexpr (std::is_same_v<T, float>) {
raft::matrix::detail::sample_rows(handle, random_state, dataset, n_rows, trainset.view());
raft::matrix::detail::sample_rows<T, int64_t>(
handle, random_state, dataset, n_rows, trainset.view());
} else {
// TODO(tfeher): Enable codebook generation with any type T, and then remove trainset tmp.
auto trainset_tmp = make_device_mdarray<T>(
handle, &managed_mr, make_extents<internal_extents_t>(n_rows_train, dim));
raft::matrix::detail::sample_rows(handle, random_state, dataset, n_rows, trainset_tmp.view());
raft::matrix::detail::sample_rows<T, int64_t>(
tfeher marked this conversation as resolved.
Show resolved Hide resolved
handle, random_state, dataset, n_rows, trainset_tmp.view());

raft::linalg::unaryOp(trainset.data_handle(),
trainset_tmp.data_handle(),
Expand Down
Loading