Skip to content

Commit

Permalink
Finalise and test match UI
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Oct 11, 2024
1 parent 72e92f2 commit 3ded540
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 18 deletions.
71 changes: 64 additions & 7 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,19 +710,68 @@ def annotate_recombinants(
tszip.compress(ts, out_tsz)


@dataclasses.dataclass(frozen=True)
class HmmRun:
strain: str
num_mismatches: int
direction: str
match: sc2ts.HmmMatch

def asdict(self):
d = dataclasses.asdict(self)
d["match"] = dataclasses.asdict(self.match)
return d

def asjson(self):
return json.dumps(self.asdict())


@click.command()
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("strains", nargs=-1)
@click.option("--num-mismatches", default=3, type=float, help="num-mismatches")
@click.option("--num-mismatches", default=3, type=int, help="num-mismatches")
@click.option(
"--direction",
type=click.Choice(["forward", "reverse"]),
default="forward",
help="Direction to run HMM in",
)
@click.option(
"--num-threads",
default=0,
type=int,
help="Number of match threads (default to one)",
)
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
def run_match(alignments, ts, strains, num_mismatches, verbose):
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def run_match(
alignments,
ts,
strains,
num_mismatches,
direction,
num_threads,
progress,
verbose,
log_file,
):
"""
Run matches for a specified set of strains, outputting details to stdout as JSON.
"""
setup_logging(verbose, log_file)
ts = tszip.load(ts)
if len(strains) == 0:
return
progress_title = "Match"
samples = sc2ts.preprocess(
list(strains), "xx", alignments, keep_sites=ts.sites_position.astype(int)
list(strains),
alignments,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
num_workers=num_threads,
)
for sample in samples:
if sample.haplotype is None:
Expand All @@ -733,14 +782,22 @@ def run_match(alignments, ts, strains, num_mismatches, verbose):
ts=ts,
mu=mu,
rho=rho,
# num_threads=num_threads,
show_progress=True,
num_threads=num_threads,
show_progress=progress,
progress_title=progress_title,
progress_phase="HMM",
# Maximum possible precision
likelihood_threshold=1e-200,
# mirror_coordinates=hmm_pass == "reverse",
mirror_coordinates=direction == "reverse",
)
for hmm_match, sample in zip(matches, samples):
print(sample.strain, hmm_match.summary())
run = HmmRun(
strain=sample.strain,
num_mismatches=num_mismatches,
direction=direction,
match=hmm_match,
)
print(run.asjson())


@click.version_option(core.__version__)
Expand Down
30 changes: 19 additions & 11 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dataclasses
import collections
import concurrent.futures
import json
import pickle
import hashlib
import sqlite3
Expand Down Expand Up @@ -56,7 +57,11 @@ def __init__(self, date, phase, *args, **kwargs):

def get(self, key, total):
self.current_instance = get_progress(
None, title=self.date, phase=self.phase, show_progress=self.enabled, total=total
None,
title=self.date,
phase=self.phase,
show_progress=self.enabled,
total=total,
)
return self.current_instance

Expand Down Expand Up @@ -432,8 +437,8 @@ def match_samples(
likelihood_threshold=likelihood_threshold,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match({k})",
progress_title=date,
progress_phase=f"match({k})",
)

exceeding_threshold = []
Expand Down Expand Up @@ -462,10 +467,9 @@ def match_samples(
rho=rho,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match(F)",
progress_title=date,
progress_phase=f"match(F)",
)
recombinants = []
for sample, hmm_match in zip(run_batch, hmm_matches):
sample.hmm_match = hmm_match
cost = hmm_match.get_hmm_cost(num_mismatches)
Expand Down Expand Up @@ -1107,8 +1111,8 @@ def match_tsinfer(
likelihood_threshold=None,
num_threads=0,
show_progress=False,
date=None,
phase=None,
progress_title=None,
progress_phase=None,
mirror_coordinates=False,
):
if len(samples) == 0:
Expand All @@ -1129,7 +1133,7 @@ def match_tsinfer(
# we're interested in solving for exactly.
likelihood_threshold = rho**2 * mu**5

pm = TsinferProgressMonitor(date, phase, enabled=show_progress)
pm = TsinferProgressMonitor(progress_title, progress_phase, enabled=show_progress)

# This is just working around tsinfer's input checking logic. The actual value
# we're incrementing by has no effect.
Expand Down Expand Up @@ -1263,6 +1267,7 @@ def mutation_summary(self):
return "[" + ", ".join(str(mutation) for mutation in self.mutations) + "]"



def get_match_info(ts, sample_paths, sample_mutations):
tables = ts.tables
assert np.all(tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1))
Expand All @@ -1289,7 +1294,10 @@ def get_closest_mutation(node, site_id):

