Skip to content

Commit

Permalink
Memory-efficient posterior generation (#263)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjfleming authored Oct 31, 2023
1 parent cf71148 commit 322971d
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 74 deletions.
45 changes: 10 additions & 35 deletions cellbender/remove_background/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,9 @@ def estimate_noise(self,
if use_multiple_processes:

logger.info('Dividing dataset into chunks of genes')
chunk_logic_list = list(
self._gene_chunk_iterator(
noise_log_prob_coo=noise_log_prob_coo,
n_chunks=n_chunks,
)
chunk_logic_list = self._gene_chunk_iterator(
noise_log_prob_coo=noise_log_prob_coo,
n_chunks=n_chunks,
)

logger.info('Computing the output in asynchronous chunks in parallel...')
Expand Down Expand Up @@ -538,10 +536,9 @@ def estimate_noise(self,
def _gene_chunk_iterator(self,
noise_log_prob_coo: sp.coo_matrix,
n_chunks: int) \
-> Generator[np.ndarray, None, None]:
"""Yields chunks of the posterior that can be treated as independent,
from the standpoint of MCKP count estimation. That is, they contain all
matrix entries for any genes they include.
-> List[np.ndarray]:
"""Return a list of logical (size m) arrays used to select gene chunks
on which to compute the MCKP estimate. These chunks are independent.
Args:
noise_log_prob_coo: Full noise log prob posterior COO
Expand All @@ -551,36 +548,14 @@ def _gene_chunk_iterator(self,
Logical array which indexes elements of coo posterior for the chunk
"""

# TODO this generator is way too slow

# approximate number of entries in a chunk
# approx_chunk_entries = (noise_log_prob_coo.data.size - 1) // n_chunks

# get gene annotations
_, genes = self.index_converter.get_ng_indices(m_inds=noise_log_prob_coo.row)
genes_series = pd.Series(genes)

# things we need to keep track of for each chunk
# current_chunk_genes = []
# entry_logic = np.zeros(noise_log_prob_coo.data.size, dtype=bool)

# TODO eliminate for loop to speed this up
# take the list of genes from the coo, sort it, and divide it evenly
# somehow break ties for genes overlapping boundaries of divisions
sorted_genes = np.sort(genes)
gene_arrays = np.array_split(sorted_genes, n_chunks)
last_gene_set = {}
for gene_array in gene_arrays:
gene_set = set(gene_array)
gene_set = gene_set.difference(last_gene_set) # only the new stuff
# if there is a second chunk, make sure there is a gene unique to it
if (n_chunks > 1) and (len(gene_set) == len(set(genes))): # all genes in first set
# this mainly exists for tests
gene_set = gene_set - {gene_arrays[-1][-1]}
last_gene_set = gene_set
entry_logic = genes_series.isin(gene_set).values
if sum(entry_logic) > 0:
yield entry_logic
gene_chunk_arrays = np.array_split(np.arange(self.index_converter.total_n_genes), n_chunks)

gene_logic_arrays = [genes_series.isin(x).values for x in gene_chunk_arrays]
return gene_logic_arrays

def _chunk_estimate_noise(self,
noise_log_prob_coo: sp.coo_matrix,
Expand Down
61 changes: 31 additions & 30 deletions cellbender/remove_background/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _get_cell_noise_count_posterior_coo(
f'accurate for your dataset.')
raise RuntimeError('Zero cells found!')

dataloader_index_to_analyzed_bc_index = np.where(cell_logic)[0]
dataloader_index_to_analyzed_bc_index = torch.where(torch.tensor(cell_logic))[0]
cell_data_loader = DataLoader(
count_matrix[cell_logic],
empty_drop_dataset=None,
Expand All @@ -468,6 +468,12 @@ def _get_cell_noise_count_posterior_coo(
log_probs = []
ind = 0
n_minibatches = len(cell_data_loader)
analyzed_gene_inds = torch.tensor(self.analyzed_gene_inds.copy())
if analyzed_bcs_only:
barcode_inds = torch.tensor(self.dataset_obj.analyzed_barcode_inds.copy())
else:
barcode_inds = torch.tensor(self.barcode_inds.copy())
nonzero_noise_offset_dict = {}

logger.info('Computing posterior noise count probabilities in mini-batches.')

Expand Down Expand Up @@ -505,46 +511,43 @@ def _get_cell_noise_count_posterior_coo(
)

# Get the original gene index from gene index in the trimmed dataset.
genes_i = self.analyzed_gene_inds[genes_i_analyzed]
genes_i = analyzed_gene_inds[genes_i_analyzed.cpu()]

# Barcode index in the dataloader.
bcs_i = bcs_i_chunk + ind
bcs_i = (bcs_i_chunk + ind).cpu()

# Obtain the real barcode index since we only use cells.
bcs_i = dataloader_index_to_analyzed_bc_index[bcs_i]

# Translate chunk barcode inds to overall inds.
if analyzed_bcs_only:
bcs_i = self.dataset_obj.analyzed_barcode_inds[bcs_i]
else:
bcs_i = self.barcode_inds[bcs_i]
bcs_i = barcode_inds[bcs_i]

# Add sparse matrix values to lists.
try:
bcs.extend(bcs_i.tolist())
genes.extend(genes_i.tolist())
c.extend(c_i.tolist())
log_probs.extend(log_prob_i.tolist())
c_offset.extend(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
.detach().cpu().numpy())
except TypeError as e:
# edge case of a single value
bcs.append(bcs_i)
genes.append(genes_i)
c.append(c_i)
log_probs.append(log_prob_i)
c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
.detach().cpu().numpy())
bcs.append(bcs_i.detach())
genes.append(genes_i.detach())
c.append(c_i.detach().cpu())
log_probs.append(log_prob_i.detach().cpu())

# Update offset dict with any nonzeros.
nonzero_offset_inds, nonzero_noise_count_offsets = dense_to_sparse_op_torch(
noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().flatten(),
)
m_i = self.index_converter.get_m_indices(cell_inds=bcs_i, gene_inds=genes_i)

nonzero_noise_offset_dict.update(
dict(zip(m_i[nonzero_offset_inds.detach().cpu()].tolist(),
nonzero_noise_count_offsets.detach().cpu().tolist()))
)
c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().cpu())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size

# Convert the lists to numpy arrays.
log_probs = np.array(log_probs, dtype=float)
c = np.array(c, dtype=np.uint32)
barcodes = np.array(bcs, dtype=np.uint64) # uint32 is too small!
genes = np.array(genes, dtype=np.uint64) # use same as above for IndexConverter
noise_count_offsets = np.array(c_offset, dtype=np.uint32)
# Concatenate lists.
log_probs = torch.cat(log_probs)
c = torch.cat(c)
barcodes = torch.cat(bcs)
genes = torch.cat(genes)

# Translate (barcode, gene) inds to 'm' format index.
m = self.index_converter.get_m_indices(cell_inds=barcodes, gene_inds=genes)
Expand All @@ -554,8 +557,6 @@ def _get_cell_noise_count_posterior_coo(
(log_probs, (m, c)),
shape=[np.prod(self.count_matrix_shape), n_counts_max],
)
noise_offset_dict = dict(zip(m, noise_count_offsets))
nonzero_noise_offset_dict = {k: v for k, v in noise_offset_dict.items() if (v > 0)}
self._noise_count_posterior_coo_offsets = nonzero_noise_offset_dict
return self._noise_count_posterior_coo

Expand Down
6 changes: 3 additions & 3 deletions cellbender/remove_background/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@torch.no_grad()
def dense_to_sparse_op_torch(t: torch.Tensor,
tensor_for_nonzeros: Optional[torch.Tensor] = None) \
-> Tuple[np.ndarray, ...]:
-> Tuple[torch.Tensor, ...]:
"""Converts dense matrix to sparse COO format tuple of numpy arrays (*indices, data)
Args:
Expand All @@ -28,9 +28,9 @@ def dense_to_sparse_op_torch(t: torch.Tensor,
tensor_for_nonzeros = t

nonzero_inds_tuple = torch.nonzero(tensor_for_nonzeros, as_tuple=True)
nonzero_values = t[nonzero_inds_tuple].flatten()
nonzero_values = t[nonzero_inds_tuple].flatten().clone()

return tuple([ten.cpu().numpy() for ten in (nonzero_inds_tuple + (nonzero_values,))])
return nonzero_inds_tuple + (nonzero_values,)


def log_prob_sparse_to_dense(coo: sp.coo_matrix) -> np.ndarray:
Expand Down
6 changes: 3 additions & 3 deletions cellbender/remove_background/tests/test_dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_dataloader_sorting(simulated_dataset, cuda):
bcs_i = loader.unsort_inds(bcs_i)

# Add sparse matrix values to lists.
barcodes.append(bcs_i)
genes.append(genes_i)
counts.append(counts_i)
barcodes.append(bcs_i.detach().cpu())
genes.append(genes_i.detach().cpu())
counts.append(counts_i.detach().cpu())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size
Expand Down
6 changes: 3 additions & 3 deletions cellbender/remove_background/tests/test_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def test_dense_to_sparse_op_torch(simulated_dataset, cuda):
bcs_i = data_loader.unsort_inds(bcs_i)

# Add sparse matrix values to lists.
barcodes.append(bcs_i)
genes.append(genes_i)
counts.append(counts_i)
barcodes.append(bcs_i.detach().cpu())
genes.append(genes_i.detach().cpu())
counts.append(counts_i.detach().cpu())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size
Expand Down

0 comments on commit 322971d

Please sign in to comment.