Skip to content

Commit

Permalink
Perform the intersection with strings on the R side.
Browse files Browse the repository at this point in the history
This is more natural given that we're already doing match()'ing to
translate the strings to integer indices; so we might as well do the
translation on the intersection of gene names.
  • Loading branch information
LTLA committed Oct 10, 2024
1 parent 3ddcfb5 commit d158981
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 29 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: SingleR
Title: Reference-Based Single-Cell RNA-Seq Annotation
Version: 2.7.5
Date: 2024-10-09
Version: 2.7.6
Date: 2024-10-10
Authors@R: c(person("Dvir", "Aran", email="[email protected]", role=c("aut", "cph")),
person("Aaron", "Lun", email="[email protected]", role=c("ctb", "cre")),
person("Daniel", "Bunis", role="ctb"),
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ grouped_medians <- function(ref, groups, ngroups, nthreads) {
.Call('_SingleR_grouped_medians', PACKAGE = 'SingleR', ref, groups, ngroups, nthreads)
}

train_integrated <- function(test_features, references, ref_ids, labels, prebuilt, nthreads) {
.Call('_SingleR_train_integrated', PACKAGE = 'SingleR', test_features, references, ref_ids, labels, prebuilt, nthreads)
train_integrated <- function(test_features, references, ref_features, labels, prebuilt, nthreads) {
.Call('_SingleR_train_integrated', PACKAGE = 'SingleR', test_features, references, ref_features, labels, prebuilt, nthreads)
}

#' @importFrom Rcpp sourceCpp
Expand Down
13 changes: 10 additions & 3 deletions R/combineRecomputedResults.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,19 @@ combineRecomputedResults <- function(
}
}

all.inter.test <- all.inter.ref <- vector("list", length(trained))
test.genes <- rownames(test)
for (i in seq_along(all.refnames)) {
inter <- .create_intersection(test.genes, all.refnames[[i]])
all.inter.test[[i]] <- inter$test - 1L
all.inter.ref[[i]] <- inter$reference - 1L
}

# Applying the integration.
universe <- Reduce(union, c(list(rownames(test)), all.refnames))
ibuilt <- train_integrated(
test_features=match(rownames(test), universe) - 1L,
test_features=all.inter.test,
references=lapply(trained, function(x) initializeCpp(x$ref)),
ref_ids=lapply(all.refnames, function(x) match(x, universe) - 1L),
ref_features=all.inter.ref,
labels=lapply(trained, function(x) match(x$labels$full, x$labels$unique) - 1L),
prebuilt=lapply(trained, function(x) rebuildIndex(x)$built),
nthreads = num.threads
Expand Down
6 changes: 3 additions & 3 deletions R/trainSingleR.R
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@ trainSingleR <- function(
if (is.null(test.genes)) {
test.genes <- ref.genes <- seq_len(nrow(ref))
} else {
universe <- union(test.genes, rownames(ref))
test.genes <- match(test.genes, universe)
ref.genes <- match(rownames(ref), universe)
intersection <- .create_intersection(test.genes, rownames(ref))
test.genes <- intersection$test
ref.genes <- intersection$reference
}

builder <- defineBuilder(BNPARAM)
Expand Down
13 changes: 13 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@
x
}

.create_intersection <- function(test, reference) {
# Effectively an NA-safe intersect() that preserves ordering in 'test'.
common <- test[test %in% reference]
common <- common[!is.na(common)]
common <- common[!duplicated(common)]

# match() takes the first occurrence, consistent with internal behavior in singlepp.
list(
test = match(common, test),
reference = match(common, reference)
)
}

#' @importFrom methods is
#' @importClassesFrom S4Vectors List
.is_list <- function(val) {
Expand Down
10 changes: 5 additions & 5 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ BEGIN_RCPP
END_RCPP
}
// train_integrated
SEXP train_integrated(Rcpp::IntegerVector test_features, Rcpp::List references, Rcpp::List ref_ids, Rcpp::List labels, Rcpp::List prebuilt, int nthreads);
RcppExport SEXP _SingleR_train_integrated(SEXP test_featuresSEXP, SEXP referencesSEXP, SEXP ref_idsSEXP, SEXP labelsSEXP, SEXP prebuiltSEXP, SEXP nthreadsSEXP) {
SEXP train_integrated(Rcpp::List test_features, Rcpp::List references, Rcpp::List ref_features, Rcpp::List labels, Rcpp::List prebuilt, int nthreads);
RcppExport SEXP _SingleR_train_integrated(SEXP test_featuresSEXP, SEXP referencesSEXP, SEXP ref_featuresSEXP, SEXP labelsSEXP, SEXP prebuiltSEXP, SEXP nthreadsSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type test_features(test_featuresSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type test_features(test_featuresSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type references(referencesSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type ref_ids(ref_idsSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type ref_features(ref_featuresSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type labels(labelsSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type prebuilt(prebuiltSEXP);
Rcpp::traits::input_parameter< int >::type nthreads(nthreadsSEXP);
rcpp_result_gen = Rcpp::wrap(train_integrated(test_features, references, ref_ids, labels, prebuilt, nthreads));
rcpp_result_gen = Rcpp::wrap(train_integrated(test_features, references, ref_features, labels, prebuilt, nthreads));
return rcpp_result_gen;
END_RCPP
}
Expand Down
26 changes: 17 additions & 9 deletions src/train_integrated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,37 @@
#include <memory>

//[[Rcpp::export(rng=false)]]
SEXP train_integrated(Rcpp::IntegerVector test_features, Rcpp::List references, Rcpp::List ref_ids, Rcpp::List labels, Rcpp::List prebuilt, int nthreads) {
SEXP train_integrated(Rcpp::List test_features, Rcpp::List references, Rcpp::List ref_features, Rcpp::List labels, Rcpp::List prebuilt, int nthreads) {
size_t nrefs = references.size();

std::vector<singlepp::TrainIntegratedInput<double, int, int> > inputs;
inputs.reserve(nrefs);
std::vector<Rcpp::IntegerVector> holding_labs;
holding_labs.reserve(nrefs);
std::vector<singlepp::Intersection<int> > intersections(nrefs);
std::vector<Rcpp::IntegerVector> holding_labs(nrefs);

for (size_t r = 0; r < nrefs; ++r) {
Rcpp::RObject curref(references[r]);
Rtatami::BoundNumericPointer parsed(curref);

Rcpp::IntegerVector curids(ref_ids[r]);
holding_labs.emplace_back(labels[r]);
Rcpp::IntegerVector test_ids(test_features[r]);
Rcpp::IntegerVector ref_ids(ref_features[r]);
size_t ninter = test_ids.size();
if (ninter != static_cast<size_t>(ref_ids.size())) {
throw std::runtime_error("length of each entry of 'test_features' and 'ref_features' should be the same");
}
auto& curinter = intersections[r];
for (size_t i = 0; i < ninter; ++i) {
curinter.emplace_back(test_ids[i], ref_ids[i]);
}

holding_labs[r] = labels[r];
Rcpp::RObject built = prebuilt[r];
TrainedSingleIntersectPointer curbuilt(built);

inputs.push_back(singlepp::prepare_integrated_input_intersect(
static_cast<int>(test_features.size()),
static_cast<const int*>(test_features.begin()),
curinter,
*(parsed->ptr),
static_cast<const int*>(curids.begin()),
static_cast<const int*>(holding_labs.back().begin()),
static_cast<const int*>(holding_labs[r].begin()),
*curbuilt
));
}
Expand Down
25 changes: 20 additions & 5 deletions src/train_single.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ SEXP train_single(Rcpp::IntegerVector test_features, Rcpp::RObject ref, Rcpp::In
BiocNeighbors::BuilderPointer bptr(builder);
opts.trainer = std::shared_ptr<BiocNeighbors::Builder>(std::shared_ptr<BiocNeighbors::Builder>{}, bptr.get()); // make a no-op shared pointer.

Rtatami::BoundNumericPointer parsed(ref);
int NR = parsed->ptr->nrow();
int NC = parsed->ptr->ncol();
if (static_cast<int>(labels.size()) != NC) {
throw std::runtime_error("length of 'labels' is equal to the number of columns of 'ref'");
}

// Setting up the markers. We assume that these are already 0-indexed on the R side.
size_t ngroups = markers.size();
singlepp::Markers<int> markers2(ngroups);
Expand All @@ -26,18 +33,26 @@ SEXP train_single(Rcpp::IntegerVector test_features, Rcpp::RObject ref, Rcpp::In

for (size_t n = 0; n < inner_ngroups; ++n) {
Rcpp::IntegerVector seq(curmarkers[n]);
auto& seq2 = curmarkers2[n];
auto& seq2 = curmarkers2[n];
seq2.insert(seq2.end(), seq.begin(), seq.end());
}
}

// Preparing the features.
size_t ninter = test_features.size();
if (ninter != static_cast<size_t>(ref_features.size())) {
throw std::runtime_error("length of 'test_features' and 'ref_features' should be the same");
}
singlepp::Intersection<int> inter;
inter.reserve(test_features.size());
for (size_t i = 0; i < ninter; ++i) {
inter.emplace_back(test_features[i], ref_features[i]);
}

// Building the indices.
Rtatami::BoundNumericPointer parsed(ref);
auto built = singlepp::train_single_intersect(
static_cast<int>(test_features.size()),
static_cast<const int*>(test_features.begin()),
inter,
*(parsed->ptr),
static_cast<const int*>(ref_features.begin()),
static_cast<const int*>(labels.begin()),
std::move(markers2),
opts
Expand Down

0 comments on commit d158981

Please sign in to comment.