matches = []
for path, mutations in zip(sample_paths, sample_mutations):
sample_path = [PathSegment(*seg) for seg in path]
sample_path = [
PathSegment(int(left), int(right), int(parent))
for left, right, parent in path
]
sample_mutations = []
for site_pos, derived_state in mutations:
site_id = np.searchsorted(ts.sites_position, site_pos)
Expand All @@ -1312,7 +1320,7 @@ def get_closest_mutation(node, site_id):
assert inherited_state != derived_state
sample_mutations.append(
MatchMutation(
site_id=site_id,
site_id=int(site_id),
site_position=int(site_pos),
derived_state=derived_state,
inherited_state=inherited_state,
Expand Down
71 changes: 71 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,77 @@ def test_provenance(self, tmp_path):
assert "max_memory" in resources


class TestRunMatch:

def test_single_defaults(self, tmp_path, fx_ts_map, fx_alignment_store):
strain = "ERR4206593"
ts = fx_ts_map["2020-02-04"]
ts_path = tmp_path / "ts.ts"
ts.dump(ts_path)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"run-match {fx_alignment_store.path} {ts_path} {strain}",
catch_exceptions=False,
)
assert result.exit_code == 0
lines = result.stdout.splitlines()
assert len(lines) == 1
d = json.loads(lines[0])
assert d["strain"] == strain
assert d["num_mismatches"] == 3
assert d["direction"] == "forward"
assert len(d["match"]["path"]) == 1
assert len(d["match"]["mutations"]) == 5

def test_multi_defaults(self, tmp_path, fx_ts_map, fx_alignment_store):
copies = 10
strains = ["ERR4206593"] * 10
ts = fx_ts_map["2020-02-13"]
ts_path = tmp_path / "ts.ts"
ts.dump(ts_path)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"run-match {fx_alignment_store.path} {ts_path} " + " ".join(strains),
catch_exceptions=False,
)
assert result.exit_code == 0
lines = result.stdout.splitlines()
assert len(lines) == copies
d = json.loads(lines[0])
assert d["strain"] == strains[0]
assert d["num_mismatches"] == 3
assert d["direction"] == "forward"
assert len(d["match"]["path"]) == 1
assert len(d["match"]["mutations"]) == 0
for line in lines[1:]:
d2 = json.loads(line)
assert d == d2

def test_single_options(self, tmp_path, fx_ts_map, fx_alignment_store):
strain = "ERR4206593"
ts = fx_ts_map["2020-02-04"]
ts_path = tmp_path / "ts.ts"
ts.dump(ts_path)
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"run-match {fx_alignment_store.path} {ts_path} {strain}"
" --direction=reverse --num-mismatches=5 --num-threads=4",
catch_exceptions=False,
)
assert result.exit_code == 0
lines = result.stdout.splitlines()
assert len(lines) == 1
d = json.loads(lines[0])
assert d["strain"] == strain
assert d["num_mismatches"] == 5
assert d["direction"] == "reverse"
assert len(d["match"]["path"]) == 1
assert len(d["match"]["mutations"]) == 5


class TestListDates:
def test_defaults(self, fx_metadata_db):
runner = ct.CliRunner(mix_stderr=False)
Expand Down

0 comments on commit 3ded540

Please sign in to comment.