Skip to content

Commit

Permalink
fix: 1) improve error in transform when providing an empty hap file…
Browse files Browse the repository at this point in the history
… and a `--region` and 2) allow for calling `write()` on Genotypes objects without variants (#264)
  • Loading branch information
aryarm authored Dec 11, 2024
1 parent 09a916a commit 4e84178
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 33 deletions.
21 changes: 16 additions & 5 deletions .devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,31 @@
// "forwardPorts": [],

// Use 'postCreateCommand' to run commands after the container is created.
"postCreateCommand": "mamba env create -n haptools -f dev-env.yml && conda run -n haptools poetry install",
"postCreateCommand": "mamba env create -n haptools -f dev-env.yml && conda run -n haptools poetry config virtualenvs.in-project true && conda run -n haptools poetry install",

// Configure tool-specific properties.
"customizations": {
"vscode": {
"extensions": ["ms-python.python"],
"extensions": [
"ms-python.python",
"ms-python.black-formatter"
],
"settings": {
"python.analysis.typeCheckingMode": "off", // TODO: set to "strict"
"python.condaPath": "/opt/conda/condabin/conda",
"python.defaultInterpreterPath": "/opt/conda/envs/haptools/bin/python",
"python.terminal.activateEnvironment": true,
"python.terminal.activateEnvInCurrentTerminal": true,
"python.venvFolders": ["/home/vscode/.cache/pypoetry/virtualenvs"],
"python.venvPath": "/workspaces/haptools/.venv",
"python.defaultInterpreterPath": "/workspaces/haptools/.venv/bin/python",
"python.testing.pytestArgs": [
"tests"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"terminal.integrated.environmentChangesRelaunch": true,
"terminal.integrated.hideOnStartup": "always"
"editor.defaultFormatter": "ms-python.black-formatter",
"terminal.integrated.hideOnStartup": "always",
"files.eol": "\n"
}
}
}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
auto-activate-base: false
miniforge-version: latest
use-mamba: true
conda-remove-defaults: "true"

- name: Get Date
id: get-date
Expand Down
2 changes: 1 addition & 1 deletion haptools/clump.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .data import Genotypes, GenotypesVCF, GenotypesTR, GenotypesPLINKTR
from .data import Genotypes, GenotypesVCF, GenotypesPLINK, GenotypesTR, GenotypesPLINKTR


class Variant:
Expand Down
2 changes: 1 addition & 1 deletion haptools/data/breakpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# This tuple lists the haplotype blocks in a sample, one set for each chromosome
# Let's define a type alias, "SampleBlocks", for future use...
SampleBlocks = NewType(
"SampleBlocks", "list[npt.NDArray[HapBlock], npt.NDArray[HapBlock]]]"
"SampleBlocks", "list[npt.NDArray[HapBlock], npt.NDArray[HapBlock]]]" # type: ignore
)


Expand Down
58 changes: 43 additions & 15 deletions haptools/data/genotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cyvcf2 import VCF, Variant

try:
import trtools.utils.tr_harmonizer as trh
import trtools.utils.tr_harmonizer as trh # type: ignore
except ModuleNotFoundError:
from . import tr_harmonizer as trh

Expand Down Expand Up @@ -219,7 +219,7 @@ def _variant_arr(self, record: Variant):
Parameters
----------
record: Variant
A cyvcf2.Variant object from which to fetch metadata
A Variant object from which to fetch metadata
Returns
-------
Expand All @@ -231,20 +231,20 @@ def _variant_arr(self, record: Variant):
dtype=self.variants.dtype,
)

def _vcf_iter(self, vcf: cyvcf2.VCF, region: str):
def _vcf_iter(self, vcf: VCF, region: str):
"""
Yield all variants within a region in the VCF file.
Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
The VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
Returns
-------
vcffile : cyvcf2.VCF
vcffile : VCF
Iterable cyvcf2 instance.
"""
return vcf(region)
Expand All @@ -255,8 +255,8 @@ def _return_data(self, variant: Variant):
Parameters
----------
variant: cyvcf2.Variant
A cyvcf2.Variant object from which to fetch genotypes
variant: Variant
A Variant object from which to fetch genotypes
Returns
-------
Expand All @@ -274,7 +274,7 @@ def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None):
Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
The VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
variants : set[str], optional
Expand Down Expand Up @@ -805,7 +805,13 @@ def write(self):
record.samples[sample].phased = self.data[samp_idx, var_idx, 2]
# write the record to a file
vcf.write(record)
vcf.close()
try:
vcf.close()
except OSError as e:
if e.errno == 9 and len(self.variants) == 0:
self.log.warning(f"No variants in {self.fname}.")
else:
raise e


