Skip to content

Commit

Permalink
Properly treat blank ancestral allele, and set "N" as the default "un…
Browse files Browse the repository at this point in the history
…known" state

Also document the class
  • Loading branch information
hyanwong committed Sep 9, 2024
1 parent 1d04fb8 commit d4c8f49
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 10 deletions.
34 changes: 32 additions & 2 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import sys
import tempfile
import warnings

import msprime
import numcodecs
Expand Down Expand Up @@ -626,12 +627,40 @@ def test_missing_ancestral_allele(tmp_path):
tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
def test_deliberate_ancestral_missingness(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
ancestral_allele = ds.variant_ancestral_allele.values
ancestral_allele[0] = "N"
ancestral_allele[1] = "n"
ds = ds.drop_vars(["variant_ancestral_allele"])
sgkit.save_dataset(ds, str(zarr_path) + ".tmp")
tsutil.add_array_to_dataset(
"variant_ancestral_allele",
ancestral_allele,
str(zarr_path) + ".tmp",
["variants"],
)
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
with warnings.catch_warnings():
warnings.simplefilter("error") # No warning raised if AA deliberately missing
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
inf_ts = tsinfer.infer(sd)
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
if i in [0, 1]:
assert inf_var.site.metadata == {"inference_type": "parsimony"}
else:
assert inf_var.site.ancestral_state == var.site.ancestral_state


@pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on Windows")
def test_ancestral_missingness(tmp_path):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
ancestral_allele = ds.variant_ancestral_allele.values
ancestral_allele[0] = "N"
ancestral_allele[2] = ""
ancestral_allele[11] = "-"
ancestral_allele[12] = "💩"
ancestral_allele[15] = "💩"
Expand All @@ -646,13 +675,14 @@ def test_ancestral_missingness(tmp_path):
ds = sgkit.load_dataset(str(zarr_path) + ".tmp")
with pytest.warns(
UserWarning,
match=r"not found in the variant_allele array for the 4 [\s\S]*'💩': 2",
match=r"not found in the variant_allele array for the 5 [\s\S]*'💩': 2",
):
sd = tsinfer.VariantData(str(zarr_path) + ".tmp", "variant_ancestral_allele")
inf_ts = tsinfer.infer(sd)
for i, (inf_var, var) in enumerate(zip(inf_ts.variants(), ts.variants())):
if i in [0, 11, 12, 15]:
if i in [0, 2, 11, 12, 15]:
assert inf_var.site.metadata == {"inference_type": "parsimony"}
assert inf_var.site.ancestral_state in var.site.alleles
else:
assert inf_var.site.ancestral_state == var.site.ancestral_state

Expand Down
61 changes: 53 additions & 8 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import sys
import threading
import warnings
from typing import Union # noqa: F401

import attr
import humanize
Expand Down Expand Up @@ -2293,6 +2294,48 @@ def populations(self):


class VariantData(SampleData):
"""
Class representing input variant data used for inference. This is
mostly a thin wrapper for a Zarr dataset storing information in
the VCF Zarr (.vcz) format, plus information specifing the ancestral allele
and (optional) data masks. It then provides various derived properties and
methods for accessing the data in a form suitable for inference.
.. note::
In the VariantData object, "samples" refer to the individuals in the dataset,
each of which can be of arbitrary ploidy. This is in contrast to ``tskit``,
in which each *haploid genome* is treated as a separate "sample". For example
in a diploid dataset, the inferred tree sequence returned at the end of
the inference process will have ``inferred_ts.num_samples`` equal to double
the number returned by ``VariantData.num_samples``.
:param str path: The path to the file containing the input dataset in VCF-Zarr
format.
:param Union(array, str) ancestral_allele: A numpy array of strings specifying
the ancestral alleles used in inference. This must be the same length as
the number of unmasked sites in the dataset. Alternatively, a single string
can be provided, giving the name of an array in the input dataset which contains
the ancestral alleles. Unknown ancestral alleles can be specified using "N".
Any ancestral alleles which do not match any of the known alleles at that site,
will be tallied, and a warning issued summarizing the unknown ancestral states.
:param Union(array, str) sample_mask: A numpy array of booleans specifying which
samples to mask out (exclude) from the dataset. Alternatively, a string
can be provided, giving the name of an array in the input dataset which contains
the sample mask. If ``None`` (default), all samples are included.
:param Union(array, str) site_mask: A numpy array of booleans specifying which
sites to mask out (exclude) from the dataset. Alternatively, a string
can be provided, giving the name of an array in the input dataset which contains
the site mask. If ``None`` (default), all sites are included.
:param Union(array, str) sites_time: A numpy array of floats specifying the relative
time of occurrence of the mutation to the derived state at each site. This must
be of the same length as the number of unmasked sites. Alternatively, a
string can be provided, giving the name of an array in the input dataset
which contains the site times. If ``None`` (default), the frequency of the
derived allele is used as a proxy for the time of occurrence: this is usually a
reasonable approximation to the relative order of ancestors used for inference.
Time values are ignored for sites not used in inference, such as singletons,
sites with more than two alleles, or sites with an unknown ancestral allele.
"""

FORMAT_NAME = "tsinfer-variant-data"
FORMAT_VERSION = (0, 1)
Expand Down Expand Up @@ -2412,16 +2455,18 @@ def __init__(
)
self._sites_ancestral_allele = self._sites_ancestral_allele.astype(str)
unknown_alleles = collections.Counter()
converted = np.zeros(self.num_sites, dtype=np.int8)
converted = np.full(self.num_sites, -1, dtype=np.int8)
for i, allele in enumerate(self._sites_ancestral_allele):
allele_index = -1
try:
allele_index = np.where(allele == self.sites_alleles[i])[0][0]
except IndexError:
unknown_alleles[allele] += 1
converted[i] = allele_index
if not (allele in {"", "N", "n"}): # All these must represent unknown
try:
converted[i] = np.where(allele == self.sites_alleles[i])[0][0]
continue
except IndexError:
pass
unknown_alleles[allele] += 1
deliberately_unknown = sum([unknown_alleles.get(c, 0) for c in ("N", "n")])
tot = sum(unknown_alleles.values())
if tot > 0:
if tot != deliberately_unknown:
frac_bad = tot / self.num_sites
frac_bad_per_type = [v / self.num_sites for v in unknown_alleles.values()]
summarise_unknown = [
Expand Down

0 comments on commit d4c8f49

Please sign in to comment.