Skip to content

Commit

Permalink
Throw an error if min_remaining is greater than num_neighbors.
Browse files Browse the repository at this point in the history
Also coerce to the correct type for integer comparison.
  • Loading branch information
LTLA committed Aug 1, 2024
1 parent 7648bb8 commit e7a175c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
9 changes: 7 additions & 2 deletions include/nenesub/nenesub.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ void compute(Index_ num_obs, GetNeighbors_ get_neighbors, GetIndex_ get_index, G

selected.clear();
std::vector<uint8_t> tainted(num_obs);
Index_ min_remaining = options.min_remaining;
while (!store.empty()) {
auto payload = store.top();
store.pop();
Expand All @@ -137,7 +138,7 @@ void compute(Index_ num_obs, GetNeighbors_ get_neighbors, GetIndex_ get_index, G
const auto& neighbors = get_neighbors(payload.identity);
Index_ new_remaining = remaining[payload.identity];

if (new_remaining >= options.min_remaining) {
if (new_remaining >= min_remaining) {
payload.remaining = new_remaining;
if (!store.empty() && cmp(payload, store.top())) {
store.push(payload);
Expand Down Expand Up @@ -206,10 +207,14 @@ std::vector<Index_> compute(const knncolle::NeighborList<Index_, Distance_>& nei
*/
template<typename Dim_, typename Index_, typename Float_>
std::vector<Index_> compute(const knncolle::Prebuilt<Dim_, Index_, Float_>& prebuilt, const Options& options) {
int k = options.num_neighbors;
if (k < options.min_remaining) {
throw std::runtime_error("number of neighbors is less than 'min_remaining'");
}

Index_ nobs = prebuilt.num_observations();
std::vector<std::vector<Index_> > nn_indices(nobs);
std::vector<Float_> max_distance(nobs);
int k = options.num_neighbors;

#ifndef KNNCOLLE_CUSTOM_PARALLEL
#ifdef _OPENMP
Expand Down
7 changes: 7 additions & 0 deletions tests/src/nenesub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,10 @@ TEST(Nenesub, Sanity) {
auto pselected = nenesub::compute(NR, NC, vec.data(), builder, opt);
EXPECT_EQ(selected, pselected);
}

TEST(Nenesub, OptCheck) {
knncolle::VptreeBuilder builder;
nenesub::Options opt;
opt.min_remaining = 1000;
scran_tests::expect_error([&]() { nenesub::compute(5, 0, static_cast<double*>(NULL), builder, opt); }, "number of neighbors");
}

0 comments on commit e7a175c

Please sign in to comment.