From 98f29aadb5902071f623bcc021ee60d321931d74 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 10 Oct 2024 23:26:08 +0100 Subject: [PATCH 1/5] Refactor preprocess to decouple from metadata --- sc2ts/inference.py | 77 +++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 28595dd..dba1c18 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -482,14 +482,13 @@ def check_base_ts(ts): assert len(sc2ts_md["samples_strain"]) == ts.num_samples -def preprocess_worker(samples_md, alignment_store_path, keep_sites): +def preprocess_worker(strains, alignment_store_path, keep_sites): # print("preprocess worker", samples_md) with alignments.AlignmentStore(alignment_store_path) as alignment_store: samples = [] - for md in samples_md: - strain = md["strain"] + for strain in strains: alignment = alignment_store.get(strain, None) - sample = Sample(strain, metadata=md) + sample = Sample(strain) if alignment is not None: a = alignment[keep_sites] sample.haplotype = alignments.encode_alignment(a) @@ -501,47 +500,28 @@ def preprocess_worker(samples_md, alignment_store_path, keep_sites): def preprocess( - samples_md, - base_ts, + strains, date, - alignment_store, - pango_lineage_key="pango", + alignment_store_path, + keep_sites=None, show_progress=False, - max_missing_sites=np.inf, num_workers=0, ): num_workers = max(1, num_workers) - keep_sites = base_ts.sites_position.astype(int) - splits = min(len(samples_md), 2 * num_workers) - work = np.array_split(samples_md, splits) + splits = min(len(strains), 2 * num_workers) + work = np.array_split(strains, splits) samples = [] - bar = get_progress(samples_md, date, f"preprocess", show_progress) + bar = get_progress(strains, date, f"preprocess", show_progress) with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: futures = [ - executor.submit(preprocess_worker, w, alignment_store.path, keep_sites) + executor.submit(preprocess_worker, w, alignment_store_path, keep_sites) for w in work ] for future in concurrent.futures.as_completed(futures): for s in future.result(): bar.update() - if s.haplotype is None: - logger.debug(f"No alignment stored for {s.strain}") - continue - s.date = date - s.pango = s.metadata.get(pango_lineage_key, "PangoUnknown") - num_missing_sites = s.num_missing_sites - num_deletion_sites = s.num_deletion_sites - logger.debug( - f"Encoded {s.strain} {s.pango} missing={num_missing_sites} " - f"deletions={num_deletion_sites}" - ) - if num_missing_sites <= max_missing_sites: - samples.append(s) - else: - logger.debug( - f"Filter {s.strain}: missing={num_missing_sites} > {max_missing_sites}" - ) + samples.append(s) bar.close() return samples @@ -587,16 +567,37 @@ def extend( logger.info(f"Got {len(metadata_matches)} metadata matches") - samples = preprocess( - metadata_matches, - base_ts, - date, - alignment_store, - pango_lineage_key="Viridian_pangolin", # TODO parametrise + preprocessed_samples = preprocess( + strains=[md["strain"] for md in metadata_matches], + date=date, + alignment_store_path=alignment_store.path, + keep_sites=base_ts.sites_position.astype(int), show_progress=show_progress, - max_missing_sites=max_missing_sites, num_workers=num_threads, ) + # FIXME parametrise + pango_lineage_key = "Viridian_pangolin" + + samples = [] + for s, md in zip(preprocessed_samples, metadata_matches): + if s.haplotype is None: + logger.debug(f"No alignment stored for {s.strain}") + continue + s.metadata = md + s.pango = md.get(pango_lineage_key, "Unknown") + s.date = date + num_missing_sites = s.num_missing_sites + num_deletion_sites = s.num_deletion_sites + logger.debug( + f"Encoded {s.strain} {s.pango} missing={num_missing_sites} " + f"deletions={num_deletion_sites}" + ) + if num_missing_sites <= max_missing_sites: + samples.append(s) + else: + logger.debug( + f"Filter {s.strain}: missing={num_missing_sites} > {max_missing_sites}" + ) if max_daily_samples is not None: if max_daily_samples < len(samples): From e372488cc23789c08049cbf6233cc185ddd09ff2 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 10 Oct 2024 23:40:45 +0100 Subject: [PATCH 2/5] Add basic match UI --- sc2ts/cli.py | 34 ++++++++++++++++++++++++++++++++++ sc2ts/inference.py | 3 +-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sc2ts/cli.py b/sc2ts/cli.py index dd319bc..d7c2351 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -710,6 +710,39 @@ def annotate_recombinants( tszip.compress(ts, out_tsz) +@click.command() +@click.argument("alignments", type=click.Path(exists=True, dir_okay=False)) +@click.argument("ts", type=click.Path(exists=True, dir_okay=False)) +@click.argument("strains", nargs=-1) +@click.option("--num-mismatches", default=3, type=float, help="num-mismatches") +@click.option("-v", "--verbose", count=True) +def run_match(alignments, ts, strains, num_mismatches, verbose): + """ + Run matches for a specified set of strains, outputting details to stdout as JSON. + """ + ts = tszip.load(ts) + samples = sc2ts.preprocess( + list(strains), "xx", alignments, keep_sites=ts.sites_position.astype(int) + ) + for sample in samples: + if sample.haplotype is None: + raise ValueError(f"No alignment stored for {sample.strain}") + mu, rho = sc2ts.solve_num_mismatches(num_mismatches) + matches = sc2ts.match_tsinfer( + samples=samples, + ts=ts, + mu=mu, + rho=rho, + # num_threads=num_threads, + show_progress=True, + # Maximum possible precision + likelihood_threshold=1e-200, + # mirror_coordinates=hmm_pass == "reverse", + ) + for hmm_match, sample in zip(matches, samples): + print(sample.strain, hmm_match.summary()) + + @click.version_option(core.__version__) @click.group() def cli(): @@ -730,4 +763,5 @@ def cli(): cli.add_command(extend) cli.add_command(validate) cli.add_command(annotate_recombinants) +cli.add_command(run_match) cli.add_command(tally_lineages) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index dba1c18..ccd7d1a 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -483,7 +483,7 @@ def check_base_ts(ts): def preprocess_worker(strains, alignment_store_path, keep_sites): - # print("preprocess worker", samples_md) + assert keep_sites is not None with alignments.AlignmentStore(alignment_store_path) as alignment_store: samples = [] for strain in strains: @@ -511,7 +511,6 @@ def preprocess( splits = min(len(strains), 2 * num_workers) work = np.array_split(strains, splits) samples = [] - bar = get_progress(strains, date, f"preprocess", show_progress) with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: futures = [ From 9ef2638d600b2b0a6b6cbd68b79db7e6cd503c59 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 11 Oct 2024 10:19:15 +0100 Subject: [PATCH 3/5] Fixup tests --- sc2ts/inference.py | 3 ++- tests/test_inference.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index ccd7d1a..688922e 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -503,7 +503,8 @@ def preprocess( strains, date, alignment_store_path, - keep_sites=None, + *, + keep_sites, show_progress=False, num_workers=0, ): diff --git a/tests/test_inference.py b/tests/test_inference.py index 929ce87..afb330f 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -824,7 +824,10 @@ def test_exact_matches( ): ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( - [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store + [strain], + "2020-02-20", + fx_alignment_store.path, + keep_sites=ts.sites_position.astype(int), ) mu, rho = sc2ts.solve_num_mismatches(num_mismatches) matches = sc2ts.match_tsinfer( @@ -860,7 +863,10 @@ def test_one_mismatch( ): ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( - [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store + [strain], + "2020-02-20", + fx_alignment_store.path, + keep_sites=ts.sites_position.astype(int), ) mu, rho = sc2ts.solve_num_mismatches(num_mismatches) matches = sc2ts.match_tsinfer( @@ -889,7 +895,10 @@ def test_two_mismatches( strain = "SRR11597164" ts = fx_ts_map["2020-02-01"] samples = sc2ts.preprocess( - [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store + [strain], + "2020-02-20", + fx_alignment_store.path, + keep_sites=ts.sites_position.astype(int), ) mu, rho = sc2ts.solve_num_mismatches(num_mismatches) matches = sc2ts.match_tsinfer( From 72e92f269b665a3003cb30b8504cc57e3659ca82 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 11 Oct 2024 10:23:53 +0100 Subject: [PATCH 4/5] Simplify preprocess interface --- sc2ts/inference.py | 12 ++++++------ tests/test_inference.py | 6 ------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 688922e..cd05191 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -31,7 +31,7 @@ DELETION = core.ALLELES.index("-") -def get_progress(iterable, date, phase, show_progress, total=None): +def get_progress(iterable, title, phase, show_progress, total=None): bar_format = ( "{desc:<22}{percentage:3.0f}%|{bar}" "| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]" @@ -39,7 +39,7 @@ def get_progress(iterable, date, phase, show_progress, total=None): return tqdm.tqdm( iterable, total=total, - desc=f"{date}:{phase}", + desc=f"{title}:{phase}", disable=not show_progress, bar_format=bar_format, dynamic_ncols=True, @@ -56,7 +56,7 @@ def __init__(self, date, phase, *args, **kwargs): def get(self, key, total): self.current_instance = get_progress( - None, self.date, phase=self.phase, show_progress=self.enabled, total=total + None, title=self.date, phase=self.phase, show_progress=self.enabled, total=total ) return self.current_instance @@ -501,10 +501,10 @@ def preprocess_worker(strains, alignment_store_path, keep_sites): def preprocess( strains, - date, alignment_store_path, *, keep_sites, + progress_title="", show_progress=False, num_workers=0, ): @@ -512,7 +512,7 @@ def preprocess( splits = min(len(strains), 2 * num_workers) work = np.array_split(strains, splits) samples = [] - bar = get_progress(strains, date, f"preprocess", show_progress) + bar = get_progress(strains, progress_title, "preprocess", show_progress) with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor: futures = [ executor.submit(preprocess_worker, w, alignment_store_path, keep_sites) @@ -569,9 +569,9 @@ def extend( preprocessed_samples = preprocess( strains=[md["strain"] for md in metadata_matches], - date=date, alignment_store_path=alignment_store.path, keep_sites=base_ts.sites_position.astype(int), + progress_title=date, show_progress=show_progress, num_workers=num_threads, ) diff --git a/tests/test_inference.py b/tests/test_inference.py index afb330f..dc0c456 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -817,7 +817,6 @@ def test_exact_matches( self, fx_ts_map, fx_alignment_store, - fx_metadata_db, strain, parent, num_mismatches, @@ -825,7 +824,6 @@ def test_exact_matches( ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( [strain], - "2020-02-20", fx_alignment_store.path, keep_sites=ts.sites_position.astype(int), ) @@ -854,7 +852,6 @@ def test_one_mismatch( self, fx_ts_map, fx_alignment_store, - fx_metadata_db, strain, parent, position, @@ -864,7 +861,6 @@ def test_one_mismatch( ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( [strain], - "2020-02-20", fx_alignment_store.path, keep_sites=ts.sites_position.astype(int), ) @@ -889,14 +885,12 @@ def test_two_mismatches( self, fx_ts_map, fx_alignment_store, - fx_metadata_db, num_mismatches, ): strain = "SRR11597164" ts = fx_ts_map["2020-02-01"] samples = sc2ts.preprocess( [strain], - "2020-02-20", fx_alignment_store.path, keep_sites=ts.sites_position.astype(int), ) From 3ded54015dbb35b2bcbfaac67273503857470107 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 11 Oct 2024 11:18:53 +0100 Subject: [PATCH 5/5] Finalise and test match UI --- sc2ts/cli.py | 71 +++++++++++++++++++++++++++++++++++++++++----- sc2ts/inference.py | 30 +++++++++++++------- tests/test_cli.py | 71 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 18 deletions(-) diff --git a/sc2ts/cli.py b/sc2ts/cli.py index d7c2351..70c25eb 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -710,19 +710,68 @@ def annotate_recombinants( tszip.compress(ts, out_tsz) +@dataclasses.dataclass(frozen=True) +class HmmRun: + strain: str + num_mismatches: int + direction: str + match: sc2ts.HmmMatch + + def asdict(self): + d = dataclasses.asdict(self) + d["match"] = dataclasses.asdict(self.match) + return d + + def asjson(self): + return json.dumps(self.asdict()) + + @click.command() @click.argument("alignments", type=click.Path(exists=True, dir_okay=False)) @click.argument("ts", type=click.Path(exists=True, dir_okay=False)) @click.argument("strains", nargs=-1) -@click.option("--num-mismatches", default=3, type=float, help="num-mismatches") +@click.option("--num-mismatches", default=3, type=int, help="num-mismatches") +@click.option( + "--direction", + type=click.Choice(["forward", "reverse"]), + default="forward", + help="Direction to run HMM in", +) +@click.option( + "--num-threads", + default=0, + type=int, + help="Number of match threads (default to one)", +) +@click.option("--progress/--no-progress", default=True) @click.option("-v", "--verbose", count=True) -def run_match(alignments, ts, strains, num_mismatches, verbose): +@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False)) +def run_match( + alignments, + ts, + strains, + num_mismatches, + direction, + num_threads, + progress, + verbose, + log_file, +): """ Run matches for a specified set of strains, outputting details to stdout as JSON. """ + setup_logging(verbose, log_file) ts = tszip.load(ts) + if len(strains) == 0: + return + progress_title = "Match" samples = sc2ts.preprocess( - list(strains), "xx", alignments, keep_sites=ts.sites_position.astype(int) + list(strains), + alignments, + show_progress=progress, + progress_title=progress_title, + keep_sites=ts.sites_position.astype(int), + num_workers=num_threads, ) for sample in samples: if sample.haplotype is None: @@ -733,14 +782,22 @@ def run_match(alignments, ts, strains, num_mismatches, verbose): ts=ts, mu=mu, rho=rho, - # num_threads=num_threads, - show_progress=True, + num_threads=num_threads, + show_progress=progress, + progress_title=progress_title, + progress_phase="HMM", # Maximum possible precision likelihood_threshold=1e-200, - # mirror_coordinates=hmm_pass == "reverse", + mirror_coordinates=direction == "reverse", ) for hmm_match, sample in zip(matches, samples): - print(sample.strain, hmm_match.summary()) + run = HmmRun( + strain=sample.strain, + num_mismatches=num_mismatches, + direction=direction, + match=hmm_match, + ) + print(run.asjson()) @click.version_option(core.__version__) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index cd05191..f6e8b11 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -5,6 +5,7 @@ import dataclasses import collections import concurrent.futures +import json import pickle import hashlib import sqlite3 @@ -56,7 +57,11 @@ def __init__(self, date, phase, *args, **kwargs): def get(self, key, total): self.current_instance = get_progress( - None, title=self.date, phase=self.phase, show_progress=self.enabled, total=total + None, + title=self.date, + phase=self.phase, + show_progress=self.enabled, + total=total, ) return self.current_instance @@ -432,8 +437,8 @@ def match_samples( likelihood_threshold=likelihood_threshold, num_threads=num_threads, show_progress=show_progress, - date=date, - phase=f"match({k})", + progress_title=date, + progress_phase=f"match({k})", ) exceeding_threshold = [] @@ -462,10 +467,9 @@ def match_samples( rho=rho, num_threads=num_threads, show_progress=show_progress, - date=date, - phase=f"match(F)", + progress_title=date, + progress_phase=f"match(F)", ) - recombinants = [] for sample, hmm_match in zip(run_batch, hmm_matches): sample.hmm_match = hmm_match cost = hmm_match.get_hmm_cost(num_mismatches) @@ -1107,8 +1111,8 @@ def match_tsinfer( likelihood_threshold=None, num_threads=0, show_progress=False, - date=None, - phase=None, + progress_title=None, + progress_phase=None, mirror_coordinates=False, ): if len(samples) == 0: @@ -1129,7 +1133,7 @@ def match_tsinfer( # we're interested in solving for exactly. likelihood_threshold = rho**2 * mu**5 - pm = TsinferProgressMonitor(date, phase, enabled=show_progress) + pm = TsinferProgressMonitor(progress_title, progress_phase, enabled=show_progress) # This is just working around tsinfer's input checking logic. The actual value # we're incrementing by has no effect. @@ -1263,6 +1267,7 @@ def mutation_summary(self): return "[" + ", ".join(str(mutation) for mutation in self.mutations) + "]" + def get_match_info(ts, sample_paths, sample_mutations): tables = ts.tables assert np.all(tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1)) @@ -1289,7 +1294,10 @@ def get_closest_mutation(node, site_id): matches = [] for path, mutations in zip(sample_paths, sample_mutations): - sample_path = [PathSegment(*seg) for seg in path] + sample_path = [ + PathSegment(int(left), int(right), int(parent)) + for left, right, parent in path + ] sample_mutations = [] for site_pos, derived_state in mutations: site_id = np.searchsorted(ts.sites_position, site_pos) @@ -1312,7 +1320,7 @@ def get_closest_mutation(node, site_id): assert inherited_state != derived_state sample_mutations.append( MatchMutation( - site_id=site_id, + site_id=int(site_id), site_position=int(site_pos), derived_state=derived_state, inherited_state=inherited_state, diff --git a/tests/test_cli.py b/tests/test_cli.py index edf1dab..0d7bc64 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -104,6 +104,77 @@ def test_provenance(self, tmp_path): assert "max_memory" in resources +class TestRunMatch: + + def test_single_defaults(self, tmp_path, fx_ts_map, fx_alignment_store): + strain = "ERR4206593" + ts = fx_ts_map["2020-02-04"] + ts_path = tmp_path / "ts.ts" + ts.dump(ts_path) + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.cli, + f"run-match {fx_alignment_store.path} {ts_path} {strain}", + catch_exceptions=False, + ) + assert result.exit_code == 0 + lines = result.stdout.splitlines() + assert len(lines) == 1 + d = json.loads(lines[0]) + assert d["strain"] == strain + assert d["num_mismatches"] == 3 + assert d["direction"] == "forward" + assert len(d["match"]["path"]) == 1 + assert len(d["match"]["mutations"]) == 5 + + def test_multi_defaults(self, tmp_path, fx_ts_map, fx_alignment_store): + copies = 10 + strains = ["ERR4206593"] * 10 + ts = fx_ts_map["2020-02-13"] + ts_path = tmp_path / "ts.ts" + ts.dump(ts_path) + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.cli, + f"run-match {fx_alignment_store.path} {ts_path} " + " ".join(strains), + catch_exceptions=False, + ) + assert result.exit_code == 0 + lines = result.stdout.splitlines() + assert len(lines) == copies + d = json.loads(lines[0]) + assert d["strain"] == strains[0] + assert d["num_mismatches"] == 3 + assert d["direction"] == "forward" + assert len(d["match"]["path"]) == 1 + assert len(d["match"]["mutations"]) == 0 + for line in lines[1:]: + d2 = json.loads(line) + assert d == d2 + + def test_single_options(self, tmp_path, fx_ts_map, fx_alignment_store): + strain = "ERR4206593" + ts = fx_ts_map["2020-02-04"] + ts_path = tmp_path / "ts.ts" + ts.dump(ts_path) + runner = ct.CliRunner(mix_stderr=False) + result = runner.invoke( + cli.cli, + f"run-match {fx_alignment_store.path} {ts_path} {strain}" + " --direction=reverse --num-mismatches=5 --num-threads=4", + catch_exceptions=False, + ) + assert result.exit_code == 0 + lines = result.stdout.splitlines() + assert len(lines) == 1 + d = json.loads(lines[0]) + assert d["strain"] == strain + assert d["num_mismatches"] == 5 + assert d["direction"] == "reverse" + assert len(d["match"]["path"]) == 1 + assert len(d["match"]["mutations"]) == 5 + + class TestListDates: def test_defaults(self, fx_metadata_db): runner = ct.CliRunner(mix_stderr=False)