Skip to content

Commit

Permalink
Improve execution flow for running neighborhood suite operations (#79).
Browse files Browse the repository at this point in the history
Modifying the execution of neighborhood suite operations (computing t-SNE, UMAP, and SNN graph). When the package is run on macOS, we switch to "fork" to circumvent the need for `if __name__ == '__main__':` in user scripts. We completely avoid multiprocessing and execute all operations sequentially if the number of available threads is 1.
jkanche authored Feb 5, 2024
1 parent c6cdc50 commit 6bfa315
Showing 1 changed file with 63 additions and 27 deletions.
90 changes: 63 additions & 27 deletions src/scranpy/analyze/run_neighbor_suite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from concurrent.futures import ProcessPoolExecutor, wait
from copy import copy
from typing import Callable, Tuple

@@ -87,46 +86,83 @@ def run_neighbor_suite(
for k in set([umap_nn, tsne_nn]):
serialized_dict[k] = nn_dict[k].serialize()

# Attempting to evenly distribute threads across the tasks. t-SNE and UMAP
# Attempting to evenly distribute threads across the tasks. t-SNE and UMAP
# are run on separate processes while the SNN graph construction is kept on
# the main thread because we'll need the output for marker detection.
threads_per_task = max(1, int(num_threads / 3))
executor = ProcessPoolExecutor(max_workers=min(2, num_threads))
max_workers = min(2, num_threads)
_tasks = []

run_tsne_copy = copy(run_tsne_options)
run_tsne_copy.set_threads(threads_per_task)
_tasks.append(
executor.submit(
_unserialize_neighbors_before_run,
dimred.run_tsne,
serialized_dict[tsne_nn],
run_tsne_copy,
)
)

run_umap_copy = copy(run_umap_options)
run_umap_copy.set_threads(threads_per_task)
_tasks.append(
executor.submit(
_unserialize_neighbors_before_run,
dimred.run_umap,
serialized_dict[umap_nn],
run_umap_copy,

if max_workers > 1:
import multiprocessing as mp
import platform
from concurrent.futures import ProcessPoolExecutor, wait

pp = platform.platform()
extra_args = {}
if "macos" in pp.lower():
extra_args["mp_context"] = mp.get_context("fork")

executor = ProcessPoolExecutor(max_workers=max_workers, **extra_args)

_tasks.append(
executor.submit(
_unserialize_neighbors_before_run,
dimred.run_tsne,
serialized_dict[tsne_nn],
run_tsne_copy,
)
)
)

def retrieve():
wait(_tasks)
executor.shutdown()
_tasks.append(
executor.submit(
_unserialize_neighbors_before_run,
dimred.run_umap,
serialized_dict[umap_nn],
run_umap_copy,
)
)

def retrieve():
wait(_tasks)
executor.shutdown()

def get_tsne():
retrieve()
return _tasks[0].result()

def get_umap():
retrieve()
return _tasks[1].result()

else:
_tasks.append(
_unserialize_neighbors_before_run(
dimred.run_tsne,
serialized_dict[tsne_nn],
run_tsne_copy,
)
)

_tasks.append(
_unserialize_neighbors_before_run(
dimred.run_umap,
serialized_dict[umap_nn],
run_umap_copy,
)
)

def get_tsne():
retrieve()
return _tasks[0].result()
def get_tsne():
return _tasks[0]

def get_umap():
retrieve()
return _tasks[1].result()
def get_umap():
return _tasks[1]

build_snn_graph_copy = copy(build_snn_graph_options)
remaining_threads = max(1, num_threads - threads_per_task * 2)

0 comments on commit 6bfa315

Please sign in to comment.