Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory-efficient posterior generation #263

Merged
merged 8 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 10 additions & 35 deletions cellbender/remove_background/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,9 @@ def estimate_noise(self,
if use_multiple_processes:

logger.info('Dividing dataset into chunks of genes')
chunk_logic_list = list(
self._gene_chunk_iterator(
noise_log_prob_coo=noise_log_prob_coo,
n_chunks=n_chunks,
)
chunk_logic_list = self._gene_chunk_iterator(
noise_log_prob_coo=noise_log_prob_coo,
n_chunks=n_chunks,
)

logger.info('Computing the output in asynchronous chunks in parallel...')
Expand Down Expand Up @@ -538,10 +536,9 @@ def estimate_noise(self,
def _gene_chunk_iterator(self,
noise_log_prob_coo: sp.coo_matrix,
n_chunks: int) \
-> Generator[np.ndarray, None, None]:
"""Yields chunks of the posterior that can be treated as independent,
from the standpoint of MCKP count estimation. That is, they contain all
matrix entries for any genes they include.
-> List[np.ndarray]:
"""Return a list of logical (size m) arrays used to select gene chunks
on which to compute the MCKP estimate. These chunks are independent.

Args:
noise_log_prob_coo: Full noise log prob posterior COO
Expand All @@ -551,36 +548,14 @@ def _gene_chunk_iterator(self,
Logical array which indexes elements of coo posterior for the chunk
"""

# TODO this generator is way too slow

# approximate number of entries in a chunk
# approx_chunk_entries = (noise_log_prob_coo.data.size - 1) // n_chunks

# get gene annotations
_, genes = self.index_converter.get_ng_indices(m_inds=noise_log_prob_coo.row)
genes_series = pd.Series(genes)

# things we need to keep track of for each chunk
# current_chunk_genes = []
# entry_logic = np.zeros(noise_log_prob_coo.data.size, dtype=bool)

# TODO eliminate for loop to speed this up
# take the list of genes from the coo, sort it, and divide it evenly
# somehow break ties for genes overlapping boundaries of divisions
sorted_genes = np.sort(genes)
gene_arrays = np.array_split(sorted_genes, n_chunks)
last_gene_set = {}
for gene_array in gene_arrays:
gene_set = set(gene_array)
gene_set = gene_set.difference(last_gene_set) # only the new stuff
# if there is a second chunk, make sure there is a gene unique to it
if (n_chunks > 1) and (len(gene_set) == len(set(genes))): # all genes in first set
# this mainly exists for tests
gene_set = gene_set - {gene_arrays[-1][-1]}
last_gene_set = gene_set
entry_logic = genes_series.isin(gene_set).values
if sum(entry_logic) > 0:
yield entry_logic
gene_chunk_arrays = np.array_split(np.arange(self.index_converter.total_n_genes), n_chunks)

gene_logic_arrays = [genes_series.isin(x).values for x in gene_chunk_arrays]
return gene_logic_arrays

def _chunk_estimate_noise(self,
noise_log_prob_coo: sp.coo_matrix,
Expand Down
61 changes: 31 additions & 30 deletions cellbender/remove_background/posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def _get_cell_noise_count_posterior_coo(
f'accurate for your dataset.')
raise RuntimeError('Zero cells found!')

dataloader_index_to_analyzed_bc_index = np.where(cell_logic)[0]
dataloader_index_to_analyzed_bc_index = torch.where(torch.tensor(cell_logic))[0]
cell_data_loader = DataLoader(
count_matrix[cell_logic],
empty_drop_dataset=None,
Expand All @@ -468,6 +468,12 @@ def _get_cell_noise_count_posterior_coo(
log_probs = []
ind = 0
n_minibatches = len(cell_data_loader)
analyzed_gene_inds = torch.tensor(self.analyzed_gene_inds.copy())
if analyzed_bcs_only:
barcode_inds = torch.tensor(self.dataset_obj.analyzed_barcode_inds.copy())
else:
barcode_inds = torch.tensor(self.barcode_inds.copy())
nonzero_noise_offset_dict = {}

logger.info('Computing posterior noise count probabilities in mini-batches.')

Expand Down Expand Up @@ -505,46 +511,43 @@ def _get_cell_noise_count_posterior_coo(
)

# Get the original gene index from gene index in the trimmed dataset.
genes_i = self.analyzed_gene_inds[genes_i_analyzed]
genes_i = analyzed_gene_inds[genes_i_analyzed.cpu()]

# Barcode index in the dataloader.
bcs_i = bcs_i_chunk + ind
bcs_i = (bcs_i_chunk + ind).cpu()

# Obtain the real barcode index since we only use cells.
bcs_i = dataloader_index_to_analyzed_bc_index[bcs_i]

# Translate chunk barcode inds to overall inds.
if analyzed_bcs_only:
bcs_i = self.dataset_obj.analyzed_barcode_inds[bcs_i]
else:
bcs_i = self.barcode_inds[bcs_i]
bcs_i = barcode_inds[bcs_i]

# Add sparse matrix values to lists.
try:
bcs.extend(bcs_i.tolist())
genes.extend(genes_i.tolist())
c.extend(c_i.tolist())
log_probs.extend(log_prob_i.tolist())
c_offset.extend(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
.detach().cpu().numpy())
except TypeError as e:
# edge case of a single value
bcs.append(bcs_i)
genes.append(genes_i)
c.append(c_i)
log_probs.append(log_prob_i)
c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed]
.detach().cpu().numpy())
bcs.append(bcs_i.detach())
genes.append(genes_i.detach())
c.append(c_i.detach().cpu())
log_probs.append(log_prob_i.detach().cpu())

# Update offset dict with any nonzeros.
nonzero_offset_inds, nonzero_noise_count_offsets = dense_to_sparse_op_torch(
noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().flatten(),
)
m_i = self.index_converter.get_m_indices(cell_inds=bcs_i, gene_inds=genes_i)

nonzero_noise_offset_dict.update(
dict(zip(m_i[nonzero_offset_inds.detach().cpu()].tolist(),
nonzero_noise_count_offsets.detach().cpu().tolist()))
)
c_offset.append(noise_count_offset_NG[bcs_i_chunk, genes_i_analyzed].detach().cpu())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size

# Convert the lists to numpy arrays.
log_probs = np.array(log_probs, dtype=float)
c = np.array(c, dtype=np.uint32)
barcodes = np.array(bcs, dtype=np.uint64) # uint32 is too small!
genes = np.array(genes, dtype=np.uint64) # use same as above for IndexConverter
noise_count_offsets = np.array(c_offset, dtype=np.uint32)
# Concatenate lists.
log_probs = torch.cat(log_probs)
c = torch.cat(c)
barcodes = torch.cat(bcs)
genes = torch.cat(genes)

# Translate (barcode, gene) inds to 'm' format index.
m = self.index_converter.get_m_indices(cell_inds=barcodes, gene_inds=genes)
Expand All @@ -554,8 +557,6 @@ def _get_cell_noise_count_posterior_coo(
(log_probs, (m, c)),
shape=[np.prod(self.count_matrix_shape), n_counts_max],
)
noise_offset_dict = dict(zip(m, noise_count_offsets))
nonzero_noise_offset_dict = {k: v for k, v in noise_offset_dict.items() if (v > 0)}
self._noise_count_posterior_coo_offsets = nonzero_noise_offset_dict
return self._noise_count_posterior_coo

Expand Down
6 changes: 3 additions & 3 deletions cellbender/remove_background/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@torch.no_grad()
def dense_to_sparse_op_torch(t: torch.Tensor,
tensor_for_nonzeros: Optional[torch.Tensor] = None) \
-> Tuple[np.ndarray, ...]:
-> Tuple[torch.Tensor, ...]:
"""Converts dense matrix to sparse COO format tuple of numpy arrays (*indices, data)

Args:
Expand All @@ -28,9 +28,9 @@ def dense_to_sparse_op_torch(t: torch.Tensor,
tensor_for_nonzeros = t

nonzero_inds_tuple = torch.nonzero(tensor_for_nonzeros, as_tuple=True)
nonzero_values = t[nonzero_inds_tuple].flatten()
nonzero_values = t[nonzero_inds_tuple].flatten().clone()

return tuple([ten.cpu().numpy() for ten in (nonzero_inds_tuple + (nonzero_values,))])
return nonzero_inds_tuple + (nonzero_values,)


def log_prob_sparse_to_dense(coo: sp.coo_matrix) -> np.ndarray:
Expand Down
6 changes: 3 additions & 3 deletions cellbender/remove_background/tests/test_dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def test_dataloader_sorting(simulated_dataset, cuda):
bcs_i = loader.unsort_inds(bcs_i)

# Add sparse matrix values to lists.
barcodes.append(bcs_i)
genes.append(genes_i)
counts.append(counts_i)
barcodes.append(bcs_i.detach().cpu())
genes.append(genes_i.detach().cpu())
counts.append(counts_i.detach().cpu())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size
Expand Down
6 changes: 3 additions & 3 deletions cellbender/remove_background/tests/test_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def test_dense_to_sparse_op_torch(simulated_dataset, cuda):
bcs_i = data_loader.unsort_inds(bcs_i)

# Add sparse matrix values to lists.
barcodes.append(bcs_i)
genes.append(genes_i)
counts.append(counts_i)
barcodes.append(bcs_i.detach().cpu())
genes.append(genes_i.detach().cpu())
counts.append(counts_i.detach().cpu())

# Increment barcode index counter.
ind += data.shape[0] # Same as data_loader.batch_size
Expand Down