diff --git a/include/singlepp/classify_integrated.hpp b/include/singlepp/classify_integrated.hpp index 0608383..91ec459 100644 --- a/include/singlepp/classify_integrated.hpp +++ b/include/singlepp/classify_integrated.hpp @@ -130,6 +130,9 @@ void classify_integrated( ClassifyIntegratedBuffers& buffers, const ClassifyIntegratedOptions& options) { + if (trained.test_nrow != static_cast(-1) && trained.test_nrow != test.nrow()) { + throw std::runtime_error("number of rows in 'test' is not the same as that used to build 'trained'"); + } internal::annotate_cells_integrated( test, trained, diff --git a/include/singlepp/classify_single.hpp b/include/singlepp/classify_single.hpp index 2d94508..a8160d8 100644 --- a/include/singlepp/classify_single.hpp +++ b/include/singlepp/classify_single.hpp @@ -134,6 +134,9 @@ void classify_single( const ClassifySingleBuffers& buffers, const ClassifySingleOptions& options) { + if (trained.get_test_nrow() != test.nrow()) { + throw std::runtime_error("number of rows in 'test' is not the same as that used to build 'trained'"); + } internal::annotate_cells_single( test, trained.get_subset(), @@ -172,6 +175,9 @@ void classify_single_intersect( const ClassifySingleBuffers& buffers, const ClassifySingleOptions& options) { + if (trained.get_test_nrow() != static_cast(-1) && trained.get_test_nrow() != test.nrow()) { + throw std::runtime_error("number of rows in 'test' is not the same as that used to build 'trained'"); + } internal::annotate_cells_single( test, trained.get_test_subset(), diff --git a/include/singlepp/train_integrated.hpp b/include/singlepp/train_integrated.hpp index f827be8..8b779e4 100644 --- a/include/singlepp/train_integrated.hpp +++ b/include/singlepp/train_integrated.hpp @@ -35,6 +35,8 @@ struct TrainIntegratedInput { /** * @cond */ + Index_ test_nrow; + const tatami::Matrix* ref; const Label_* labels; @@ -76,6 +78,7 @@ TrainIntegratedInput prepare_integrated_input( const TrainedSingle& trained) { TrainIntegratedInput output; + output.test_nrow = ref.nrow(); // remember, test and ref are assumed to have the same features. output.ref = &ref; output.labels = labels; @@ -115,25 +118,29 @@ TrainIntegratedInput prepare_integrated_input( * @tparam Label_ Integer type for the reference labels. * @tparam Float_ Floating-point type for the correlations and scores. * + * @param test_nrow Number of features in the test dataset. * @param intersection Vector defining the intersection of genes between the test and reference datasets. * Each pair corresponds to a gene where the first and second elements represent the row indices of that gene in the test and reference matrices, respectively. + * The first element of each pair should be non-negative and less than `test_nrow`, while the second element should be non-negative and less than `ref->nrow()`. * See `intersect_genes()` for more details. * @param ref Matrix containing the reference expression values, where rows are genes and columns are reference profiles. * The number and identity of genes should be consistent with `intersection`. * @param[in] labels An array of length equal to the number of columns of `ref`, containing the label for each sample. * Values should be integers in \f$[0, L)\f$ where \f$L\f$ is the number of unique labels. - * @param trained Classifier created by calling `train_single_intersect()` on `intersection`, `ref` and `labels`. + * @param trained Classifier created by calling `train_single_intersect()` on `test_nrow`, `intersection`, `ref` and `labels`. * * @return An opaque input object for `train_integrated()`. */ template TrainIntegratedInput prepare_integrated_input_intersect( + Index_ test_nrow, const Intersection& intersection, const tatami::Matrix& ref, const Label_* labels, const TrainedSingleIntersect& trained) { TrainIntegratedInput output; + output.test_nrow = test_nrow; output.ref = &ref; output.labels = labels; @@ -166,6 +173,23 @@ TrainIntegratedInput prepare_integrated_input_intersect( return output; } +/** + * @cond + */ +// For back-compatibility only. +template +TrainIntegratedInput prepare_integrated_input_intersect( + const Intersection& intersection, + const tatami::Matrix& ref, + const Label_* labels, + const TrainedSingleIntersect& trained) +{ + return prepare_integrated_input_intersect(-1, intersection, ref, labels, trained); +} +/** + * @endcond + */ + /** * Prepare a reference dataset for `train_integrated()`. * This overload automatically identifies the intersection of genes between the test and reference datasets. @@ -202,7 +226,7 @@ TrainIntegratedInput prepare_integrated_input_intersect( const TrainedSingleIntersect& trained) { auto intersection = intersect_genes(test_nrow, test_id, ref.nrow(), ref_id); - auto output = prepare_integrated_input_intersect(intersection, ref, labels, trained); + auto output = prepare_integrated_input_intersect(test_nrow, intersection, ref, labels, trained); output.user_intersection = NULL; output.auto_intersection.swap(intersection); return output; @@ -248,6 +272,7 @@ class TrainedIntegrated { */ // Technically this should be private, but it's a pain to add // templated friend functions, so I can't be bothered. + Index_ test_nrow; std::vector universe; // To be used by classify_integrated() for indexed extraction. std::vector check_availability; @@ -411,6 +436,16 @@ TrainedIntegrated train_integrated(Inputs_& inputs, const TrainIntegrate output.markers.resize(nrefs); output.ranked.resize(nrefs); + // Checking that the number of genes in the test dataset are consistent. + output.test_nrow = -1; + for (const auto& in : inputs) { + if (output.test_nrow == static_cast(-1)) { + output.test_nrow = in.test_nrow; + } else if (in.test_nrow != static_cast(-1) && in.test_nrow != output.test_nrow) { + throw std::runtime_error("inconsistent number of rows in the test dataset across entries of 'inputs'"); + } + } + // Identify the union of all marker genes. std::unordered_map remap_to_universe; std::unordered_set subset_tmp; diff --git a/include/singlepp/train_single.hpp b/include/singlepp/train_single.hpp index 3030007..79cd276 100644 --- a/include/singlepp/train_single.hpp +++ b/include/singlepp/train_single.hpp @@ -90,9 +90,11 @@ class TrainedSingle { * @cond */ TrainedSingle( + Index_ test_nrow, Markers markers, std::vector subset, std::vector > references) : + my_test_nrow(test_nrow), my_markers(std::move(markers)), my_subset(std::move(subset)), my_references(std::move(references)) @@ -102,11 +104,19 @@ class TrainedSingle { */ private: + Index_ my_test_nrow; Markers my_markers; std::vector my_subset; std::vector > my_references; public: + /** + * @return Number of rows that should be present in the test dataset. + */ + Index_ get_test_nrow() const { + return my_test_nrow; + } + /** * @return A vector of vectors of vectors of ranked marker genes to be used in the classification. * In the innermost vectors, each value is an index into the subset vector (see `get_subset()`), @@ -186,7 +196,8 @@ TrainedSingle train_single( { auto subset = internal::subset_to_markers(markers, options.top); auto subref = internal::build_references(ref, labels, subset, options); - return TrainedSingle(std::move(markers), std::move(subset), std::move(subref)); + Index_ test_nrow = ref.nrow(); // remember, test and ref are assumed to have the same features. + return TrainedSingle(test_nrow, std::move(markers), std::move(subset), std::move(subref)); } /** @@ -206,10 +217,12 @@ class TrainedSingleIntersect { * @cond */ TrainedSingleIntersect( + Index_ test_nrow, Markers markers, std::vector test_subset, std::vector ref_subset, std::vector > references) : + my_test_nrow(test_nrow), my_markers(std::move(markers)), my_test_subset(std::move(test_subset)), my_ref_subset(std::move(ref_subset)), @@ -220,12 +233,20 @@ class TrainedSingleIntersect { */ private: + Index_ my_test_nrow; Markers my_markers; std::vector my_test_subset; std::vector my_ref_subset; std::vector > my_references; public: + /** + * @return Number of rows that should be present in the test dataset. + */ + Index_ get_test_nrow() const { + return my_test_nrow; + } + /** * @return A vector of vectors of ranked marker genes to be used in the classification. * In the innermost vectors, each value is an index into the subset vectors (see `get_test_subset()` and `get_ref_subset()`). @@ -295,8 +316,10 @@ class TrainedSingleIntersect { * @tparam Label_ Integer type for the reference labels. * @tparam Float_ Floating-point type for the correlations and scores. * + * @param test_nrow Number of features in the test dataset. * @param intersection Vector defining the intersection of genes between the test and reference datasets. * Each pair corresponds to a gene where the first and second elements represent the row indices of that gene in the test and reference matrices, respectively. + * The first element of each pair should be non-negative and less than `test_nrow`, while the second element should be non-negative and less than `ref->nrow()`. * See `intersect_genes()` for more details. * @param ref An expression matrix for the reference expression profiles, where rows are genes and columns are cells. * This should have non-zero columns. @@ -309,6 +332,7 @@ class TrainedSingleIntersect { */ template TrainedSingleIntersect train_single_intersect( + Index_ test_nrow, const Intersection& intersection, const tatami::Matrix& ref, const Label_* labels, @@ -317,9 +341,27 @@ TrainedSingleIntersect train_single_intersect( { auto pairs = internal::subset_to_markers(intersection, markers, options.top); auto subref = internal::build_references(ref, labels, pairs.second, options); - return TrainedSingleIntersect(std::move(markers), std::move(pairs.first), std::move(pairs.second), std::move(subref)); + return TrainedSingleIntersect(test_nrow, std::move(markers), std::move(pairs.first), std::move(pairs.second), std::move(subref)); } +/** + * @cond + */ +// For back-compatibility only. +template +TrainedSingleIntersect train_single_intersect( + const Intersection& intersection, + const tatami::Matrix& ref, + const Label_* labels, + Markers markers, + const TrainSingleOptions& options) +{ + return train_single_intersect(-1, intersection, ref, labels, std::move(markers), options); +} +/** + * @endcond + */ + /** * Variant of `train_single()` that uses the intersection of genes between the reference dataset and a (future) test dataset. * This is useful when the genes are not in the same order and number across the test and reference datasets. @@ -359,7 +401,7 @@ TrainedSingleIntersect train_single_intersect( const TrainSingleOptions& options) { auto intersection = intersect_genes(test_nrow, test_id, ref.nrow(), ref_id); - return train_single_intersect(intersection, ref, labels, std::move(markers), options); + return train_single_intersect(test_nrow, intersection, ref, labels, std::move(markers), options); } } diff --git a/tests/src/classify_integrated.cpp b/tests/src/classify_integrated.cpp index 39da536..031ccfe 100644 --- a/tests/src/classify_integrated.cpp +++ b/tests/src/classify_integrated.cpp @@ -54,17 +54,6 @@ class IntegratedTestCore { static std::vector simulate_ref_ids(size_t ngenes, int seed) { return simulate_ids(ngenes, seed, MISSING_REF_ID); // -1 => unique to the reference. } - - static singlepp::Markers truncate_markers(singlepp::Markers remarkers, int ntop) { - for (auto& x : remarkers) { - for (auto& y : x) { - if (y.size() > static_cast(ntop)) { - y.resize(ntop); - } - } - } - return remarkers; - } }; /********************************************/ @@ -286,6 +275,84 @@ INSTANTIATE_TEST_SUITE_P( /********************************************/ +class TrainIntegratedMismatchTest : public ::testing::Test, public IntegratedTestCore { +protected: + inline static std::vector > > sub_references; + + void SetUp() { + assemble(); + for (size_t r = 0; r < nrefs; ++r) { + sub_references.emplace_back(new tatami::DelayedSubsetBlock(references[r], r, ngenes - r, true)); + } + } +}; + +TEST_F(TrainIntegratedMismatchTest, Simple) { + singlepp::TrainSingleOptions bopt; + std::vector > inputs; + for (size_t r = 0; r < nrefs; ++r) { + const auto& sref = *(sub_references[r]); + auto pre = singlepp::train_single(sref, labels[r].data(), markers[r], bopt); + inputs.push_back(singlepp::prepare_integrated_input(sref, labels[r].data(), pre)); + } + + bool failed = false; + singlepp::TrainIntegratedOptions iopt; + try { + singlepp::train_integrated(std::move(inputs), iopt); + } catch (std::exception& e) { + EXPECT_TRUE(std::string(e.what()).find("inconsistent number of rows") != std::string::npos); + failed = true; + } + EXPECT_TRUE(failed); +} + +TEST_F(TrainIntegratedMismatchTest, Intersect) { + singlepp::TrainSingleOptions bopt; + std::vector > single_ref; + std::vector > inputs; + + int base_seed = 6969; + for (size_t r = 0; r < nrefs; ++r) { + const auto& srefmat = *sub_references[r]; + size_t curgenes = srefmat.nrow(); + auto test_ids = simulate_test_ids(curgenes, base_seed * 10 + r); + auto ref_ids = simulate_ref_ids(curgenes, base_seed * 20 + r); + auto labptr = labels[r].data(); + auto pre = singlepp::train_single_intersect(test_ids.size(), test_ids.data(), srefmat, ref_ids.data(), labptr, markers[r], bopt); + inputs.push_back(singlepp::prepare_integrated_input_intersect(curgenes, test_ids.data(), srefmat, ref_ids.data(), labptr, pre)); + single_ref.push_back(std::move(pre)); + } + + bool failed = false; + singlepp::TrainIntegratedOptions iopt; + try { + singlepp::train_integrated(std::move(inputs), iopt); + } catch (std::exception& e) { + EXPECT_TRUE(std::string(e.what()).find("inconsistent number of rows") != std::string::npos); + failed = true; + } + EXPECT_TRUE(failed); +} + +/********************************************/ + +template +static std::vector > mock_best_choices(size_t ntest, const std::vector& prebuilts, size_t seed) { + size_t nrefs = prebuilts.size(); + std::vector > chosen(nrefs); + + std::mt19937_64 rng(seed); + for (size_t r = 0; r < nrefs; ++r) { + size_t nlabels = prebuilts[r].get_markers().size(); + for (size_t t = 0; t < ntest; ++t) { + chosen[r].push_back(rng() % nlabels); + } + } + + return chosen; +} + class ClassifyIntegratedTest : public ::testing::TestWithParam >, public IntegratedTestCore { protected: static void SetUpTestSuite() { @@ -297,22 +364,6 @@ class ClassifyIntegratedTest : public ::testing::TestWithParam > test; protected: - template - static std::vector > mock_best_choices(size_t ntest, const std::vector& prebuilts, size_t seed) { - size_t nrefs = prebuilts.size(); - std::vector > chosen(nrefs); - - std::mt19937_64 rng(seed); - for (size_t r = 0; r < nrefs; ++r) { - size_t nlabels = prebuilts[r].get_markers().size(); - for (size_t t = 0; t < ntest; ++t) { - chosen[r].push_back(rng() % nlabels); - } - } - - return chosen; - } - static auto split_by_labels(const std::vector >& labels) { std::vector > > by_labels(labels.size()); for (size_t r = 0, nrefs = labels.size(); r < nrefs; ++r) { @@ -526,6 +577,29 @@ TEST_P(ClassifyIntegratedTest, Intersected) { EXPECT_EQ(output.scores[r], poutput.scores[r]); } } + + // Back-compatibility check. + { + std::vector > inputs; + std::vector > intersections; + auto idptr = test_ids.data(); + for (size_t r = 0; r < nrefs; ++r) { + auto refptr = ref_ids[r].data(); + auto labptr = labels[r].data(); + const auto& refmat = *references[r]; + intersections.push_back(singlepp::intersect_genes(test_ids.size(), idptr, refmat.nrow(), refptr)); + inputs.push_back(singlepp::prepare_integrated_input_intersect(intersections.back(), refmat, labptr, prebuilts[r])); + } + + singlepp::TrainIntegratedOptions iopt; + auto integrated = singlepp::train_integrated(std::move(integrated_inputs), iopt); + auto boutput = singlepp::classify_integrated(*test, chosen_ptrs, integrated, copt); + EXPECT_EQ(boutput.best, output.best); + EXPECT_EQ(boutput.delta, output.delta); + for (size_t r = 0; r < nrefs; ++r) { + EXPECT_EQ(boutput.scores[r], output.scores[r]); + } + } } TEST_P(ClassifyIntegratedTest, IntersectedComparison) { @@ -620,3 +694,49 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(0.5, 0.8, 0.9) // number of quantiles. ) ); + +/********************************************/ + +class ClassifyIntegratedMismatchTest : public ::testing::Test, public IntegratedTestCore { +protected: + static void SetUpTestSuite() { + assemble(); + } +}; + +TEST_F(ClassifyIntegratedMismatchTest, Basic) { + singlepp::TrainSingleOptions bopt; + std::vector > prebuilts; + std::vector > integrated_inputs; + + for (size_t r = 0; r < nrefs; ++r) { + auto pre = singlepp::train_single(*(references[r]), labels[r].data(), markers[r], bopt); + prebuilts.push_back(std::move(pre)); + integrated_inputs.push_back(singlepp::prepare_integrated_input(*(references[r]), labels[r].data(), prebuilts.back())); + } + + singlepp::TrainIntegratedOptions iopt; + auto integrated = singlepp::train_integrated(std::move(integrated_inputs), iopt); + + // Mocking up the test dataset and its choices. + size_t ntest = 20; + auto test = spawn_matrix(ngenes * 2, ntest, 69); // more genes than expected. + + int base_seed = 70; + auto chosen = mock_best_choices(ntest, prebuilts, /* seed = */ base_seed); + std::vector chosen_ptrs(nrefs); + for (size_t r = 0; r < nrefs; ++r) { + chosen_ptrs[r] = chosen[r].data(); + } + + // Verifying that it does, in fact, fail. + singlepp::ClassifyIntegratedOptions copt; + bool failed = false; + try { + singlepp::classify_integrated(*test, chosen_ptrs, integrated, copt); + } catch (std::exception& e) { + EXPECT_TRUE(std::string(e.what()).find("number of rows") != std::string::npos); + failed = true; + } + EXPECT_TRUE(failed); +} diff --git a/tests/src/classify_single.cpp b/tests/src/classify_single.cpp index 14934e5..98d696b 100644 --- a/tests/src/classify_single.cpp +++ b/tests/src/classify_single.cpp @@ -33,6 +33,7 @@ TEST_P(ClassifySingleSimpleTest, Simple) { singlepp::TrainSingleOptions bopt; bopt.top = top; auto trained = singlepp::train_single(*refs, labels.data(), markers, bopt); + EXPECT_EQ(trained.get_test_nrow(), 200); singlepp::ClassifySingleOptions copt; copt.fine_tune = false; @@ -115,6 +116,9 @@ TEST_P(ClassifySingleIntersectTest, Intersect) { singlepp::TrainSingleOptions bopt; bopt.top = top; auto trained = singlepp::train_single_intersect(left.size(), left.data(), *refs, right.data(), labels.data(), markers, bopt); + EXPECT_EQ(trained.get_test_nrow(), left.size()); + EXPECT_EQ(trained.get_ref_subset().size(), trained.get_test_subset().size()); + EXPECT_GE(trained.get_test_subset().size(), 10); // should be, on average, 'ngenes * prop^2' overlapping genes. singlepp::ClassifySingleOptions copt; copt.quantile = quantile; @@ -166,6 +170,20 @@ TEST_P(ClassifySingleIntersectTest, Intersect) { { auto intersection = singlepp::intersect_genes(left.size(), left.data(), right.size(), right.data()); std::shuffle(intersection.begin(), intersection.end(), rng); + auto trained2 = singlepp::train_single_intersect(left.size(), intersection, *refs, labels.data(), markers, bopt); + + singlepp::ClassifySingleOptions copt; + copt.quantile = quantile; + auto result2 = singlepp::classify_single_intersect(*mat, trained, copt); + + EXPECT_EQ(result2.scores[0], result.scores[0]); + EXPECT_EQ(result2.best, result.best); + EXPECT_EQ(result2.delta, result.delta); + } + + // Back-compatibility check for the soft-deprecated intersection method. + { + auto intersection = singlepp::intersect_genes(left.size(), left.data(), right.size(), right.data()); auto trained2 = singlepp::train_single_intersect(intersection, *refs, labels.data(), markers, bopt); singlepp::ClassifySingleOptions copt; @@ -272,3 +290,57 @@ TEST(ClassifySingleTest, Nulls) { EXPECT_EQ(best, full.best); } + +TEST(ClassifySingleTest, SimpleMismatch) { + size_t ngenes = 200; + size_t nlabels = 3; + size_t nrefs = 50; + + auto refs = spawn_matrix(ngenes, nrefs, 100); + auto labels = spawn_labels(nrefs, nlabels, 1000); + auto markers = mock_markers(nlabels, 50, ngenes); + + singlepp::TrainSingleOptions bopt; + auto trained = singlepp::train_single(*refs, labels.data(), markers, bopt); + + auto test = spawn_matrix(ngenes + 10, nrefs, 100); + singlepp::ClassifySingleOptions copt; + copt.quantile = 1; + + bool failed = false; + try { + singlepp::classify_single(*test, trained, copt); + } catch (std::exception& e) { + failed = true; + EXPECT_TRUE(std::string(e.what()).find("number of rows") != std::string::npos); + } + EXPECT_TRUE(failed); +} + +TEST(ClassifySingleTest, IntersectMismatch) { + size_t ngenes = 200; + size_t nlabels = 3; + size_t nrefs = 50; + + auto refs = spawn_matrix(ngenes, nrefs, 100); + auto labels = spawn_labels(nrefs, nlabels, 1000); + auto markers = mock_markers(nlabels, 50, ngenes); + + std::vector ids(ngenes); + std::iota(ids.begin(), ids.end(), 0); + singlepp::TrainSingleOptions bopt; + auto trained = singlepp::train_single_intersect(ngenes, ids.data(), *refs, ids.data(), labels.data(), markers, bopt); + + auto test = spawn_matrix(ngenes + 10, nrefs, 100); + singlepp::ClassifySingleOptions copt; + copt.quantile = 1; + + bool failed = false; + try { + singlepp::classify_single_intersect(*test, trained, copt); + } catch (std::exception& e) { + failed = true; + EXPECT_TRUE(std::string(e.what()).find("number of rows") != std::string::npos); + } + EXPECT_TRUE(failed); +}