class TRRecordHarmonizerRegion(trh.TRRecordHarmonizer):
Expand Down Expand Up @@ -909,14 +915,14 @@ def load(
genotypes.check_phase()
return genotypes

def _vcf_iter(self, vcf: cyvcf2.VCF, region: str = None):
def _vcf_iter(self, vcf: VCF, region: str = None):
"""
Collect GTs (trh.TRRecord objects) to iterate over
Parameters
----------
vcf: VCF
The cyvcf2.VCF object from which to fetch variant records
The VCF object from which to fetch variant records
region : str, optional
See documentation for :py:meth:`~.Genotypes.read`
Expand Down Expand Up @@ -1066,7 +1072,7 @@ def read_samples(self, samples: set[str] = None):
self.samples = {
ct: samp[col_idx]
for ct, samp in enumerate(psamples)
if (samples is None) or (samp[col_idx] in samples)
if len(samp) and ((samples is None) or (samp[col_idx] in samples))
}
indices = np.array(list(self.samples.keys()), dtype=np.uint32)
self.samples = tuple(self.samples.values())
Expand Down Expand Up @@ -1289,7 +1295,18 @@ def read(
super(Genotypes, self).read()

sample_idxs = self.read_samples(samples)
pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8"))
pvar_fname = bytes(str(self.fname.with_suffix(".pvar")), "utf8")
try:
pv = pgenlib.PvarReader(pvar_fname)
except RuntimeError as e:
if e.args[0].decode("utf8").startswith("No variants in"):
self.log.warning(f"No variants in {pvar_fname}.")
self.data = np.empty(
(len(sample_idxs), 0, (2 + (not self._prephased))), dtype=np.uint8
)
return
else:
raise e

with pgenlib.PgenReader(
bytes(str(self.fname), "utf8"), sample_subset=sample_idxs, pvar=pv
Expand Down Expand Up @@ -1544,12 +1561,23 @@ def write(self):
chunks = len(self.variants)

# write the pgen file
pv = pgenlib.PvarReader(bytes(str(self.fname.with_suffix(".pvar")), "utf8"))
try:
max_allele_ct = pgenlib.PvarReader(
bytes(str(self.fname.with_suffix(".pvar")), "utf8")
).get_max_allele_ct()
except RuntimeError as e:
if len(self.variants) == 0:
# write an empty pgen file
with open(self.fname, "wb"):
pass
return
else:
raise e
with pgenlib.PgenWriter(
filename=bytes(str(self.fname), "utf8"),
sample_ct=len(self.samples),
variant_ct=len(self.variants),
allele_ct_limit=pv.get_max_allele_ct(),
allele_ct_limit=max_allele_ct,
nonref_flags=False,
hardcall_phase_present=True,
) as pgen:
Expand Down
20 changes: 12 additions & 8 deletions haptools/data/haplotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def extras_order(cls) -> tuple[str]:
"""
return tuple(extra.name for extra in cls._extras)

def transform(self, genotypes: GenotypesVCF) -> npt.NDArray[bool]:
def transform(self, genotypes: GenotypesVCF) -> npt.NDArray:
"""
Transform a genotypes matrix via the current haplotype
Expand All @@ -478,9 +478,9 @@ def transform(self, genotypes: GenotypesVCF) -> npt.NDArray[bool]:
Returns
-------
npt.NDArray[bool]
A 2D matrix of shape (num_samples, 2) where each entry in the matrix
denotes the presence of the haplotype in one chromosome of a sample
npt.NDArray
A 2D matrix of shape (num_samples, 2) where each entry in the matrix is a
bool denoting the presence of the haplotype in one chromosome of a sample
"""
var_IDs = self.varIDs
# ensure the variants in the Genotypes object are ordered according to var_IDs
Expand Down Expand Up @@ -1198,12 +1198,16 @@ def __iter__(
indexed = True
try:
haps_file = TabixFile(str(self.fname))
if region is not None:
haps_file.fetch(region=region, multiple_iterators=True)
except OSError:
indexed = False
# if the user requested a specific region or subset of haplotypes and the file
except ValueError:
indexed = False
# If the user requested a specific region or subset of haplotypes and the file
# is indexed, then we should handle it using tabix
# else, we use a regular text opener - b/c there's no benefit to using tabix
if region or (haplotypes and indexed):
if (region or haplotypes) and indexed:
haps_file = TabixFile(str(self.fname))
metas, extras = self.check_header(list(haps_file.header))
types = self._get_field_types(extras, metas.get("order"))
Expand Down Expand Up @@ -1232,8 +1236,8 @@ def __iter__(
)
haps_file.close()
else:
# the file is not indexed, so we can't assume it's sorted, either
# use hook_compressed to automatically handle gz files
# The file is not indexed, so we can't assume it's sorted, either
# Use hook_compressed to automatically handle gz files
with self.hook_compressed(self.fname, mode="r") as haps:
self.log.info("Not taking advantage of indexing.")
header_lines = []
Expand Down
1 change: 1 addition & 0 deletions haptools/data/tr_harmonizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Handles VCFs generated by various TR genotyping tools
"""
import re
import math
import enum
import warnings
from typing import (
Expand Down
2 changes: 2 additions & 0 deletions haptools/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ def transform_haps(
"that the IDs in your .hap file correspond with those you provided. "
f"Here are the first few missing haplotypes: {diff[:first_few]}"
)
if len(hp.data) == 0:
raise ValueError("Didn't load any haplotypes from the .hap file")

log.info("Extracting variants from haplotypes")
variants = {vr.id for id in hp.type_ids["H"] for vr in hp.data[id].variants}
Expand Down
6 changes: 3 additions & 3 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import shutil
from pathlib import Path

import nox
from nox_poetry import Session
from nox_poetry import session
import nox # type: ignore
from nox_poetry import Session # type: ignore
from nox_poetry import session # type: ignore


package = "haptools"
Expand Down
1 change: 1 addition & 0 deletions tests/bench_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from datetime import datetime

import click
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

Expand Down
24 changes: 24 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,20 @@ def test_write_genotypes(self):
fname.with_suffix(".pvar").unlink()
fname.unlink()

def test_write_genotypes_empty(self):
fname = DATADIR / "test_write.pgen"
gts = GenotypesPLINK(fname=fname)
gts.data = np.empty((0, 0, 0), dtype=np.uint8)
gts.samples = ()
gts.variants = np.empty(0, dtype=gts.variants.dtype)
gts.write()
gts.read()

# clean up afterwards: delete the files we created
fname.with_suffix(".psam").unlink()
fname.with_suffix(".pvar").unlink()
fname.unlink()

def test_write_genotypes_prephased(self):
gts = self._get_fake_genotypes_plink()

Expand Down Expand Up @@ -1851,6 +1865,16 @@ def test_write_ref_alt(self, multiallelic=False):

fname.unlink()

def test_write_empty(self):
fname = Path("test.vcf")
gts = GenotypesVCF(fname=fname)
gts.samples = ()
gts.variants = np.array([], dtype=gts.variants.dtype)
gts.data = np.empty((0, 0, 0), dtype=np.uint8)
gts.write()
gts.read()
fname.unlink()

def test_write_multiallelic(self):
self.test_write_ref_alt(multiallelic=True)

Expand Down
39 changes: 39 additions & 0 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,3 +371,42 @@ def test_ancestry_from_bp(capfd):
captured = capfd.readouterr()
assert captured.out == ancestry_results
assert result.exit_code == 0


def test_transform_empty_hap(capfd):
gt_file = DATADIR / "simple.vcf.gz"
hp_file = Path("empty.hap")
hp_file_gz = Path("empty.hap.gz")
hp_file_idx = Path("empty.hap.gz.tbi")

# create an empty .hap file
with open(hp_file, "w") as f:
f.write("")

# can we run transform with the empty hap file?
cmd = f"transform --region 1:10116-10122 {gt_file} {hp_file}"
runner = CliRunner()
result = runner.invoke(main, cmd.split(" "))
captured = capfd.readouterr()
assert all(line for line in captured.out.split("\n") if line.startswith("#"))
assert result.exit_code != 0

# now, index the empty hap file and try again
cmd = f"index {hp_file}"
runner = CliRunner()
result = runner.invoke(main, cmd.split(" "), catch_exceptions=False)
captured = capfd.readouterr()
assert result.exit_code == 0
assert hp_file_gz.exists()
assert hp_file_idx.exists()

# what about now? does it still fail?
cmd = f"transform --region 1:10116-10122 {gt_file} {hp_file_gz}"
runner = CliRunner()
result = runner.invoke(main, cmd.split(" "))
captured = capfd.readouterr()
assert result.exit_code != 0

hp_file.unlink()
hp_file_gz.unlink()
hp_file_idx.unlink()

0 comments on commit 4e84178

Please sign in to comment.