Skip to content

Commit

Permalink
Fixup local alleles interface
Browse files Browse the repository at this point in the history
- Add option to top-level convert function
- Change so there is a single source of truth for the default
  • Loading branch information
jeromekelleher committed Jul 11, 2024
1 parent d1e3e09 commit 0768962
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 6 deletions.
3 changes: 3 additions & 0 deletions bio2zarr/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def dencode_finalise(zarr_path, verbose, progress):
@verbose
@progress
@worker_processes
@local_alleles
def convert_vcf(
vcfs,
zarr_path,
Expand All @@ -491,6 +492,7 @@ def convert_vcf(
verbose,
progress,
worker_processes,
local_alleles,
):
"""
Convert input VCF(s) directly to vcfzarr (not recommended for large files).
Expand All @@ -504,6 +506,7 @@ def convert_vcf(
samples_chunk_size=samples_chunk_size,
show_progress=progress,
worker_processes=worker_processes,
local_alleles=local_alleles,
)


Expand Down
8 changes: 5 additions & 3 deletions bio2zarr/vcf2zarr/icf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,12 +1083,14 @@ def init(
target_num_partitions=None,
show_progress=False,
compressor=None,
local_alleles,
local_alleles=None,
):
if self.path.exists():
raise ValueError(f"ICF path already exists: {self.path}")
if compressor is None:
compressor = ICF_DEFAULT_COMPRESSOR
if local_alleles is None:
local_alleles = True
vcfs = [pathlib.Path(vcf) for vcf in vcfs]
target_num_partitions = max(target_num_partitions, len(vcfs))

Expand Down Expand Up @@ -1310,7 +1312,7 @@ def explode(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=True,
local_alleles=None,
):
writer = IntermediateColumnarFormatWriter(icf_path)
writer.init(
Expand All @@ -1337,7 +1339,7 @@ def explode_init(
worker_processes=1,
show_progress=False,
compressor=None,
local_alleles=True,
local_alleles=None,
):
writer = IntermediateColumnarFormatWriter(icf_path)
return writer.init(
Expand Down
2 changes: 2 additions & 0 deletions bio2zarr/vcf2zarr/vcz.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,7 @@ def convert(
variants_chunk_size=None,
samples_chunk_size=None,
worker_processes=1,
local_alleles=None,
show_progress=False,
icf_path=None,
):
Expand All @@ -1075,6 +1076,7 @@ def convert(
vcfs,
worker_processes=worker_processes,
show_progress=show_progress,
local_alleles=local_alleles,
)
encode(
icf_path,
Expand Down
10 changes: 9 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@
samples_chunk_size=None,
show_progress=True,
worker_processes=1,
local_alleles=True,
)

DEFAULT_PLINK_CONVERT_ARGS = dict(
variants_chunk_size=None,
samples_chunk_size=None,
show_progress=True,
worker_processes=1,
)


Expand Down Expand Up @@ -621,7 +629,7 @@ def test_convert_plink(self, mocked, progress, flag):
assert result.exit_code == 0
assert len(result.stdout) == 0
assert len(result.stderr) == 0
args = dict(DEFAULT_CONVERT_ARGS)
args = dict(DEFAULT_PLINK_CONVERT_ARGS)
args["show_progress"] = progress
mocked.assert_called_once_with("in", "out", **args)

Expand Down
16 changes: 14 additions & 2 deletions tests/test_vcf_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,22 @@ def test_call_LAA(self, ds):
class TestTriploidExample:
data_path = "tests/data/vcf/triploid.vcf.gz"

def test_value_error(self, tmp_path_factory):
@pytest.fixture(scope="class")
def ds(self, tmp_path_factory):
out = tmp_path_factory.mktemp("data") / "example.vcf.zarr"
vcf2zarr.convert([self.data_path], out, local_alleles=False)
return sg.load_dataset(out)

def test_error_with_local_alleles(self, tmp_path_factory):
icf_path = tmp_path_factory.mktemp("data") / "triploid.icf"
with pytest.raises(ValueError, match=re.escape("Cannot handle ploidy = 3")):
vcf2zarr.explode(icf_path, [self.data_path], worker_processes=0)
vcf2zarr.explode(
icf_path, [self.data_path], worker_processes=0, local_alleles=True
)

def test_ok_without_local_alleles(self, ds):
nt.assert_array_equal(ds.call_genotype.values, [[[0, 0, 0]]])
nt.assert_array_equal(ds.call_PL.values, [[[0, 0, 0, 0]]])


class Test1000G2020Example:
Expand Down

0 comments on commit 0768962

Please sign in to comment.