From d1589817547c86c50e3a9c0b61890813df1791c9 Mon Sep 17 00:00:00 2001 From: LTLA Date: Thu, 10 Oct 2024 10:34:50 -0700 Subject: [PATCH] Perform the intersection with strings on the R side. This is more natural given that we're already doing match()'ing to translate the strings to integer indices; so we might as well do the translation on the intersection of gene names. --- DESCRIPTION | 4 ++-- R/RcppExports.R | 4 ++-- R/combineRecomputedResults.R | 13 ++++++++++--- R/trainSingleR.R | 6 +++--- R/utils.R | 13 +++++++++++++ src/RcppExports.cpp | 10 +++++----- src/train_integrated.cpp | 26 +++++++++++++++++--------- src/train_single.cpp | 25 ++++++++++++++++++++----- 8 files changed, 72 insertions(+), 29 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 0ecfe25..b1e26b0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: SingleR Title: Reference-Based Single-Cell RNA-Seq Annotation -Version: 2.7.5 -Date: 2024-10-09 +Version: 2.7.6 +Date: 2024-10-10 Authors@R: c(person("Dvir", "Aran", email="dvir.aran@ucsf.edu", role=c("aut", "cph")), person("Aaron", "Lun", email="infinite.monkeys.with.keyboards@gmail.com", role=c("ctb", "cre")), person("Daniel", "Bunis", role="ctb"), diff --git a/R/RcppExports.R b/R/RcppExports.R index 77266b6..077b5bb 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -19,8 +19,8 @@ grouped_medians <- function(ref, groups, ngroups, nthreads) { .Call('_SingleR_grouped_medians', PACKAGE = 'SingleR', ref, groups, ngroups, nthreads) } -train_integrated <- function(test_features, references, ref_ids, labels, prebuilt, nthreads) { - .Call('_SingleR_train_integrated', PACKAGE = 'SingleR', test_features, references, ref_ids, labels, prebuilt, nthreads) +train_integrated <- function(test_features, references, ref_features, labels, prebuilt, nthreads) { + .Call('_SingleR_train_integrated', PACKAGE = 'SingleR', test_features, references, ref_features, labels, prebuilt, nthreads) } #' @importFrom Rcpp sourceCpp diff --git a/R/combineRecomputedResults.R b/R/combineRecomputedResults.R index 512da13..983a2ad 100644 --- a/R/combineRecomputedResults.R +++ b/R/combineRecomputedResults.R @@ -149,12 +149,19 @@ combineRecomputedResults <- function( } } + all.inter.test <- all.inter.ref <- vector("list", length(trained)) + test.genes <- rownames(test) + for (i in seq_along(all.refnames)) { + inter <- .create_intersection(test.genes, all.refnames[[i]]) + all.inter.test[[i]] <- inter$test - 1L + all.inter.ref[[i]] <- inter$reference - 1L + } + # Applying the integration. - universe <- Reduce(union, c(list(rownames(test)), all.refnames)) ibuilt <- train_integrated( - test_features=match(rownames(test), universe) - 1L, + test_features=all.inter.test, references=lapply(trained, function(x) initializeCpp(x$ref)), - ref_ids=lapply(all.refnames, function(x) match(x, universe) - 1L), + ref_features=all.inter.ref, labels=lapply(trained, function(x) match(x$labels$full, x$labels$unique) - 1L), prebuilt=lapply(trained, function(x) rebuildIndex(x)$built), nthreads = num.threads diff --git a/R/trainSingleR.R b/R/trainSingleR.R index ad243fc..8f0b9dd 100644 --- a/R/trainSingleR.R +++ b/R/trainSingleR.R @@ -340,9 +340,9 @@ trainSingleR <- function( if (is.null(test.genes)) { test.genes <- ref.genes <- seq_len(nrow(ref)) } else { - universe <- union(test.genes, rownames(ref)) - test.genes <- match(test.genes, universe) - ref.genes <- match(rownames(ref), universe) + intersection <- .create_intersection(test.genes, rownames(ref)) + test.genes <- intersection$test + ref.genes <- intersection$reference } builder <- defineBuilder(BNPARAM) diff --git a/R/utils.R b/R/utils.R index 9ae0a49..ad1573f 100644 --- a/R/utils.R +++ b/R/utils.R @@ -35,6 +35,19 @@ x } +.create_intersection <- function(test, reference) { + # Effectively an NA-safe intersect() that preserves ordering in 'test'. + common <- test[test %in% reference] + common <- common[!is.na(common)] + common <- common[!duplicated(common)] + + # match() takes the first occurrence, consistent with internal behavior in singlepp. + list( + test = match(common, test), + reference = match(common, reference) + ) +} + #' @importFrom methods is #' @importClassesFrom S4Vectors List .is_list <- function(val) { diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 4382809..6f4a85a 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -70,17 +70,17 @@ BEGIN_RCPP END_RCPP } // train_integrated -SEXP train_integrated(Rcpp::IntegerVector test_features, Rcpp::List references, Rcpp::List ref_ids, Rcpp::List labels, Rcpp::List prebuilt, int nthreads); -RcppExport SEXP _SingleR_train_integrated(SEXP test_featuresSEXP, SEXP referencesSEXP, SEXP ref_idsSEXP, SEXP labelsSEXP, SEXP prebuiltSEXP, SEXP nthreadsSEXP) { +SEXP train_integrated(Rcpp::List test_features, Rcpp::List references, Rcpp::List ref_features, Rcpp::List labels, Rcpp::List prebuilt, int nthreads); +RcppExport SEXP _SingleR_train_integrated(SEXP test_featuresSEXP, SEXP referencesSEXP, SEXP ref_featuresSEXP, SEXP labelsSEXP, SEXP prebuiltSEXP, SEXP nthreadsSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; - Rcpp::traits::input_parameter< Rcpp::IntegerVector >::type test_features(test_featuresSEXP); + Rcpp::traits::input_parameter< Rcpp::List >::type test_features(test_featuresSEXP); Rcpp::traits::input_parameter< Rcpp::List >::type references(referencesSEXP); - Rcpp::traits::input_parameter< Rcpp::List >::type ref_ids(ref_idsSEXP); + Rcpp::traits::input_parameter< Rcpp::List >::type ref_features(ref_featuresSEXP); Rcpp::traits::input_parameter< Rcpp::List >::type labels(labelsSEXP); Rcpp::traits::input_parameter< Rcpp::List >::type prebuilt(prebuiltSEXP); Rcpp::traits::input_parameter< int >::type nthreads(nthreadsSEXP); - rcpp_result_gen = Rcpp::wrap(train_integrated(test_features, references, ref_ids, labels, prebuilt, nthreads)); + rcpp_result_gen = Rcpp::wrap(train_integrated(test_features, references, ref_features, labels, prebuilt, nthreads)); return rcpp_result_gen; END_RCPP } diff --git a/src/train_integrated.cpp b/src/train_integrated.cpp index 0705b33..259d9c5 100644 --- a/src/train_integrated.cpp +++ b/src/train_integrated.cpp @@ -4,29 +4,37 @@ #include //[[Rcpp::export(rng=false)]] -SEXP train_integrated(Rcpp::IntegerVector test_features, Rcpp::List references, Rcpp::List ref_ids, Rcpp::List labels, Rcpp::List prebuilt, int nthreads) { +SEXP train_integrated(Rcpp::List test_features, Rcpp::List references, Rcpp::List ref_features, Rcpp::List labels, Rcpp::List prebuilt, int nthreads) { size_t nrefs = references.size(); std::vector > inputs; inputs.reserve(nrefs); - std::vector holding_labs; - holding_labs.reserve(nrefs); + std::vector > intersections(nrefs); + std::vector holding_labs(nrefs); for (size_t r = 0; r < nrefs; ++r) { Rcpp::RObject curref(references[r]); Rtatami::BoundNumericPointer parsed(curref); - Rcpp::IntegerVector curids(ref_ids[r]); - holding_labs.emplace_back(labels[r]); + Rcpp::IntegerVector test_ids(test_features[r]); + Rcpp::IntegerVector ref_ids(ref_features[r]); + size_t ninter = test_ids.size(); + if (ninter != static_cast(ref_ids.size())) { + throw std::runtime_error("length of each entry of 'test_features' and 'ref_features' should be the same"); + } + auto& curinter = intersections[r]; + for (size_t i = 0; i < ninter; ++i) { + curinter.emplace_back(test_ids[i], ref_ids[i]); + } + + holding_labs[r] = labels[r]; Rcpp::RObject built = prebuilt[r]; TrainedSingleIntersectPointer curbuilt(built); inputs.push_back(singlepp::prepare_integrated_input_intersect( - static_cast(test_features.size()), - static_cast(test_features.begin()), + curinter, *(parsed->ptr), - static_cast(curids.begin()), - static_cast(holding_labs.back().begin()), + static_cast(holding_labs[r].begin()), *curbuilt )); } diff --git a/src/train_single.cpp b/src/train_single.cpp index e2d0486..4fba77f 100644 --- a/src/train_single.cpp +++ b/src/train_single.cpp @@ -15,6 +15,13 @@ SEXP train_single(Rcpp::IntegerVector test_features, Rcpp::RObject ref, Rcpp::In BiocNeighbors::BuilderPointer bptr(builder); opts.trainer = std::shared_ptr(std::shared_ptr{}, bptr.get()); // make a no-op shared pointer. + Rtatami::BoundNumericPointer parsed(ref); + int NR = parsed->ptr->nrow(); + int NC = parsed->ptr->ncol(); + if (static_cast(labels.size()) != NC) { + throw std::runtime_error("length of 'labels' is equal to the number of columns of 'ref'"); + } + // Setting up the markers. We assume that these are already 0-indexed on the R side. size_t ngroups = markers.size(); singlepp::Markers markers2(ngroups); @@ -26,18 +33,26 @@ SEXP train_single(Rcpp::IntegerVector test_features, Rcpp::RObject ref, Rcpp::In for (size_t n = 0; n < inner_ngroups; ++n) { Rcpp::IntegerVector seq(curmarkers[n]); - auto& seq2 = curmarkers2[n]; + auto& seq2 = curmarkers2[n]; seq2.insert(seq2.end(), seq.begin(), seq.end()); } } + // Preparing the features. + size_t ninter = test_features.size(); + if (ninter != static_cast(ref_features.size())) { + throw std::runtime_error("length of 'test_features' and 'ref_features' should be the same"); + } + singlepp::Intersection inter; + inter.reserve(test_features.size()); + for (size_t i = 0; i < ninter; ++i) { + inter.emplace_back(test_features[i], ref_features[i]); + } + // Building the indices. - Rtatami::BoundNumericPointer parsed(ref); auto built = singlepp::train_single_intersect( - static_cast(test_features.size()), - static_cast(test_features.begin()), + inter, *(parsed->ptr), - static_cast(ref_features.begin()), static_cast(labels.begin()), std::move(markers2), opts