Skip to content

Commit

Permalink
Merge pull request #360 from jeromekelleher/match-ui
Browse files Browse the repository at this point in the history
Match UI
  • Loading branch information
jeromekelleher authored Oct 11, 2024
2 parents d5b4ccf + 3ded540 commit fe64f87
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 60 deletions.
91 changes: 91 additions & 0 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,96 @@ def annotate_recombinants(
tszip.compress(ts, out_tsz)


@dataclasses.dataclass(frozen=True)
class HmmRun:
strain: str
num_mismatches: int
direction: str
match: sc2ts.HmmMatch

def asdict(self):
d = dataclasses.asdict(self)
d["match"] = dataclasses.asdict(self.match)
return d

def asjson(self):
return json.dumps(self.asdict())


@click.command()
@click.argument("alignments", type=click.Path(exists=True, dir_okay=False))
@click.argument("ts", type=click.Path(exists=True, dir_okay=False))
@click.argument("strains", nargs=-1)
@click.option("--num-mismatches", default=3, type=int, help="num-mismatches")
@click.option(
"--direction",
type=click.Choice(["forward", "reverse"]),
default="forward",
help="Direction to run HMM in",
)
@click.option(
"--num-threads",
default=0,
type=int,
help="Number of match threads (default to one)",
)
@click.option("--progress/--no-progress", default=True)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def run_match(
alignments,
ts,
strains,
num_mismatches,
direction,
num_threads,
progress,
verbose,
log_file,
):
"""
Run matches for a specified set of strains, outputting details to stdout as JSON.
"""
setup_logging(verbose, log_file)
ts = tszip.load(ts)
if len(strains) == 0:
return
progress_title = "Match"
samples = sc2ts.preprocess(
list(strains),
alignments,
show_progress=progress,
progress_title=progress_title,
keep_sites=ts.sites_position.astype(int),
num_workers=num_threads,
)
for sample in samples:
if sample.haplotype is None:
raise ValueError(f"No alignment stored for {sample.strain}")
mu, rho = sc2ts.solve_num_mismatches(num_mismatches)
matches = sc2ts.match_tsinfer(
samples=samples,
ts=ts,
mu=mu,
rho=rho,
num_threads=num_threads,
show_progress=progress,
progress_title=progress_title,
progress_phase="HMM",
# Maximum possible precision
likelihood_threshold=1e-200,
mirror_coordinates=direction == "reverse",
)
for hmm_match, sample in zip(matches, samples):
run = HmmRun(
strain=sample.strain,
num_mismatches=num_mismatches,
direction=direction,
match=hmm_match,
)
print(run.asjson())


@click.version_option(core.__version__)
@click.group()
def cli():
Expand All @@ -730,4 +820,5 @@ def cli():
cli.add_command(extend)
cli.add_command(validate)
cli.add_command(annotate_recombinants)
cli.add_command(run_match)
cli.add_command(tally_lineages)
117 changes: 63 additions & 54 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dataclasses
import collections
import concurrent.futures
import json
import pickle
import hashlib
import sqlite3
Expand All @@ -31,15 +32,15 @@
DELETION = core.ALLELES.index("-")


def get_progress(iterable, date, phase, show_progress, total=None):
def get_progress(iterable, title, phase, show_progress, total=None):
bar_format = (
"{desc:<22}{percentage:3.0f}%|{bar}"
"| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
)
return tqdm.tqdm(
iterable,
total=total,
desc=f"{date}:{phase}",
desc=f"{title}:{phase}",
disable=not show_progress,
bar_format=bar_format,
dynamic_ncols=True,
Expand All @@ -56,7 +57,11 @@ def __init__(self, date, phase, *args, **kwargs):

def get(self, key, total):
self.current_instance = get_progress(
None, self.date, phase=self.phase, show_progress=self.enabled, total=total
None,
title=self.date,
phase=self.phase,
show_progress=self.enabled,
total=total,
)
return self.current_instance

Expand Down Expand Up @@ -432,8 +437,8 @@ def match_samples(
likelihood_threshold=likelihood_threshold,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match({k})",
progress_title=date,
progress_phase=f"match({k})",
)

exceeding_threshold = []
Expand Down Expand Up @@ -462,10 +467,9 @@ def match_samples(
rho=rho,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match(F)",
progress_title=date,
progress_phase=f"match(F)",
)
recombinants = []
for sample, hmm_match in zip(run_batch, hmm_matches):
sample.hmm_match = hmm_match
cost = hmm_match.get_hmm_cost(num_mismatches)
Expand All @@ -482,14 +486,13 @@ def check_base_ts(ts):
assert len(sc2ts_md["samples_strain"]) == ts.num_samples


def preprocess_worker(samples_md, alignment_store_path, keep_sites):
# print("preprocess worker", samples_md)
def preprocess_worker(strains, alignment_store_path, keep_sites):
assert keep_sites is not None
with alignments.AlignmentStore(alignment_store_path) as alignment_store:
samples = []
for md in samples_md:
strain = md["strain"]
for strain in strains:
alignment = alignment_store.get(strain, None)
sample = Sample(strain, metadata=md)
sample = Sample(strain)
if alignment is not None:
a = alignment[keep_sites]
sample.haplotype = alignments.encode_alignment(a)
Expand All @@ -501,47 +504,28 @@ def preprocess_worker(samples_md, alignment_store_path, keep_sites):


def preprocess(
samples_md,
base_ts,
date,
alignment_store,
pango_lineage_key="pango",
strains,
alignment_store_path,
*,
keep_sites,
progress_title="",
show_progress=False,
max_missing_sites=np.inf,
num_workers=0,
):
num_workers = max(1, num_workers)
keep_sites = base_ts.sites_position.astype(int)
splits = min(len(samples_md), 2 * num_workers)
work = np.array_split(samples_md, splits)
splits = min(len(strains), 2 * num_workers)
work = np.array_split(strains, splits)
samples = []

bar = get_progress(samples_md, date, f"preprocess", show_progress)
bar = get_progress(strains, progress_title, "preprocess", show_progress)
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [
executor.submit(preprocess_worker, w, alignment_store.path, keep_sites)
executor.submit(preprocess_worker, w, alignment_store_path, keep_sites)
for w in work
]
for future in concurrent.futures.as_completed(futures):
for s in future.result():
bar.update()
if s.haplotype is None:
logger.debug(f"No alignment stored for {s.strain}")
continue
s.date = date
s.pango = s.metadata.get(pango_lineage_key, "PangoUnknown")
num_missing_sites = s.num_missing_sites
num_deletion_sites = s.num_deletion_sites
logger.debug(
f"Encoded {s.strain} {s.pango} missing={num_missing_sites} "
f"deletions={num_deletion_sites}"
)
if num_missing_sites <= max_missing_sites:
samples.append(s)
else:
logger.debug(
f"Filter {s.strain}: missing={num_missing_sites} > {max_missing_sites}"
)
samples.append(s)
bar.close()
return samples

Expand Down Expand Up @@ -587,16 +571,37 @@ def extend(

logger.info(f"Got {len(metadata_matches)} metadata matches")

samples = preprocess(
metadata_matches,
base_ts,
date,
alignment_store,
pango_lineage_key="Viridian_pangolin", # TODO parametrise
preprocessed_samples = preprocess(
strains=[md["strain"] for md in metadata_matches],
alignment_store_path=alignment_store.path,
keep_sites=base_ts.sites_position.astype(int),
progress_title=date,
show_progress=show_progress,
max_missing_sites=max_missing_sites,
num_workers=num_threads,
)
# FIXME parametrise
pango_lineage_key = "Viridian_pangolin"

samples = []
for s, md in zip(preprocessed_samples, metadata_matches):
if s.haplotype is None:
logger.debug(f"No alignment stored for {s.strain}")
continue
s.metadata = md
s.pango = md.get(pango_lineage_key, "Unknown")
s.date = date
num_missing_sites = s.num_missing_sites
num_deletion_sites = s.num_deletion_sites
logger.debug(
f"Encoded {s.strain} {s.pango} missing={num_missing_sites} "
f"deletions={num_deletion_sites}"
)
if num_missing_sites <= max_missing_sites:
samples.append(s)
else:
logger.debug(
f"Filter {s.strain}: missing={num_missing_sites} > {max_missing_sites}"
)

if max_daily_samples is not None:
if max_daily_samples < len(samples):
Expand Down Expand Up @@ -1106,8 +1111,8 @@ def match_tsinfer(
likelihood_threshold=None,
num_threads=0,
show_progress=False,
date=None,
phase=None,
progress_title=None,
progress_phase=None,
mirror_coordinates=False,
):
if len(samples) == 0:
Expand All @@ -1128,7 +1133,7 @@ def match_tsinfer(
# we're interested in solving for exactly.
likelihood_threshold = rho**2 * mu**5

pm = TsinferProgressMonitor(date, phase, enabled=show_progress)
pm = TsinferProgressMonitor(progress_title, progress_phase, enabled=show_progress)

# This is just working around tsinfer's input checking logic. The actual value
# we're incrementing by has no effect.
Expand Down Expand Up @@ -1262,6 +1267,7 @@ def mutation_summary(self):
return "[" + ", ".join(str(mutation) for mutation in self.mutations) + "]"



def get_match_info(ts, sample_paths, sample_mutations):
tables = ts.tables
assert np.all(tables.sites.ancestral_state_offset == np.arange(ts.num_sites + 1))
Expand All @@ -1288,7 +1294,10 @@ def get_closest_mutation(node, site_id):

matches = []
for path, mutations in zip(sample_paths, sample_mutations):
sample_path = [PathSegment(*seg) for seg in path]
sample_path = [
PathSegment(int(left), int(right), int(parent))
for left, right, parent in path
]
sample_mutations = []
for site_pos, derived_state in mutations:
site_id = np.searchsorted(ts.sites_position, site_pos)
Expand All @@ -1311,7 +1320,7 @@ def get_closest_mutation(node, site_id):
assert inherited_state != derived_state
sample_mutations.append(
MatchMutation(
site_id=site_id,
site_id=int(site_id),
site_position=int(site_pos),
derived_state=derived_state,
inherited_state=inherited_state,
Expand Down
Loading

0 comments on commit fe64f87

Please sign in to comment.