diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 688922e..cd05191 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -31,7 +31,7 @@ DELETION = core.ALLELES.index("-") -def get_progress(iterable, date, phase, show_progress, total=None): +def get_progress(iterable, title, phase, show_progress, total=None): bar_format = ( "{desc:<22}{percentage:3.0f}%|{bar}" "| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]" @@ -39,7 +39,7 @@ def get_progress(iterable, date, phase, show_progress, total=None): return tqdm.tqdm( iterable, total=total, - desc=f"{date}:{phase}", + desc=f"{title}:{phase}", disable=not show_progress, bar_format=bar_format, dynamic_ncols=True, @@ -56,7 +56,7 @@ def __init__(self, date, phase, *args, **kwargs): def get(self, key, total): self.current_instance = get_progress( - None, self.date, phase=self.phase, show_progress=self.enabled, total=total + None, title=self.date, phase=self.phase, show_progress=self.enabled, total=total ) return self.current_instance @@ -501,10 +501,10 @@ def preprocess_worker(strains, alignment_store_path, keep_sites): def preprocess( strains, - date, alignment_store_path, *, keep_sites, + progress_title="", show_progress=False, num_workers=0, ): @@ -512,7 +512,7 @@ def preprocess( splits = min(len(strains), 2 * num_workers) work = np.array_split(strains, splits) samples = [] - bar = get_progress(strains, date, f"preprocess", show_progress) + bar = get_progress(strains, progress_title, "preprocess", show_progress) with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: futures = [ executor.submit(preprocess_worker, w, alignment_store_path, keep_sites) @@ -569,9 +569,9 @@ def extend( preprocessed_samples = preprocess( strains=[md["strain"] for md in metadata_matches], - date=date, alignment_store_path=alignment_store.path, keep_sites=base_ts.sites_position.astype(int), + progress_title=date, show_progress=show_progress, num_workers=num_threads, ) diff --git a/tests/test_inference.py b/tests/test_inference.py index afb330f..dc0c456 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -817,7 +817,6 @@ def test_exact_matches( self, fx_ts_map, fx_alignment_store, - fx_metadata_db, strain, parent, num_mismatches, @@ -825,7 +824,6 @@ def test_exact_matches( ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( [strain], - "2020-02-20", fx_alignment_store.path, keep_sites=ts.sites_position.astype(int), ) @@ -854,7 +852,6 @@ def test_one_mismatch( self, fx_ts_map, fx_alignment_store, - fx_metadata_db, strain, parent, position, @@ -864,7 +861,6 @@ def test_one_mismatch( ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( [strain], - "2020-02-20", fx_alignment_store.path, keep_sites=ts.sites_position.astype(int), ) @@ -889,14 +885,12 @@ def test_two_mismatches( self, fx_ts_map, fx_alignment_store, - fx_metadata_db, num_mismatches, ): strain = "SRR11597164" ts = fx_ts_map["2020-02-01"] samples = sc2ts.preprocess( [strain], - "2020-02-20", fx_alignment_store.path, keep_sites=ts.sites_position.astype(int), )