diff --git a/pyproject.toml b/pyproject.toml index 957ed6f..d3cc5a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "rich", "rich-click", "structlog", + "tqdm", "urllib3", "us", "zstandard", diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 03f8694..81efd8b 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -134,6 +134,8 @@ six==1.16.0 # via python-dateutil structlog==24.4.0 # via cladetime (pyproject.toml) +tqdm==4.67.1 + # via cladetime (pyproject.toml) types-awscrt==0.23.0 # via botocore-stubs types-python-dateutil==2.9.0.20241003 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 103e96c..9d3b654 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -81,6 +81,8 @@ six==1.16.0 # via python-dateutil structlog==24.4.0 # via cladetime (pyproject.toml) +tqdm==4.67.1 + # via cladetime (pyproject.toml) typing-extensions==4.12.2 # via rich-click tzdata==2024.2 diff --git a/src/cladetime/cladetime.py b/src/cladetime/cladetime.py index ac1243d..9710196 100644 --- a/src/cladetime/cladetime.py +++ b/src/cladetime/cladetime.py @@ -285,8 +285,23 @@ def assign_clades(self, sequence_metadata: pl.LazyFrame, output_file: str | None "Starting clade assignment pipeline", sequence_as_of=self.sequence_as_of, tree_as_of=self.tree_as_of ) + # drop any clade-related columns from sequence_metadata (if any exists, it will be replaced + # by the results of the clade assignment) + sequence_metadata = sequence_metadata.drop( + [ + col + for col in sequence_metadata.collect_schema().names() + if col not in self._config.nextstrain_standard_metadata_fields + ] + ) + + # from the sequence metadata, derive a set of sequence IDs (the "strain") + # column for use when filtering sequences in the .fasta file + logger.info("Collecting sequence IDs from metadata") + ids: set = sequence.get_metadata_ids(sequence_metadata) + sequence_count = len(ids) + # if there are no sequences in the filtered metadata, stop the clade assignment - sequence_count = sequence_metadata.select(pl.len()).collect().item() if sequence_count == 0: msg = "Sequence_metadata is empty or missing 'strain' columns \n" "Stopping clade assignment...." warnings.warn( @@ -309,17 +324,6 @@ def assign_clades(self, sequence_metadata: pl.LazyFrame, output_file: str | None category=CladeTimeSequenceWarning, ) - # drop any clade-related columns from sequence_metadata (if any exists, it will be replaced - # by the results of the clade assignment) - sequence_metadata = sequence_metadata.drop( - [ - col - for col in sequence_metadata.collect_schema().names() - if col not in self._config.nextstrain_standard_metadata_fields - ] - ) - - ids = sequence.get_metadata_ids(sequence_metadata) tree = Tree(self.tree_as_of, self.url_sequence) with tempfile.TemporaryDirectory() as tmpdir: diff --git a/src/cladetime/sequence.py b/src/cladetime/sequence.py index 675002e..6c07d1a 100644 --- a/src/cladetime/sequence.py +++ b/src/cladetime/sequence.py @@ -1,5 +1,6 @@ """Functions for retrieving and parsing SARS-CoV-2 virus genome data.""" +import io import lzma import os import re @@ -12,6 +13,7 @@ import requests import structlog import us +import zstandard as zstd from Bio import SeqIO from Bio.SeqIO import FastaIO from requests import Session @@ -25,14 +27,34 @@ @time_function -def _download_from_url(session: Session, url: str, data_path: Path) -> Path: - """Download a file from the specified URL and save it to data_path.""" +def _download_from_url( + session: Session, + url: str, + data_path: Path, +) -> Path: + """Download a file from the specified URL and save it to data_path. + + Parameters + ---------- + session : Session + Requests session for making HTTP requests + url : str + URL of the file to download + data_path : Path + Path where the downloaded file will be saved + + Returns + ------- + Path + Path to the downloaded file + """ parsed_url = urlparse(url) url_filename = os.path.basename(parsed_url.path) - data_path.mkdir(parents=True, exist_ok=True) filename = data_path / url_filename + data_path.mkdir(parents=True, exist_ok=True) + with session.get(url, stream=True) as result: result.raise_for_status() with open(filename, "wb") as f: @@ -386,7 +408,7 @@ def parse_sequence_assignments(df_assignments: pl.DataFrame) -> pl.DataFrame: @time_function -def filter(sequence_ids: set, url_sequence: str, output_path: Path) -> Path: +def filter(sequence_ids: set, url_sequence: str, output_path: Path, stream: bool = True) -> Path: """Filter a fasta file against a specific set of sequences. Download a sequence file (in FASTA format) from Nexstrain, filter @@ -408,31 +430,51 @@ def filter(sequence_ids: set, url_sequence: str, output_path: Path) -> Path: ------- pathlib.Path Full path to the filtered sequence file + + Raises + ------ + ValueError + If url_sequence points to a file that doesn't have a + .zst or .xz extension. """ session = _get_session() - # FIXME: validate url_sequence (should be in filename.fasta.xz format) - # alternately, we could expand this function to handle other types - # of compression schemas (ZSTD) or none at all + # If URL doesn't indicate a file compression format used + # by Nextstrain, exit before downloading + parsed_sequence_url = urlparse(url_sequence) + file_extension = Path(parsed_sequence_url.path).suffix.lower() + if file_extension not in [".xz", ".zst"]: + raise ValueError(f"Unsupported compression format: {file_extension}") + filtered_sequence_file = output_path / "sequences_filtered.fasta" - logger.info("Starting sequence file download", url=url_sequence) + logger.info("Downloading sequence file", url=url_sequence) sequence_file = _download_from_url(session, url_sequence, output_path) - logger.info("Sequence file saved", path=sequence_file) - filtered_sequence_file = output_path / "sequences_filtered.fasta" - # create a second fasta file with only those sequences in the metadata - logger.info("Starting sequence filter", sequence_file=sequence_file, filtered_sequence_file=filtered_sequence_file) + logger.info("Starting sequence filter", filtered_sequence_file=filtered_sequence_file) sequence_count = 0 sequence_match_count = 0 + with open(filtered_sequence_file, "w") as fasta_output: - with lzma.open(sequence_file, mode="rt") as handle: - for record in FastaIO.FastaIterator(handle): - sequence_count += 1 - if record.id in sequence_ids: - sequence_match_count += 1 - SeqIO.write(record, fasta_output, "fasta") + if file_extension == ".xz": + with lzma.open(sequence_file, mode="rt") as handle: + for record in FastaIO.FastaIterator(handle): + sequence_count += 1 + if record.id in sequence_ids: + sequence_match_count += 1 + SeqIO.write(record, fasta_output, "fasta") + else: + with open(sequence_file, "rb") as handle: + dctx = zstd.ZstdDecompressor() + with dctx.stream_reader(handle) as reader: + text_stream = io.TextIOWrapper(reader, encoding="utf-8") + for record in FastaIO.FastaIterator(text_stream): + sequence_count += 1 + if record.id in sequence_ids: + sequence_match_count += 1 + SeqIO.write(record, fasta_output, "fasta") + logger.info( "Filtered sequence file saved", num_sequences=sequence_count, diff --git a/src/cladetime/util/config.py b/src/cladetime/util/config.py index b1d5828..de1affd 100644 --- a/src/cladetime/util/config.py +++ b/src/cladetime/util/config.py @@ -18,7 +18,7 @@ class Config: nextstrain_ncov_bucket = "nextstrain-data" nextstrain_ncov_metadata_key = "files/ncov/open/metadata_version.json" nextstrain_genome_metadata_key = "files/ncov/open/metadata.tsv.zst" - nextstrain_genome_sequence_key = "files/ncov/open/sequences.fasta.xz" + nextstrain_genome_sequence_key = "files/ncov/open/sequences.fasta.zst" nextclade_data_url = "https://data.clades.nextstrain.org" nextclade_data_url_version = "v3" nextclade_base_url: str = "https://nextstrain.org/nextclade/sars-cov-2" diff --git a/tests/data/README.md b/tests/data/README.md index 40b6544..190eb48 100644 --- a/tests/data/README.md +++ b/tests/data/README.md @@ -5,7 +5,8 @@ This directory contains test files used by CladeTime's test suite. * `moto_fixture` directory contains files used when recreating Nextstrain/Nextclade data in the moto mocked S3 bucket * `test_metadata.tsv` was used to test `get_clade_list` before that functionality moved to variant-nowcast-hub * `metadata.tsv.xz` and `metadata.tsv.xz` are used to test setting CladeTime's sequence_metadata property. -* `test_sequence.xz` is used to test the sequence filter function +* `test_sequences.fasta` isn't used by tests directly, but is the human-readable version of test_sequences.fasts.[xz|zst] below +* `test_sequences.fasta.xz` and `test_sequences.fasta.zst` are used to test the sequence filter function * `test_sequences.fasta`, `test_sequences_fake.fasta`, and `test_nexclade_dataset.zip` are used in Nextclade integration tests * `test_sequences_updated.fasta` is used to test clade assignments with prior reference trees * it contains 3 sequence strains with clade assignments that changed between 2024-08-02 and 2024-11-07 diff --git a/tests/data/test_sequence.xz b/tests/data/test_sequence.xz deleted file mode 100644 index ac15e0b..0000000 Binary files a/tests/data/test_sequence.xz and /dev/null differ diff --git a/tests/data/test_sequences.fasta.xz b/tests/data/test_sequences.fasta.xz new file mode 100644 index 0000000..a6c48f5 Binary files /dev/null and b/tests/data/test_sequences.fasta.xz differ diff --git a/tests/data/test_sequences.fasta.zst b/tests/data/test_sequences.fasta.zst new file mode 100644 index 0000000..84da059 Binary files /dev/null and b/tests/data/test_sequences.fasta.zst differ diff --git a/tests/integration/test_nextclade_integration.py b/tests/integration/test_nextclade_integration.py index 12c199e..bedde6a 100644 --- a/tests/integration/test_nextclade_integration.py +++ b/tests/integration/test_nextclade_integration.py @@ -29,7 +29,7 @@ def test_get_clade_assignments(test_file_path, tmp_path): "USA/NJ-CDC-LC1124615/2024", } - sequence_file = test_file_path / "test_sequences.fasta" + sequence_file = test_file_path / "test_sequences.fasta.xz" nextclade_dataset = test_file_path / "test_nextclade_dataset.zip" # _get_clade_assignments should create the output directory if it doesn't exist output_file = tmp_path / "clade_assignments" / "nextclade_assignments.csv" diff --git a/tests/unit/test_sequence.py b/tests/unit/test_sequence.py index e4842de..e494eca 100644 --- a/tests/unit/test_sequence.py +++ b/tests/unit/test_sequence.py @@ -5,6 +5,7 @@ import polars as pl import pytest +import zstandard as zstd from Bio import SeqIO from polars.testing import assert_frame_equal @@ -241,8 +242,9 @@ def test_get_metadata_ids_bad_data(bad_input): assert seq_set == set() -def test_filter(test_file_path, tmpdir): - test_sequence_file = test_file_path / "test_sequence.xz" +@pytest.mark.parametrize("sequence_file", ["test_sequences.fasta.xz", "test_sequences.fasta.zst"]) +def test_filter(test_file_path, tmpdir, sequence_file): + test_sequence_file = test_file_path / sequence_file test_sequence_set = { "USA/MD-MDH-1820/2021", "USA/CA-CDPH-A3000000297958/2023", @@ -252,7 +254,7 @@ def test_filter(test_file_path, tmpdir): } mock_download = MagicMock(return_value=test_sequence_file, name="_download_from_url_mock") with patch("cladetime.sequence._download_from_url", mock_download): - filtered_sequence_file = sequence.filter(test_sequence_set, "http://thisismocked.com", tmpdir) + filtered_sequence_file = sequence.filter(test_sequence_set, f"http://thisismocked/{test_sequence_file}", tmpdir) test_sequence_set.remove("STARFLEET/DS9-DS9-001/2024") actual_headers = [] @@ -262,19 +264,20 @@ def test_filter(test_file_path, tmpdir): assert set(actual_headers) == test_sequence_set -def test_filter_no_sequences(test_file_path, tmpdir): +@pytest.mark.parametrize("sequence_file", ["test_sequences.fasta.xz", "test_sequences.fasta.zst"]) +def test_filter_no_sequences(test_file_path, tmpdir, sequence_file): """Test filter with empty sequence set.""" - test_sequence_file = test_file_path / "test_sequence.xz" + test_sequence_file = test_file_path / sequence_file test_sequence_set = {} mock_download = MagicMock(return_value=test_sequence_file, name="_download_from_url_mock") with patch("cladetime.sequence._download_from_url", mock_download): - filtered_no_sequence = sequence.filter(test_sequence_set, "http://thisismocked.com", tmpdir) + filtered_no_sequence = sequence.filter(test_sequence_set, f"http://thisismocked.com/{sequence_file}", tmpdir) contents = filtered_no_sequence.read_text(encoding=None) assert len(contents) == 0 -def test_filter_empty_fasta(tmpdir): +def test_filter_empty_fasta_xz(tmpdir): # sequence file is empty test_sequence_set = {"A", "B", "C", "D"} empty_sequence_file = tmpdir / "empty_sequence_file.xz" @@ -282,11 +285,36 @@ def test_filter_empty_fasta(tmpdir): pass mock_download = MagicMock(return_value=empty_sequence_file, name="_download_from_url_mock") with patch("cladetime.sequence._download_from_url", mock_download): - seq_filtered = sequence.filter(test_sequence_set, "http://thisismocked.com", tmpdir) + seq_filtered = sequence.filter(test_sequence_set, "http://thisismocked.com/mocky.xz", tmpdir) contents = seq_filtered.read_text(encoding=None) assert len(contents) == 0 +def test_filter_empty_fasta_zst(tmpdir): + # sequence file is empty + test_sequence_set = {"A", "B", "C", "D"} + empty_sequence_file = tmpdir / "empty_sequence_file.zst" + + cctx = zstd.ZstdCompressor() + with open(empty_sequence_file, "wb") as f: + with cctx.stream_writer(f) as compressor: + compressor.write(b"") + mock_download = MagicMock(return_value=empty_sequence_file, name="_download_from_url_mock") + with patch("cladetime.sequence._download_from_url", mock_download): + seq_filtered = sequence.filter(test_sequence_set, "http://thisismocked.com/mocky.zst", tmpdir) + contents = seq_filtered.read_text(encoding=None) + assert len(contents) == 0 + + +def test_filter_invalid_fasta_compression(test_file_path, tmpdir): + test_sequence_file = test_file_path / "test_sequences.fasta.xz" + mock_download = MagicMock(return_value=test_sequence_file, name="_download_from_url_mock") + with pytest.raises(ValueError): + # sequence file has an invalid compression format + with patch("cladetime.sequence._download_from_url", mock_download): + sequence.filter(set(), "http://thisismocked.com/mocky.zip", tmpdir) + + def test_summarize_clades(): test_metadata = pl.DataFrame( {