diff --git a/haptools/data/genotypes.py b/haptools/data/genotypes.py index 4d812d10..f6179f13 100644 --- a/haptools/data/genotypes.py +++ b/haptools/data/genotypes.py @@ -1543,14 +1543,39 @@ def _num_unique_alleles(self, arr: npt.NDArray): allele_cts[allele_cts < 2] = 2 return allele_cts - def write(self): + def write(self, dosages: npt.NDArray = None): """ Write the variants in this class to PLINK2 files at :py:attr:`~.GenotypesPLINK.fname` + + Parameters + ---------- + dosages : npt.NDArray[np.float32|np.float64] + A matrix of dosages of shape (num_variants, num_samples) and float dtype """ # write the psam and pvar files self.write_samples() self.write_variants() + if dosages is not None: + self.log.info("Writing dosages instead of genotypes") + if (dosages.shape[0] != self.data.shape[1]) and ( + dosages.shape[1] != self.data.shape[0] + ): + raise ValueError( + "Dosage array must be of shape num_variants x num_samples" + ) + pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8")) + with pgenlib.PgenWriter( + filename=bytes(str(self.fname), "utf8"), + sample_ct=len(self.samples), + variant_ct=len(self.variants), + allele_ct_limit=pv.get_max_allele_ct(), + nonref_flags=False, + hardcall_phase_present=True, + dosage_present=True, + ) as pgen: + pgen.append_dosages_batch(dosages) + return self.log.debug(f"Transposing genotype matrix of size {self.data.shape}") # transpose the data b/c pgenwriter expects things in "variant-major" order # (ie where variants are rows instead of samples) @@ -1822,7 +1847,7 @@ def _iterate( variant.data[:, :2][missing] = np.iinfo(np.uint8).max yield variant - def write(self): + def write(self, dosages: npt.NDArray = None): raise NotImplementedError def write_variants(self): diff --git a/tests/test_data.py b/tests/test_data.py index 36377ce3..ec1e39b2 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -658,6 +658,35 @@ def test_write_genotypes_biallelic(self): fname.with_suffix(".pvar").unlink() fname.unlink() + def test_write_genotypes_biallelic_dosage(self): + gts = self._get_fake_genotypes_plink() + + # convert to bool - we should be able to handle both + gts.data = gts.data.astype(np.bool_) + + fname = DATADIR / "test_write.pgen" + gts.fname = fname + fake_dosages = gts.data[:, :, :2].sum(axis=2) + gts.write(dosages=fake_dosages.T.copy().astype(np.float32)) + + new_gts = GenotypesPLINK(fname) + new_gts.read() + + # check that everything matches what we expected + np.testing.assert_allclose( + gts.data[:, :, :2].sum(axis=2), + new_gts.data[:, :, :2].sum(axis=2), + ) + assert gts.samples == new_gts.samples + for i in range(len(new_gts.variants)): + for col in ("chrom", "pos", "id", "alleles"): + assert gts.variants[col][i] == new_gts.variants[col][i] + + # clean up afterwards: delete the files we created + fname.with_suffix(".psam").unlink() + fname.with_suffix(".pvar").unlink() + fname.unlink() + def test_write_multiallelic(self): # Create fake multi-allelic genotype data to write to the PLINK file gts_multiallelic = self._get_fake_genotypes_multiallelic()