Skip to content

Commit

Permalink
Bugfix for chooseHVGs when keep_ties=false and bounds are set.
Browse files Browse the repository at this point in the history
Consolidated the boolean and index functions to avoid code duplication. Also
reorganized the test suite to cover more combinations of the options.
  • Loading branch information
LTLA committed Dec 21, 2024
1 parent c491ae7 commit 867fc6b
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 164 deletions.
183 changes: 85 additions & 98 deletions include/scran_variances/choose_highly_variable_genes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ struct ChooseHighlyVariableGenesOptions {
* Note that the actual number of chosen genes may be:
*
* - smaller than `top`, if the latter is greater than the total number of genes in the dataset.
* - smaller than `top`, if `ChooseHighlyVariableGenesOptions::use_bound = true` and `top` is greater than the total number of genes in the dataset with statistics greater than `ChooseHighlyVariableGenesOptions::bound`.
* - smaller than `top`, if `ChooseHighlyVariableGenesOptions::use_bound = true` and `top` is greater than `N`,
* where `N` is the number of genes in the dataset with statistics greater than `ChooseHighlyVariableGenesOptions::bound`
* (or less than the bound, if `ChosenHighlyVariableGenesOptions::larger = false`).
* - larger than `top`, if `ChooseHighlyVariableGenesOptions::keep_ties = true` and there are multiple ties at the `top`-th chosen gene.
*/
size_t top = 4000;
Expand Down Expand Up @@ -57,133 +59,114 @@ struct ChooseHighlyVariableGenesOptions {
*/
namespace internal {

template<typename Index_, typename Stat_, class Cmp_>
std::vector<Index_> create_semisorted_indices(size_t n, const Stat_* statistic, Cmp_ cmp, size_t top) {
std::vector<Index_> collected(n);
std::iota(collected.begin(), collected.end(), static_cast<Index_>(0));
auto cBegin = collected.begin(), cMid = cBegin + top - 1, cEnd = collected.end();
std::nth_element(cBegin, cMid, cEnd, [&](Index_ l, Index_ r) -> bool {
auto L = statistic[l], R = statistic[r];
if (L == R) {
return l < r; // always favor the earlier index for a stable sort, even if options.larger = false.
template<bool keep_index_, typename Index_, typename Stat_, class Output_, class Cmp_, class CmpEqual_>
void choose_highly_variable_genes(Index_ n, const Stat_* statistic, Output_& output, Cmp_ cmp, CmpEqual_ cmpeq, const ChooseHighlyVariableGenesOptions& options) {
if (options.top == 0) {
if constexpr(keep_index_) {
; // no-op, we assume it's already empty.
} else {
return cmp(L, R);
std::fill_n(output, n, false);
}
});
return collected;
}

template<typename Stat_, class Output_, class Cmp_, class CmpEqual_>
void choose_highly_variable_genes(size_t n, const Stat_* statistic, Output_* output, Cmp_ cmp, CmpEqual_ cmpeq, const ChooseHighlyVariableGenesOptions& options) {
if (options.top == 0) {
std::fill_n(output, n, false);
return;
}

Stat_ bound = options.bound;
if (options.top >= n) {
if (static_cast<size_t>(options.top) >= static_cast<size_t>(n)) {
if (options.use_bound) {
for (size_t i = 0; i < n; ++i) {
output[i] = cmp(statistic[i], bound);
for (Index_ i = 0; i < n; ++i) {
bool ok = cmp(statistic[i], bound);
if constexpr(keep_index_) {
if (ok) {
output.push_back(i);
}
} else {
output[i] = ok;
}
}
} else {
std::fill_n(output, n, true);
if constexpr(keep_index_) {
output.resize(n);
std::iota(output.begin(), output.end(), static_cast<Index_>(0));
} else {
std::fill_n(output, n, true);
}
}
return;
}

auto collected = create_semisorted_indices<size_t>(n, statistic, cmp, options.top);
Stat_ threshold = statistic[collected[options.top - 1]];
std::vector<Index_> semi_sorted(n);
std::iota(semi_sorted.begin(), semi_sorted.end(), static_cast<Index_>(0));
auto cBegin = semi_sorted.begin(), cMid = cBegin + options.top - 1, cEnd = semi_sorted.end();
std::nth_element(cBegin, cMid, cEnd, [&](Index_ l, Index_ r) -> bool {
auto L = statistic[l], R = statistic[r];
if (L == R) {
return l < r; // always favor the earlier index for a stable sort, even if options.larger = false.
} else {
return cmp(L, R);
}
});

Stat_ threshold = statistic[semi_sorted[options.top - 1]];

if (options.keep_ties) {
if (options.use_bound && !cmp(threshold, bound)) {
for (size_t i = 0; i < n; ++i) {
output[i] = cmp(statistic[i], bound);
for (Index_ i = 0; i < n; ++i) {
bool ok = cmp(statistic[i], bound);
if constexpr(keep_index_) {
if (ok) {
output.push_back(i);
}
} else {
output[i] = ok;
}
}
} else {
for (size_t i = 0; i < n; ++i) {
output[i] = cmpeq(statistic[i], threshold);
for (Index_ i = 0; i < n; ++i) {
bool ok = cmpeq(statistic[i], threshold);
if constexpr(keep_index_) {
if (ok) {
output.push_back(i);
}
} else {
output[i] = ok;
}
}
}
return;
}

std::fill_n(output, n, false);
size_t counter = options.top;
if (options.use_bound && !cmp(threshold, bound)) {
--counter;
while (counter > 0) {
--counter;
if (cmp(statistic[collected[counter]], bound)) {
++counter;
break;
}
}
}

for (size_t i = 0; i < counter; ++i) {
output[collected[i]] = true;
}
}

template<typename Index_, typename Stat_, class Cmp_, class CmpEqual_>
std::vector<Index_> choose_highly_variable_genes_index(size_t n, const Stat_* statistic, Cmp_ cmp, CmpEqual_ cmpeq, const ChooseHighlyVariableGenesOptions& options) {
std::vector<Index_> output;
if (options.top == 0) {
return output;
if constexpr(keep_index_) {
output.reserve(options.top);
} else {
std::fill_n(output, n, false);
}

Stat_ bound = options.bound;
if (options.top >= n) {
if (options.use_bound) {
for (size_t i = 0; i < n; ++i) {
if (options.use_bound && cmp(statistic[i], bound)) {
output.push_back(i);
if (options.use_bound) {
Index_ counter = options.top;
while (counter > 0) {
--counter;
auto pos = semi_sorted[counter];
if (cmp(statistic[pos], bound)) {
if constexpr(keep_index_) {
output.push_back(pos);
} else {
output[pos] = true;
}
}
} else {
output.resize(n);
std::iota(output.begin(), output.end(), static_cast<Index_>(0));
}
return output;
}

output = create_semisorted_indices<Index_>(n, statistic, cmp, options.top);
Stat_ threshold = statistic[output[options.top - 1]];

if (options.keep_ties) {
output.clear();
if (options.use_bound && !cmp(threshold, bound)) {
for (size_t i = 0; i < n; ++i) {
if (cmp(statistic[i], bound)) {
output.push_back(i);
}
}
} else {
if constexpr(keep_index_) {
output.insert(output.end(), semi_sorted.begin(), semi_sorted.begin() + options.top);
} else {
for (size_t i = 0; i < n; ++i) {
if (cmpeq(statistic[i], threshold)) {
output.push_back(i);
}
for (Index_ i = 0, end = options.top; i < end; ++i) {
output[semi_sorted[i]] = true;
}
}
return output;
}

size_t counter = options.top;
if (options.use_bound && !cmp(threshold, bound)) {
--counter;
while (counter > 0) {
--counter;
if (cmp(statistic[output[counter]], bound)) {
++counter;
break;
}
}
if constexpr(keep_index_) {
std::sort(output.begin(), output.end());
}

output.resize(counter);
std::sort(output.begin(), output.end());
return output;
}

}
Expand All @@ -204,7 +187,7 @@ std::vector<Index_> choose_highly_variable_genes_index(size_t n, const Stat_* st
template<typename Stat_, typename Bool_>
void choose_highly_variable_genes(size_t n, const Stat_* statistic, Bool_* output, const ChooseHighlyVariableGenesOptions& options) {
if (options.larger) {
internal::choose_highly_variable_genes(
internal::choose_highly_variable_genes<false>(
n,
statistic,
output,
Expand All @@ -213,7 +196,7 @@ void choose_highly_variable_genes(size_t n, const Stat_* statistic, Bool_* outpu
options
);
} else {
internal::choose_highly_variable_genes(
internal::choose_highly_variable_genes<false>(
n,
statistic,
output,
Expand Down Expand Up @@ -254,23 +237,27 @@ std::vector<Bool_> choose_highly_variable_genes(size_t n, const Stat_* statistic
*/
template<typename Index_, typename Stat_>
std::vector<Index_> choose_highly_variable_genes_index(Index_ n, const Stat_* statistic, const ChooseHighlyVariableGenesOptions& options) {
std::vector<Index_> output;
if (options.larger) {
return internal::choose_highly_variable_genes_index<Index_>(
internal::choose_highly_variable_genes<true>(
n,
statistic,
output,
[](Stat_ l, Stat_ r) -> bool { return l > r; },
[](Stat_ l, Stat_ r) -> bool { return l >= r; },
options
);
} else {
return internal::choose_highly_variable_genes_index<Index_>(
internal::choose_highly_variable_genes<true>(
n,
statistic,
output,
[](Stat_ l, Stat_ r) -> bool { return l < r; },
[](Stat_ l, Stat_ r) -> bool { return l <= r; },
options
);
}
return output;
}

}
Expand Down
Loading

0 comments on commit 867fc6b

Please sign in to comment.