Skip to content

Commit

Permalink
Simplify preprocess interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Oct 11, 2024
1 parent 9ef2638 commit 72e92f2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
12 changes: 6 additions & 6 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
DELETION = core.ALLELES.index("-")


def get_progress(iterable, date, phase, show_progress, total=None):
def get_progress(iterable, title, phase, show_progress, total=None):
bar_format = (
"{desc:<22}{percentage:3.0f}%|{bar}"
"| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
)
return tqdm.tqdm(
iterable,
total=total,
desc=f"{date}:{phase}",
desc=f"{title}:{phase}",
disable=not show_progress,
bar_format=bar_format,
dynamic_ncols=True,
Expand All @@ -56,7 +56,7 @@ def __init__(self, date, phase, *args, **kwargs):

def get(self, key, total):
self.current_instance = get_progress(
None, self.date, phase=self.phase, show_progress=self.enabled, total=total
None, title=self.date, phase=self.phase, show_progress=self.enabled, total=total
)
return self.current_instance

Expand Down Expand Up @@ -501,18 +501,18 @@ def preprocess_worker(strains, alignment_store_path, keep_sites):

def preprocess(
strains,
date,
alignment_store_path,
*,
keep_sites,
progress_title="",
show_progress=False,
num_workers=0,
):
num_workers = max(1, num_workers)
splits = min(len(strains), 2 * num_workers)
work = np.array_split(strains, splits)
samples = []
bar = get_progress(strains, date, f"preprocess", show_progress)
bar = get_progress(strains, progress_title, "preprocess", show_progress)
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [
executor.submit(preprocess_worker, w, alignment_store_path, keep_sites)
Expand Down Expand Up @@ -569,9 +569,9 @@ def extend(

preprocessed_samples = preprocess(
strains=[md["strain"] for md in metadata_matches],
date=date,
alignment_store_path=alignment_store.path,
keep_sites=base_ts.sites_position.astype(int),
progress_title=date,
show_progress=show_progress,
num_workers=num_threads,
)
Expand Down
6 changes: 0 additions & 6 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,15 +817,13 @@ def test_exact_matches(
self,
fx_ts_map,
fx_alignment_store,
fx_metadata_db,
strain,
parent,
num_mismatches,
):
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
[strain],
"2020-02-20",
fx_alignment_store.path,
keep_sites=ts.sites_position.astype(int),
)
Expand Down Expand Up @@ -854,7 +852,6 @@ def test_one_mismatch(
self,
fx_ts_map,
fx_alignment_store,
fx_metadata_db,
strain,
parent,
position,
Expand All @@ -864,7 +861,6 @@ def test_one_mismatch(
ts = fx_ts_map["2020-02-10"]
samples = sc2ts.preprocess(
[strain],
"2020-02-20",
fx_alignment_store.path,
keep_sites=ts.sites_position.astype(int),
)
Expand All @@ -889,14 +885,12 @@ def test_two_mismatches(
self,
fx_ts_map,
fx_alignment_store,
fx_metadata_db,
num_mismatches,
):
strain = "SRR11597164"
ts = fx_ts_map["2020-02-01"]
samples = sc2ts.preprocess(
[strain],
"2020-02-20",
fx_alignment_store.path,
keep_sites=ts.sites_position.astype(int),
)
Expand Down

0 comments on commit 72e92f2

Please sign in to comment.