Skip to content

Commit

Permalink
Added Workspace class to re-use allocations in repeated trend fits.
Browse files Browse the repository at this point in the history
This is useful for re-using memory when fitting across multiple blocks.
Removed the need for dedicated buffers for the fitted values and
residuals, we just shift things around on the same buffer now.

Renamed the fixed_width mode to minimum_width, which is a bit more
accurate given that the window expands if there's not enough points; and
pass along the number of threads to the LOWESS fitter.

Added some tests for correct behavior with mean-based filtering.
  • Loading branch information
LTLA committed Jun 25, 2024
1 parent 1af262e commit 6196a72
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 45 deletions.
103 changes: 68 additions & 35 deletions include/scran/fit_variance_trend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,34 +56,56 @@ struct Options {

/**
* Span for the LOWESS smoother, as a proportion of the total number of points.
* This is only used if `Options::use_fixed_width = false`.
* This is only used if `Options::use_minimum_width = false`.
*/
double span = 0.3;

/**
* Should a fixed-width constraint be applied to the LOWESS smoother?
* This forces each window to be a minimum width (see `Options::fixed_width`) and avoids problems with large differences in density.
* This forces each window to be a minimum width (see `Options::minimum_width`) and avoids problems with large differences in density.
* For example, the default smoother performs poorly at high abundances where there are few genes.
*/
bool use_fixed_width = false;
bool use_minimum_width = false;

/**
* Width of the window to use when `Options::use_fixed_width = true`.
* Width of the window to use when `Options::use_minimum_width = true`.
* This should be relative to the range of `mean` values in `compute()`;
* the default value is chosen based on the typical range in single-cell RNA-seq data.
*/
double fixed_width = 1;
double minimum_width = 1;

/**
* Minimum number of observations in each window when `Options::use_fixed_width = true`.
* Minimum number of observations in each window when `Options::use_minimum_width = true`.
* This ensures that each window contains at least a given number of observations;
* if it does not, it is extended using the standard LOWESS logic until the minimum number is achieved.
*/
int minimum_window_count = 200;

/**
* Number of threads to use in the LOWESS fit.
*/
int num_threads = 1;
};

/**
* @brief Workspace for `compute()`.
*
* This avoids repeated memory allocations for repeated calls to `compute()`.
*/
template<typename Float_>
struct Workspace {
/**
* @cond
*/
WeightedLowess::SortBy sorter;

std::vector<uint8_t> sort_workspace;

std::vector<Float_> xbuffer, ybuffer;
/**
* @endcond
*/
};

