Skip to content

Commit

Permalink
Store the expected number of test rows in the Trained* objects.
Browse files Browse the repository at this point in the history
All classify_* functions now check that the test dataset has the expected
number of rows, based on the arguments provided to the training functions.
This provides an extra safeguard against the use of the wrong test dataset for
a particular Trained object. We also check that all inputs for integrated
training are expecting the same number of rows for the test dataset.

A consequence of this change is that the various function overloads that accept
an intersection now also need to store the expected number of rows in the test
dataset. The current overload has been soft-deprecated and a new overload that
takes this additional argument has been added.
  • Loading branch information
LTLA committed Dec 26, 2024
1 parent 5c14a6c commit c0447ce
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 32 deletions.
3 changes: 3 additions & 0 deletions include/singlepp/classify_integrated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ void classify_integrated(
ClassifyIntegratedBuffers<RefLabel_, Float_>& buffers,
const ClassifyIntegratedOptions<Float_>& options)
{
if (trained.test_nrow != static_cast<Index_>(-1) && trained.test_nrow != test.nrow()) {
throw std::runtime_error("number of rows in 'test' is not the same as that used to build 'trained'");
}
internal::annotate_cells_integrated(
test,
trained,
Expand Down
6 changes: 6 additions & 0 deletions include/singlepp/classify_single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ void classify_single(
const ClassifySingleBuffers<Label_, Float_>& buffers,
const ClassifySingleOptions<Float_>& options)
{
if (trained.get_test_nrow() != test.nrow()) {
throw std::runtime_error("number of rows in 'test' is not the same as that used to build 'trained'");
}
internal::annotate_cells_single(
test,
trained.get_subset(),
Expand Down Expand Up @@ -172,6 +175,9 @@ void classify_single_intersect(
const ClassifySingleBuffers<Label_, Float_>& buffers,
const ClassifySingleOptions<Float_>& options)
{
if (trained.get_test_nrow() != static_cast<Index_>(-1) && trained.get_test_nrow() != test.nrow()) {
throw std::runtime_error("number of rows in 'test' is not the same as that used to build 'trained'");
}
internal::annotate_cells_single(
test,
trained.get_test_subset(),
Expand Down
39 changes: 37 additions & 2 deletions include/singlepp/train_integrated.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct TrainIntegratedInput {
/**
* @cond
*/
Index_ test_nrow;

const tatami::Matrix<Value_, Index_>* ref;

const Label_* labels;
Expand Down Expand Up @@ -76,6 +78,7 @@ TrainIntegratedInput<Value_, Index_, Label_> prepare_integrated_input(
const TrainedSingle<Index_, Float_>& trained)
{
TrainIntegratedInput<Value_, Index_, Label_> output;
output.test_nrow = ref.nrow(); // remember, test and ref are assumed to have the same features.
output.ref = &ref;
output.labels = labels;

Expand Down Expand Up @@ -115,25 +118,29 @@ TrainIntegratedInput<Value_, Index_, Label_> prepare_integrated_input(
* @tparam Label_ Integer type for the reference labels.
* @tparam Float_ Floating-point type for the correlations and scores.
*
* @param test_nrow Number of features in the test dataset.
* @param intersection Vector defining the intersection of genes between the test and reference datasets.
* Each pair corresponds to a gene where the first and second elements represent the row indices of that gene in the test and reference matrices, respectively.
* The first element of each pair should be non-negative and less than `test_nrow`, while the second element should be non-negative and less than `ref->nrow()`.
* See `intersect_genes()` for more details.
* @param ref Matrix containing the reference expression values, where rows are genes and columns are reference profiles.
* The number and identity of genes should be consistent with `intersection`.
* @param[in] labels An array of length equal to the number of columns of `ref`, containing the label for each sample.
* Values should be integers in \f$[0, L)\f$ where \f$L\f$ is the number of unique labels.
* @param trained Classifier created by calling `train_single_intersect()` on `intersection`, `ref` and `labels`.
* @param trained Classifier created by calling `train_single_intersect()` on `test_nrow`, `intersection`, `ref` and `labels`.
*
* @return An opaque input object for `train_integrated()`.
*/
template<typename Index_, typename Value_, typename Label_, typename Float_>
TrainIntegratedInput<Value_, Index_, Label_> prepare_integrated_input_intersect(
Index_ test_nrow,
const Intersection<Index_>& intersection,
const tatami::Matrix<Value_, Index_>& ref,
const Label_* labels,
const TrainedSingleIntersect<Index_, Float_>& trained)
{
TrainIntegratedInput<Value_, Index_, Label_> output;
output.test_nrow = test_nrow;
output.ref = &ref;
output.labels = labels;

Expand Down Expand Up @@ -166,6 +173,23 @@ TrainIntegratedInput<Value_, Index_, Label_> prepare_integrated_input_intersect(
return output;
}

/**
* @cond
*/
// For back-compatibility only.
template<typename Index_, typename Value_, typename Label_, typename Float_>
TrainIntegratedInput<Value_, Index_, Label_> prepare_integrated_input_intersect(
const Intersection<Index_>& intersection,
const tatami::Matrix<Value_, Index_>& ref,
const Label_* labels,
const TrainedSingleIntersect<Index_, Float_>& trained)
{
return prepare_integrated_input_intersect<Index_, Value_, Label_, Float_>(-1, intersection, ref, labels, trained);
}
/**
* @endcond
*/

/**
* Prepare a reference dataset for `train_integrated()`.
* This overload automatically identifies the intersection of genes between the test and reference datasets.
Expand Down Expand Up @@ -202,7 +226,7 @@ TrainIntegratedInput<Value_, Index_, Label_> prepare_integrated_input_intersect(
const TrainedSingleIntersect<Index_, Float_>& trained)
{
auto intersection = intersect_genes(test_nrow, test_id, ref.nrow(), ref_id);
auto output = prepare_integrated_input_intersect(intersection, ref, labels, trained);
auto output = prepare_integrated_input_intersect(test_nrow, intersection, ref, labels, trained);
output.user_intersection = NULL;
output.auto_intersection.swap(intersection);
return output;
Expand Down Expand Up @@ -248,6 +272,7 @@ class TrainedIntegrated {
*/
// Technically this should be private, but it's a pain to add
// templated friend functions, so I can't be bothered.
Index_ test_nrow;
std::vector<Index_> universe; // To be used by classify_integrated() for indexed extraction.

std::vector<uint8_t> check_availability;
Expand Down Expand Up @@ -411,6 +436,16 @@ TrainedIntegrated<Index_> train_integrated(Inputs_& inputs, const TrainIntegrate
output.markers.resize(nrefs);
output.ranked.resize(nrefs);

// Checking that the number of genes in the test dataset are consistent.
output.test_nrow = -1;
for (const auto& in : inputs) {
if (output.test_nrow == static_cast<Index_>(-1)) {
output.test_nrow = in.test_nrow;
} else if (in.test_nrow != static_cast<Index_>(-1) && in.test_nrow != output.test_nrow) {
throw std::runtime_error("inconsistent number of rows in the test dataset across entries of 'inputs'");
}
}

// Identify the union of all marker genes.
std::unordered_map<Index_, Index_> remap_to_universe;
std::unordered_set<Index_> subset_tmp;
Expand Down
48 changes: 45 additions & 3 deletions include/singlepp/train_single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ class TrainedSingle {
* @cond
*/
TrainedSingle(
Index_ test_nrow,
Markers<Index_> markers,
std::vector<Index_> subset,
std::vector<internal::PerLabelReference<Index_, Float_> > references) :
my_test_nrow(test_nrow),
my_markers(std::move(markers)),
my_subset(std::move(subset)),
my_references(std::move(references))
Expand All @@ -102,11 +104,19 @@ class TrainedSingle {
*/

private:
Index_ my_test_nrow;
Markers<Index_> my_markers;
std::vector<Index_> my_subset;
std::vector<internal::PerLabelReference<Index_, Float_> > my_references;

public:
/**
* @return Number of rows that should be present in the test dataset.
*/
Index_ get_test_nrow() const {
return my_test_nrow;
}

/**
* @return A vector of vectors of vectors of ranked marker genes to be used in the classification.
* In the innermost vectors, each value is an index into the subset vector (see `get_subset()`),
Expand Down Expand Up @@ -186,7 +196,8 @@ TrainedSingle<Index_, Float_> train_single(
{
auto subset = internal::subset_to_markers(markers, options.top);
auto subref = internal::build_references(ref, labels, subset, options);
return TrainedSingle<Index_, Float_>(std::move(markers), std::move(subset), std::move(subref));
Index_ test_nrow = ref.nrow(); // remember, test and ref are assumed to have the same features.
return TrainedSingle<Index_, Float_>(test_nrow, std::move(markers), std::move(subset), std::move(subref));
}

/**
Expand All @@ -206,10 +217,12 @@ class TrainedSingleIntersect {
* @cond
*/
TrainedSingleIntersect(
Index_ test_nrow,
Markers<Index_> markers,
std::vector<Index_> test_subset,
std::vector<Index_> ref_subset,
std::vector<internal::PerLabelReference<Index_, Float_> > references) :
my_test_nrow(test_nrow),
my_markers(std::move(markers)),
my_test_subset(std::move(test_subset)),
my_ref_subset(std::move(ref_subset)),
Expand All @@ -220,12 +233,20 @@ class TrainedSingleIntersect {
*/

private:
Index_ my_test_nrow;
Markers<Index_> my_markers;
std::vector<Index_> my_test_subset;
std::vector<Index_> my_ref_subset;
std::vector<internal::PerLabelReference<Index_, Float_> > my_references;

public:
/**
* @return Number of rows that should be present in the test dataset.
*/
Index_ get_test_nrow() const {
return my_test_nrow;
}

/**
* @return A vector of vectors of ranked marker genes to be used in the classification.
* In the innermost vectors, each value is an index into the subset vectors (see `get_test_subset()` and `get_ref_subset()`).
Expand Down Expand Up @@ -295,8 +316,10 @@ class TrainedSingleIntersect {
* @tparam Label_ Integer type for the reference labels.
* @tparam Float_ Floating-point type for the correlations and scores.
*
* @param test_nrow Number of features in the test dataset.
* @param intersection Vector defining the intersection of genes between the test and reference datasets.
* Each pair corresponds to a gene where the first and second elements represent the row indices of that gene in the test and reference matrices, respectively.
* The first element of each pair should be non-negative and less than `test_nrow`, while the second element should be non-negative and less than `ref->nrow()`.
* See `intersect_genes()` for more details.
* @param ref An expression matrix for the reference expression profiles, where rows are genes and columns are cells.
* This should have non-zero columns.
Expand All @@ -309,6 +332,7 @@ class TrainedSingleIntersect {
*/
template<typename Index_, typename Value_, typename Label_, typename Float_>
TrainedSingleIntersect<Index_, Float_> train_single_intersect(
Index_ test_nrow,
const Intersection<Index_>& intersection,
const tatami::Matrix<Value_, Index_>& ref,
const Label_* labels,
Expand All @@ -317,9 +341,27 @@ TrainedSingleIntersect<Index_, Float_> train_single_intersect(
{
auto pairs = internal::subset_to_markers(intersection, markers, options.top);
auto subref = internal::build_references(ref, labels, pairs.second, options);
return TrainedSingleIntersect<Index_, Float_>(std::move(markers), std::move(pairs.first), std::move(pairs.second), std::move(subref));
return TrainedSingleIntersect<Index_, Float_>(test_nrow, std::move(markers), std::move(pairs.first), std::move(pairs.second), std::move(subref));
}

/**
* @cond
*/
// For back-compatibility only.
template<typename Index_, typename Value_, typename Label_, typename Float_>
TrainedSingleIntersect<Index_, Float_> train_single_intersect(
const Intersection<Index_>& intersection,
const tatami::Matrix<Value_, Index_>& ref,
const Label_* labels,
Markers<Index_> markers,
const TrainSingleOptions<Index_, Float_>& options)
{
return train_single_intersect<Index_, Value_, Label_, Float_>(-1, intersection, ref, labels, std::move(markers), options);
}
/**
* @endcond
*/

/**
* Variant of `train_single()` that uses the intersection of genes between the reference dataset and a (future) test dataset.
* This is useful when the genes are not in the same order and number across the test and reference datasets.
Expand Down Expand Up @@ -359,7 +401,7 @@ TrainedSingleIntersect<Index_, Float_> train_single_intersect(
const TrainSingleOptions<Index_, Float_>& options)
{
auto intersection = intersect_genes(test_nrow, test_id, ref.nrow(), ref_id);
return train_single_intersect(intersection, ref, labels, std::move(markers), options);
return train_single_intersect(test_nrow, intersection, ref, labels, std::move(markers), options);
}

}
Expand Down
Loading

0 comments on commit c0447ce

Please sign in to comment.