Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show trees of Pango lineages #351

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ dependencies = [
# FIXME
"tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm",
"pyfaidx",
"tskit>=0.5.3",
# FIXME - reinstate when 0.5.9 is released
# "tskit>=0.5.9",
"tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python",
"tszip",
"pandas",
"numba",
Expand Down
128 changes: 128 additions & 0 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,134 @@ def plot_recombinants_per_day(self):
ax2.set_ylabel("Fraction of samples recombinant")
ax2.set_ylim(0, 0.01)

def plot_pango_lineage_subtree(
self,
lineage,
position=None,
collapse_tracked=None,
remove_clones=None,
*,
pack_untracked_polytomies=True,
time_scale="rank",
title=None,
mutation_labels=None,
size=None,
style="",
**kwargs
):
if position is None:
position = 21563 # pick the start of the spike
if size is None:
size = (1000, 1000)
if remove_clones:
# TODO
raise NotImplementedError("remove_clones not implemented")
# remove mutation times, so they get spaced evenly along a branch
tables = self.ts.dump_tables()
time = tables.mutations.time
time[:] = tskit.UNKNOWN_TIME
tables.mutations.time = time
ts = tables.tree_sequence()
tracked_nodes = self.pango_lineage_samples[lineage]
tree = ts.at(position, tracked_samples=tracked_nodes)
order = np.array(list(tskit.drawing._postorder_tracked_minlex_traversal(
tree, collapse_tracked=collapse_tracked)))
if title is None:
simplified_ts = ts.simplify(order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]])
num_trees = simplified_ts.num_trees
tree_pos = simplified_ts.at(position).index
title = (
f"Sc2ts genealogy of {len(tracked_nodes)} {lineage} samples "
f"at position {position} (tree {tree_pos}/{num_trees})"
# f" --- file: " # TODO - show filename
)

# Find the actually shown nodes (i.e. if polytomies are packed, we may not
# see some tips. This is copied from tskit.drawing.SvgTree.assign_x_coordinates
shown_nodes = order
if pack_untracked_polytomies:
shown_nodes = []
untracked_children = collections.defaultdict(list)
prev = tree.virtual_root
for u in order:
parent = tree.parent(u)
assert parent != prev
if tree.parent(prev) != u: # is a tip
if tree.num_tracked_samples(u) == 0:
untracked_children[parent].append(u)
else:
shown_nodes.append(u)
else:
if len(untracked_children[u]) == 1:
# If only a single non-focal lineage, we might as well show it
for child in untracked_children[u]:
shown_nodes.append(child)
shown_nodes.append(u)
prev = u

if mutation_labels is None:
mutation_labels = collections.defaultdict(list)
multiple_mutations = []
reverted_mutations = []
use_mutations = np.where(np.isin(ts.mutations_node, shown_nodes))[0]
sites = ts.mutations_site[use_mutations]
for mut_id in use_mutations:
# TODO Viz the recurrent mutations
mut = ts.mutation(mut_id)
site = ts.site(mut.site)
if len(sites == site.id) > 1:
multiple_mutations.append(mut.id)
inherited_state = site.ancestral_state
if mut.parent >= 0:
parent = ts.mutation(mut.parent)
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(parent.parent).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
# Reverse map label name to mutation id, so we can count duplicates
label = f"{inherited_state}{int(site.position)}{mut.derived_state}"
mutation_labels[label].append(mut.id)
# If more than one mutation has the same label, add a prefix with the counts
mutation_labels = {
m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "")
for label, ids in mutation_labels.items()
for i, m_id in enumerate(ids)
}
# some default styles
styles = [
"".join(f".n{u} > .sym {{fill: cyan}}" for u in tracked_nodes),
".lab.summary {font-size: 12px}",
".polytomy {font-size: 10px}",
".mut .lab {font-size: 10px}",
".y-axis .lab {font-size: 12px}",
".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"
]
if len(multiple_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in multiple_mutations)
styles.append(lab_css + "{fill: red}" + sym_css + "{stroke: red}")
if len(reverted_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in reverted_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in reverted_mutations)
styles.append(lab_css + "{fill: magenta}" + sym_css + "{stroke: magenta}")

return tree.draw_svg(
time_scale=time_scale,
y_axis=True,
x_axis=False,
title=title,
size=size,
order=order,
mutation_labels=mutation_labels,
all_edge_mutations=True,
symbol_size=4,
pack_untracked_polytomies=pack_untracked_polytomies,
style="".join(styles) + style,
**kwargs,
)

def get_sample_group_info(self, group_id):
samples = []

Expand Down