/**
* @brief Fit a mean-variance trend to log-count data.
Expand All @@ -95,11 +117,16 @@ struct Options {
* @param[in] variance Pointer to an array of length `n`, containing the variances for all features.
* @param[out] fitted Pointer to an array of length `n`, to store the fitted values.
* @param[out] residuals Pointer to an array of length `n`, to store the residuals.
* @param workspace Collection of temporary data structures.
* This can be re-used across multiple `compute()` calls.
* @param options Further options.
*/
template<typename Float_>
void compute(size_t n, const Float_* mean, const Float_* variance, Float_* fitted, Float_* residuals, const Options& options) {
std::vector<Float_> xbuffer(n), ybuffer(n);
void compute(size_t n, const Float_* mean, const Float_* variance, Float_* fitted, Float_* residuals, Workspace<Float_>& workspace, const Options& options) {
auto& xbuffer = workspace.xbuffer;
xbuffer.resize(n);
auto& ybuffer = workspace.ybuffer;
ybuffer.resize(n);

auto quad = [](Float_ x) -> Float_ {
return x * x * x * x;
Expand All @@ -123,43 +150,48 @@ void compute(size_t n, const Float_* mean, const Float_* variance, Float_* fitte
throw std::runtime_error("not enough observations above the minimum mean");
}


// Determining the left edge. This needs to be done before
// SortBy::permute on the xbuffer.
size_t left_index = std::min_element(xbuffer.begin(), xbuffer.begin() + counter) - xbuffer.begin();
Float_ left_x = xbuffer[left_index];

WeightedLowess::SortBy sorter(counter, xbuffer.data());
std::vector<uint8_t> work;
auto& sorter = workspace.sorter;
sorter.set(counter, xbuffer.data());
auto& work = workspace.sort_workspace;
sorter.permute(xbuffer.data(), work);
sorter.permute(ybuffer.data(), work);

WeightedLowess::Options<Float_> smooth_opt;
if (options.use_fixed_width) {
if (options.use_minimum_width) {
smooth_opt.span = options.minimum_window_count;
smooth_opt.span_as_proportion = false;
smooth_opt.minimum_width = options.fixed_width;
smooth_opt.minimum_width = options.minimum_width;
} else {
smooth_opt.span = options.span;
}

std::vector<Float_> fbuffer(counter), rbuffer(counter);
WeightedLowess::compute(counter, xbuffer.data(), ybuffer.data(), fbuffer.data(), rbuffer.data(), smooth_opt);

sorter.unpermute(rbuffer.data(), work);
sorter.unpermute(fbuffer.data(), work);

// Identifying the left-most fitted value.
Float_ left_fitted = (options.transform ? quad(fbuffer[left_index]) : fbuffer[left_index]);

counter = 0;
for (size_t i = 0; i < n; ++i) {
if (!options.mean_filter || mean[i] >= min_mean) {
fitted[i] = (options.transform ? quad(fbuffer[counter]) : fbuffer[counter]);
++counter;
smooth_opt.num_threads = options.num_threads;

// Using the residual array to store the robustness weights as a placeholder;
// we'll be overwriting this later.
WeightedLowess::compute(counter, xbuffer.data(), ybuffer.data(), fitted, residuals, smooth_opt);

// Determining the left edge before we unpermute.
Float_ left_x = xbuffer[0];
Float_ left_fitted = (options.transform ? quad(fitted[0]) : fitted[0]);

sorter.unpermute(fitted, work);

// Walking backwards to shift the elements back to their original position
// (i.e., before filtering on the mean) on the same array. We need to walk
// backwards to ensure that writing to the original position on this array
// doesn't clobber the first 'counter' positions containing the fitted
// values, at least not until each value is shifted to its original place.
for (size_t i = n; i > 0; --i) {
auto j = i - 1;
if (!options.mean_filter || mean[j] >= min_mean) {
--counter;
fitted[j] = (options.transform ? quad(fitted[counter]) : fitted[counter]);
} else {
fitted[i] = mean[i] / left_x * left_fitted; // draw a y = x line to the origin from the left of the fitted trend.
fitted[j] = mean[j] / left_x * left_fitted; // draw a y = x line to the origin from the left of the fitted trend.
}
}

for (size_t i = 0; i < n; ++i) {
residuals[i] = variance[i] - fitted[i];
}
return;
Expand Down Expand Up @@ -211,7 +243,8 @@ struct Results {
template<typename Float_>
Results<Float_> compute(size_t n, const Float_* mean, const Float_* variance, const Options& options) {
Results<Float_> output(n);
compute(n, mean, variance, output.fitted.data(), output.residuals.data(), options);
Workspace<Float_> work;
compute(n, mean, variance, output.fitted.data(), output.residuals.data(), work, options);
return output;
}

Expand Down
4 changes: 3 additions & 1 deletion include/scran/model_gene_variances.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,12 @@ void compute_blocked(
internal::compute(mat, means, variances, block, block_size, options.num_threads);
}

fit_variance_trend::Workspace<Stat_> work;
auto fopt = options.fit_variance_trend_options;
fopt.num_threads = options.num_threads;
for (size_t b = 0, nblocks = block_size.size(); b < nblocks; ++b) {
if (block_size[b] >= 2) {
fit_variance_trend::compute(NR, means[b], variances[b], fitted[b], residuals[b], fopt);
fit_variance_trend::compute(NR, means[b], variances[b], fitted[b], residuals[b], work, fopt);
} else {
std::fill(fitted[b], fitted[b] + NR, std::numeric_limits<double>::quiet_NaN());
std::fill(residuals[b], residuals[b] + NR, std::numeric_limits<double>::quiet_NaN());
Expand Down
44 changes: 35 additions & 9 deletions tests/src/fit_variance_trend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ TEST(FitVarianceTrendTest, Extrapolation) {
}

TEST(FitVarianceTrendTest, Residuals) {
auto x = simulate_vector(101, /* lower = */ 0.0, /* upper = */ 1.0);
auto y = simulate_vector(101, /* lower = */ 0.1, /* upper = */ 2.0);
auto x = simulate_vector(101, /* lower = */ 0.0, /* upper = */ 1.0, /* seed = */ 42);
auto y = simulate_vector(101, /* lower = */ 0.1, /* upper = */ 2.0, /* seed = */ 69);

scran::fit_variance_trend::Options opt;
auto output = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt);
Expand All @@ -69,16 +69,42 @@ TEST(FitVarianceTrendTest, Residuals) {
EXPECT_EQ(output.residuals, ref);
}

TEST(FitVarianceTrendTest, FixedMode) {
auto x = simulate_vector(101, /* lower = */ 0.0, /* upper = */ 1.0);
auto y = simulate_vector(101, /* lower = */ 0.1, /* upper = */ 2.0);
TEST(FitVarianceTrendTest, Filtering) {
auto x = simulate_vector(1001, /* lower = */ 0.0, /* upper = */ 1.0, /* seed = */ 420);
auto y = simulate_vector(1001, /* lower = */ 0.1, /* upper = */ 2.0, /* seed = */ 8008);

scran::fit_variance_trend::Options opt;
auto ref = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt);

opt.mean_filter = false;
auto output_unfilt = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt);
EXPECT_NE(ref.residuals, output_unfilt.residuals); // check that there is a difference.

std::vector<double> submean, subvar, subfit, subresid;
for (size_t i = 0; i < x.size(); ++i) {
if (x[i] >= opt.minimum_mean) {
submean.push_back(x[i]);
subvar.push_back(y[i]);
subfit.push_back(ref.fitted[i]);
subresid.push_back(ref.residuals[i]);
}
}

auto output_manual = scran::fit_variance_trend::compute(submean.size(), submean.data(), subvar.data(), opt);
EXPECT_EQ(output_manual.residuals, subresid);
EXPECT_EQ(output_manual.fitted, subfit);
}

TEST(FitVarianceTrendTest, MinWidth) {
auto x = simulate_vector(101, /* lower = */ 0.0, /* upper = */ 1.0, /* seed = */ 12345);
auto y = simulate_vector(101, /* lower = */ 0.1, /* upper = */ 2.0, /* seed = */ 9876);

scran::fit_variance_trend::Options opt;
auto output = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt);

opt.use_fixed_width = true;
opt.use_minimum_width = true;
opt.minimum_window_count = 10;
opt.fixed_width = 0.2;
opt.minimum_width = 0.2;
auto foutput = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt);

EXPECT_NE(output.residuals, foutput.residuals);
Expand All @@ -89,13 +115,13 @@ TEST(FitVarianceTrendTest, FixedMode) {
opt2.span = 1;
auto output2 = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt2);

opt2.use_fixed_width = true;
opt2.use_minimum_width = true;
opt2.minimum_window_count = 200;
auto foutput2 = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt2);
EXPECT_EQ(output2.residuals, foutput2.residuals);

opt2.minimum_window_count = 0;
opt2.fixed_width = 10;
opt2.minimum_width = 10;
foutput2 = scran::fit_variance_trend::compute(x.size(), x.data(), y.data(), opt2);
EXPECT_EQ(output2.residuals, foutput2.residuals);
}

0 comments on commit 6196a72

Please sign in to comment.