From 322971d81eb8ae745601515aa15227cc7af6d951 Mon Sep 17 00:00:00 2001 From: Stephen Fleming Date: Tue, 31 Oct 2023 10:47:02 -0400 Subject: [PATCH] Memory-efficient posterior generation (#263) --- cellbender/remove_background/estimation.py | 45 +++----------- cellbender/remove_background/posterior.py | 61 ++++++++++--------- cellbender/remove_background/sparse_utils.py | 6 +- .../remove_background/tests/test_dataprep.py | 6 +- .../tests/test_sparse_utils.py | 6 +- 5 files changed, 50 insertions(+), 74 deletions(-) diff --git a/cellbender/remove_background/estimation.py b/cellbender/remove_background/estimation.py index 2443724..371a7fa 100644 --- a/cellbender/remove_background/estimation.py +++ b/cellbender/remove_background/estimation.py @@ -463,11 +463,9 @@ def estimate_noise(self, if use_multiple_processes: logger.info('Dividing dataset into chunks of genes') - chunk_logic_list = list( - self._gene_chunk_iterator( - noise_log_prob_coo=noise_log_prob_coo, - n_chunks=n_chunks, - ) + chunk_logic_list = self._gene_chunk_iterator( + noise_log_prob_coo=noise_log_prob_coo, + n_chunks=n_chunks, ) logger.info('Computing the output in asynchronous chunks in parallel...') @@ -538,10 +536,9 @@ def estimate_noise(self, def _gene_chunk_iterator(self, noise_log_prob_coo: sp.coo_matrix, n_chunks: int) \ - -> Generator[np.ndarray, None, None]: - """Yields chunks of the posterior that can be treated as independent, - from the standpoint of MCKP count estimation. That is, they contain all - matrix entries for any genes they include. + -> List[np.ndarray]: + """Return a list of logical (size m) arrays used to select gene chunks + on which to compute the MCKP estimate. These chunks are independent. Args: noise_log_prob_coo: Full noise log prob posterior COO @@ -551,36 +548,14 @@ def _gene_chunk_iterator(self, Logical array which indexes elements of coo posterior for the chunk """ - # TODO this generator is way too slow - - # approximate number of entries in a chunk - # approx_chunk_entries = (noise_log_prob_coo.data.size - 1) // n_chunks - # get gene annotations _, genes = self.index_converter.get_ng_indices(m_inds=noise_log_prob_coo.row) genes_series = pd.Series(genes) - # things we need to keep track of for each chunk - # current_chunk_genes = [] - # entry_logic = np.zeros(noise_log_prob_coo.data.size, dtype=bool) - - # TODO eliminate for loop to speed this up - # take the list of genes from the coo, sort it, and divide it evenly - # somehow break ties for genes overlapping boundaries of divisions - sorted_genes = np.sort(genes) - gene_arrays = np.array_split(sorted_genes, n_chunks) - last_gene_set = {} - for gene_array in gene_arrays: - gene_set = set(gene_array) - gene_set = gene_set.difference(last_gene_set) # only the new stuff - # if there is a second chunk, make sure there is a gene unique to it - if (n_chunks > 1) and (len(gene_set) == len(set(genes))): # all genes in first set - # this mainly exists for tests - gene_set = gene_set - {gene_arrays[-1][-1]} - last_gene_set = gene_set - entry_logic = genes_series.isin(gene_set).values - if sum(entry_logic) > 0: - yield entry_logic + gene_chunk_arrays = np.array_split(np.arange(self.index_converter.total_n_genes), n_chunks) + + gene_logic_arrays = [genes_series.isin(x).values for x in gene_chunk_arrays] + return gene_logic_arrays def _chunk_estimate_noise(self, noise_log_prob_coo: sp.coo_matrix, diff --git a/cellbender/remove_background/posterior.py b/cellbender/remove_background/posterior.py index 3758702..6aaeefa 100644 --- a/cellbender/remove_background/posterior.py +++ b/cellbender/remove_background/posterior.py @@ -451,7 +451,7 @@ def _get_cell_noise_count_posterior_coo( f'accurate for your dataset.') raise RuntimeError('Zero cells found!') - dataloader_index_to_analyzed_bc_index = np.where(cell_logic)[0] + dataloader_index_to_analyzed_bc_index = torch.where(torch.tensor(cell_logic))[0] cell_data_loader = DataLoader( count_matrix[cell_logic], empty_drop_dataset=None, @@ -468,6 +468,12 @@ def _get_cell_noise_count_posterior_coo( log_probs = [] ind = 0 n_minibatches = len(cell_data_loader) + analyzed_gene_inds = torch.tensor(self.analyzed_gene_inds.copy()) + if analyzed_bcs_only: + barcode_inds = torch.tensor(self.dataset_obj.analyzed_barcode_inds.copy()) + else: + barcode_inds = torch.tensor(self.barcode_inds.copy()) + nonzero_noise_offset_dict = {} logger.info('Computing posterior noise count probabilities in mini-batches.') @@ -505,46 +511,43 @@ def _get_cell_noise_count_posterior_coo( ) # Get the original gene index from gene index in the trimmed dataset. - genes_i = self.analyzed_gene_inds[genes_i_analyzed] + genes_i = analyzed_gene_inds[genes_i_analyzed.cpu()] # Barcode index in the dataloader. - bcs_i = bcs_i_chunk + ind + bcs_i = (bcs_i_chunk + ind).cpu() # Obtain the real barcode index since we only use cells. bcs_i = dataloader_index_to_analyzed_bc_index[bcs_i] # Translate chunk barcode inds to overall inds. - if analyzed_bcs_only: - bcs_i = self.dataset_obj.analyzed_barcode_inds[bcs_i] - else: - bcs_i = self.barcode_inds[bcs_i] + bcs_i = barcode_inds[bcs_i] # Add sparse matrix values to lists. - try: - bcs.extend(bcs_i.tolist()) - genes.extend(genes_i.tolist()) - c.extend(c_i.tolist()) - log_probs.extend(log_prob_i.tolist()) - c_offset.extend(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed] - .detach().cpu().numpy()) - except TypeError as e: - # edge case of a single value - bcs.append(bcs_i) - genes.append(genes_i) - c.append(c_i) - log_probs.append(log_prob_i) - c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed] - .detach().cpu().numpy()) + bcs.append(bcs_i.detach()) + genes.append(genes_i.detach()) + c.append(c_i.detach().cpu()) + log_probs.append(log_prob_i.detach().cpu()) + + # Update offset dict with any nonzeros. + nonzero_offset_inds, nonzero_noise_count_offsets = dense_to_sparse_op_torch( + noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().flatten(), + ) + m_i = self.index_converter.get_m_indices(cell_inds=bcs_i, gene_inds=genes_i) + + nonzero_noise_offset_dict.update( + dict(zip(m_i[nonzero_offset_inds.detach().cpu()].tolist(), + nonzero_noise_count_offsets.detach().cpu().tolist())) + ) + c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().cpu()) # Increment barcode index counter. ind += data.shape[0] # Same as data_loader.batch_size - # Convert the lists to numpy arrays. - log_probs = np.array(log_probs, dtype=float) - c = np.array(c, dtype=np.uint32) - barcodes = np.array(bcs, dtype=np.uint64) # uint32 is too small! - genes = np.array(genes, dtype=np.uint64) # use same as above for IndexConverter - noise_count_offsets = np.array(c_offset, dtype=np.uint32) + # Concatenate lists. + log_probs = torch.cat(log_probs) + c = torch.cat(c) + barcodes = torch.cat(bcs) + genes = torch.cat(genes) # Translate (barcode, gene) inds to 'm' format index. m = self.index_converter.get_m_indices(cell_inds=barcodes, gene_inds=genes) @@ -554,8 +557,6 @@ def _get_cell_noise_count_posterior_coo( (log_probs, (m, c)), shape=[np.prod(self.count_matrix_shape), n_counts_max], ) - noise_offset_dict = dict(zip(m, noise_count_offsets)) - nonzero_noise_offset_dict = {k: v for k, v in noise_offset_dict.items() if (v > 0)} self._noise_count_posterior_coo_offsets = nonzero_noise_offset_dict return self._noise_count_posterior_coo diff --git a/cellbender/remove_background/sparse_utils.py b/cellbender/remove_background/sparse_utils.py index 4a0f26f..ca31329 100644 --- a/cellbender/remove_background/sparse_utils.py +++ b/cellbender/remove_background/sparse_utils.py @@ -10,7 +10,7 @@ @torch.no_grad() def dense_to_sparse_op_torch(t: torch.Tensor, tensor_for_nonzeros: Optional[torch.Tensor] = None) \ - -> Tuple[np.ndarray, ...]: + -> Tuple[torch.Tensor, ...]: """Converts dense matrix to sparse COO format tuple of numpy arrays (*indices, data) Args: @@ -28,9 +28,9 @@ def dense_to_sparse_op_torch(t: torch.Tensor, tensor_for_nonzeros = t nonzero_inds_tuple = torch.nonzero(tensor_for_nonzeros, as_tuple=True) - nonzero_values = t[nonzero_inds_tuple].flatten() + nonzero_values = t[nonzero_inds_tuple].flatten().clone() - return tuple([ten.cpu().numpy() for ten in (nonzero_inds_tuple + (nonzero_values,))]) + return nonzero_inds_tuple + (nonzero_values,) def log_prob_sparse_to_dense(coo: sp.coo_matrix) -> np.ndarray: diff --git a/cellbender/remove_background/tests/test_dataprep.py b/cellbender/remove_background/tests/test_dataprep.py index 8fcb156..5d11380 100644 --- a/cellbender/remove_background/tests/test_dataprep.py +++ b/cellbender/remove_background/tests/test_dataprep.py @@ -75,9 +75,9 @@ def test_dataloader_sorting(simulated_dataset, cuda): bcs_i = loader.unsort_inds(bcs_i) # Add sparse matrix values to lists. - barcodes.append(bcs_i) - genes.append(genes_i) - counts.append(counts_i) + barcodes.append(bcs_i.detach().cpu()) + genes.append(genes_i.detach().cpu()) + counts.append(counts_i.detach().cpu()) # Increment barcode index counter. ind += data.shape[0] # Same as data_loader.batch_size diff --git a/cellbender/remove_background/tests/test_sparse_utils.py b/cellbender/remove_background/tests/test_sparse_utils.py index 01230da..2f2e13e 100644 --- a/cellbender/remove_background/tests/test_sparse_utils.py +++ b/cellbender/remove_background/tests/test_sparse_utils.py @@ -76,9 +76,9 @@ def test_dense_to_sparse_op_torch(simulated_dataset, cuda): bcs_i = data_loader.unsort_inds(bcs_i) # Add sparse matrix values to lists. - barcodes.append(bcs_i) - genes.append(genes_i) - counts.append(counts_i) + barcodes.append(bcs_i.detach().cpu()) + genes.append(genes_i.detach().cpu()) + counts.append(counts_i.detach().cpu()) # Increment barcode index counter. ind += data.shape[0] # Same as data_loader.batch_size