Skip to content

Commit

Permalink
Merge pull request #351 from hyanwong/viz-subtree
Browse files Browse the repository at this point in the history
Show trees of Pango lineages
  • Loading branch information
jeromekelleher authored Oct 9, 2024
2 parents 6f3d69f + 016ca90 commit bdf0a94
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ dependencies = [
# FIXME
"tsinfer @ git+https://github.com/jeromekelleher/tsinfer.git@experimental-hmm",
"pyfaidx",
"tskit>=0.5.3",
# FIXME - reinstate when 0.5.9 is released
# "tskit>=0.5.9",
"tskit @ git+https://github.com/tskit-dev/tskit.git@main#subdirectory=python",
"tszip",
"pandas",
"numba",
Expand Down
128 changes: 128 additions & 0 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,134 @@ def plot_recombinants_per_day(self):
ax2.set_ylabel("Fraction of samples recombinant")
ax2.set_ylim(0, 0.01)

def plot_pango_lineage_subtree(
self,
lineage,
position=None,
collapse_tracked=None,
remove_clones=None,
*,
pack_untracked_polytomies=True,
time_scale="rank",
title=None,
mutation_labels=None,
size=None,
style="",
**kwargs
):
if position is None:
position = 21563 # pick the start of the spike
if size is None:
size = (1000, 1000)
if remove_clones:
# TODO
raise NotImplementedError("remove_clones not implemented")
# remove mutation times, so they get spaced evenly along a branch
tables = self.ts.dump_tables()
time = tables.mutations.time
time[:] = tskit.UNKNOWN_TIME
tables.mutations.time = time
ts = tables.tree_sequence()
tracked_nodes = self.pango_lineage_samples[lineage]
tree = ts.at(position, tracked_samples=tracked_nodes)
order = np.array(list(tskit.drawing._postorder_tracked_minlex_traversal(
tree, collapse_tracked=collapse_tracked)))
if title is None:
simplified_ts = ts.simplify(order[np.where(ts.nodes_flags[order] & tskit.NODE_IS_SAMPLE)[0]])
num_trees = simplified_ts.num_trees
tree_pos = simplified_ts.at(position).index
title = (
f"Sc2ts genealogy of {len(tracked_nodes)} {lineage} samples "
f"at position {position} (tree {tree_pos}/{num_trees})"
# f" --- file: " # TODO - show filename
)

# Find the actually shown nodes (i.e. if polytomies are packed, we may not
# see some tips. This is copied from tskit.drawing.SvgTree.assign_x_coordinates
shown_nodes = order
if pack_untracked_polytomies:
shown_nodes = []
untracked_children = collections.defaultdict(list)
prev = tree.virtual_root
for u in order:
parent = tree.parent(u)
assert parent != prev
if tree.parent(prev) != u: # is a tip
if tree.num_tracked_samples(u) == 0:
untracked_children[parent].append(u)
else:
shown_nodes.append(u)
else:
if len(untracked_children[u]) == 1:
# If only a single non-focal lineage, we might as well show it
for child in untracked_children[u]:
shown_nodes.append(child)
shown_nodes.append(u)
prev = u

if mutation_labels is None:
mutation_labels = collections.defaultdict(list)
multiple_mutations = []
reverted_mutations = []
use_mutations = np.where(np.isin(ts.mutations_node, shown_nodes))[0]
sites = ts.mutations_site[use_mutations]
for mut_id in use_mutations:
# TODO Viz the recurrent mutations
mut = ts.mutation(mut_id)
site = ts.site(mut.site)
if len(sites == site.id) > 1:
multiple_mutations.append(mut.id)
inherited_state = site.ancestral_state
if mut.parent >= 0:
parent = ts.mutation(mut.parent)
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(parent.parent).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
# Reverse map label name to mutation id, so we can count duplicates
label = f"{inherited_state}{int(site.position)}{mut.derived_state}"
mutation_labels[label].append(mut.id)
# If more than one mutation has the same label, add a prefix with the counts
mutation_labels = {
m_id: label + (f" ({i+1}/{len(ids)})" if len(ids) > 1 else "")
for label, ids in mutation_labels.items()
for i, m_id in enumerate(ids)
}
# some default styles
styles = [
"".join(f".n{u} > .sym {{fill: cyan}}" for u in tracked_nodes),
".lab.summary {font-size: 12px}",
".polytomy {font-size: 10px}",
".mut .lab {font-size: 10px}",
".y-axis .lab {font-size: 12px}",
".mut .lab {fill: darkred} .mut .sym {stroke: darkred} .background path {fill: white}"
]
if len(multiple_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in multiple_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in multiple_mutations)
styles.append(lab_css + "{fill: red}" + sym_css + "{stroke: red}")
if len(reverted_mutations) > 0:
lab_css = ", ".join(f".mut.m{m} .lab" for m in reverted_mutations)
sym_css = ", ".join(f".mut.m{m} .sym" for m in reverted_mutations)
styles.append(lab_css + "{fill: magenta}" + sym_css + "{stroke: magenta}")

return tree.draw_svg(
time_scale=time_scale,
y_axis=True,
x_axis=False,
title=title,
size=size,
order=order,
mutation_labels=mutation_labels,
all_edge_mutations=True,
symbol_size=4,
pack_untracked_polytomies=pack_untracked_polytomies,
style="".join(styles) + style,
**kwargs,
)

def get_sample_group_info(self, group_id):
samples = []

Expand Down

0 comments on commit bdf0a94

Please sign in to comment.