Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for dosages in Genotypes classes #266

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,14 +1543,39 @@ def _num_unique_alleles(self, arr: npt.NDArray):
allele_cts[allele_cts < 2] = 2
return allele_cts

def write(self):
def write(self, dosages: npt.NDArray = None):
"""
Write the variants in this class to PLINK2 files at
:py:attr:`~.GenotypesPLINK.fname`

Parameters
----------
dosages : npt.NDArray[np.float32|np.float64]
A matrix of dosages of shape (num_variants, num_samples) and float dtype
"""
# write the psam and pvar files
self.write_samples()
self.write_variants()
if dosages is not None:
self.log.info("Writing dosages instead of genotypes")
if (dosages.shape[0] != self.data.shape[1]) and (
dosages.shape[1] != self.data.shape[0]
):
raise ValueError(
"Dosage array must be of shape num_variants x num_samples"
)
pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8"))
with pgenlib.PgenWriter(
filename=bytes(str(self.fname), "utf8"),
sample_ct=len(self.samples),
variant_ct=len(self.variants),
allele_ct_limit=pv.get_max_allele_ct(),
nonref_flags=False,
hardcall_phase_present=True,
dosage_present=True,
) as pgen:
pgen.append_dosages_batch(dosages)
return
self.log.debug(f"Transposing genotype matrix of size {self.data.shape}")
# transpose the data b/c pgenwriter expects things in "variant-major" order
# (ie where variants are rows instead of samples)
Expand Down Expand Up @@ -1822,7 +1847,7 @@ def _iterate(
variant.data[:, :2][missing] = np.iinfo(np.uint8).max
yield variant

def write(self):
def write(self, dosages: npt.NDArray = None):
raise NotImplementedError

def write_variants(self):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,35 @@ def test_write_genotypes_biallelic(self):
fname.with_suffix(".pvar").unlink()
fname.unlink()

def test_write_genotypes_biallelic_dosage(self):
gts = self._get_fake_genotypes_plink()

# convert to bool - we should be able to handle both
gts.data = gts.data.astype(np.bool_)

fname = DATADIR / "test_write.pgen"
gts.fname = fname
fake_dosages = gts.data[:, :, :2].sum(axis=2)
gts.write(dosages=fake_dosages.T.copy().astype(np.float32))

new_gts = GenotypesPLINK(fname)
new_gts.read()

# check that everything matches what we expected
np.testing.assert_allclose(
gts.data[:, :, :2].sum(axis=2),
new_gts.data[:, :, :2].sum(axis=2),
)
assert gts.samples == new_gts.samples
for i in range(len(new_gts.variants)):
for col in ("chrom", "pos", "id", "alleles"):
assert gts.variants[col][i] == new_gts.variants[col][i]

# clean up afterwards: delete the files we created
fname.with_suffix(".psam").unlink()
fname.with_suffix(".pvar").unlink()
fname.unlink()

def test_write_multiallelic(self):
# Create fake multi-allelic genotype data to write to the PLINK file
gts_multiallelic = self._get_fake_genotypes_multiallelic()
Expand Down
Loading