Skip to content

Commit

Permalink
Merge pull request #348 from jeromekelleher/add-mask-options
Browse files Browse the repository at this point in the history
Add mask options
  • Loading branch information
jeromekelleher authored Oct 9, 2024
2 parents de36b11 + ecc24d1 commit 7a04007
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 24 deletions.
42 changes: 35 additions & 7 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_resources():
"elapsed_time": wall_time,
"user_time": user_time,
"sys_time": sys_time,
"max_memory": max_mem, # bytes
"max_memory": max_mem, # bytes
}


Expand Down Expand Up @@ -264,24 +264,52 @@ def add_provenance(ts, output_file):
type=click.Path(exists=True, dir_okay=False),
help="File containing the list of additional problematic sites to exclude.",
)
@click.option(
"--mask-flanks",
is_flag=True,
flag_value=True,
help=(
"If true, add the non-genic regions at either end of the genome to "
"problematic sites"
),
)
@click.option(
"--mask-problematic-regions",
is_flag=True,
flag_value=True,
help=(
"If true, add the problematic regions problematic sites"
),
)
@click.option("-v", "--verbose", count=True)
@click.option("-l", "--log-file", default=None, type=click.Path(dir_okay=False))
def initialise(ts, match_db, additional_problematic_sites, verbose, log_file):
def initialise(
ts, match_db, additional_problematic_sites, mask_flanks, mask_problematic_regions, verbose, log_file
):
"""
Initialise a new base tree sequence to begin inference.
"""
setup_logging(verbose, log_file)

additional_problematic = []
additional_problematic = np.array([], dtype=int)
if additional_problematic_sites is not None:
additional_problematic = (
np.loadtxt(additional_problematic_sites, ndmin=1).astype(int).tolist()
)
additional_problematic = np.loadtxt(
additional_problematic_sites, ndmin=1
).astype(int)
logger.info(
f"Excluding additional {len(additional_problematic)} problematic sites"
)
if mask_flanks:
additional_problematic = np.concatenate(
(core.get_flank_coordinates(), additional_problematic)
)
if mask_problematic_regions:
additional_problematic = np.concatenate(
(core.get_problematic_regions(), additional_problematic)
)

base_ts = sc2ts.initial_ts(additional_problematic)
additional_problematic = np.unique(additional_problematic)
base_ts = sc2ts.initial_ts(additional_problematic.tolist())
add_provenance(base_ts, ts)
logger.info(f"New base ts at {ts}")
sc2ts.MatchDb.initialise(match_db)
Expand Down
35 changes: 31 additions & 4 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,35 @@ def get_problematic_regions():
https://github.com/jeromekelleher/sc2ts/issues/231#issuecomment-2401405355
Region: NTD domain
Coords: [21602-22472)
Multiple highly recurrent deleted regions in NTD domain in Spike
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7971772/
Region: ORF8
https://virological.org/t/repeated-loss-of-orf8-expression-in-circulating-sars-cov-2-lineages/931/1
The 1-based (half-open) coordinates were taken from the UCSC Genome Browser.
"""
return np.concatenate([
np.arange(21602, 22472, dtype=np.int64), # NTD domain in S
np.arange(27939, 28257, dtype=np.int64), # ORF8
])
orf8 = get_gene_coordinates()["ORF8"]
return np.concatenate(
[
np.arange(21602, 22472, dtype=np.int64), # NTD domain in S
np.arange(*orf8, dtype=np.int64),
]
)


def get_flank_coordinates():
"""
Return the coordinates at either end of the genome for masking out.
"""
genes = get_gene_coordinates()
start = genes["ORF1ab"][0]
end = genes["ORF10"][1]
return np.concatenate(
(np.arange(1, start), np.arange(end, REFERENCE_SEQUENCE_LENGTH))
)


@dataclasses.dataclass
Expand Down Expand Up @@ -116,6 +139,10 @@ def get_reference_sequence(as_array=False):


def get_gene_coordinates():
"""
Returns a map of gene name to interval, (start, stop). These are
half-open, left-inclusive, right-exclusive.
"""
global __cached_genes
if __cached_genes is None:
d = {}
Expand Down
22 changes: 11 additions & 11 deletions sc2ts/data/annotation.csv
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
start,end,gene
266,21555,ORF1ab
21563,25384,S
25393,26220,ORF3a
26245,26472,E
26523,27191,M
27202,27387,ORF6
27394,27759,ORF7a
27756,27887,ORF7b
27894,28259,ORF8
28274,29533,N
29558,29674,ORF10
266,21556,ORF1ab
21563,25385,S
25393,26221,ORF3a
26245,26473,E
26523,27192,M
27202,27388,ORF6
27394,27760,ORF7a
27756,27888,ORF7b
27894,28260,ORF8
28274,29534,N
29558,29675,ORF10
2 changes: 1 addition & 1 deletion sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def initial_ts(additional_problematic_sites=list()):
problematic_sites = set(core.get_problematic_sites()) | set(
additional_problematic_sites
)

logger.info(f"Masking out {len(problematic_sites)} sites (additional={len(additional_problematic_sites)})")
tables = tskit.TableCollection(L)
tables.time_units = core.TIME_UNITS

Expand Down
2 changes: 1 addition & 1 deletion tests/test_alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_in(self, fx_alignment_store):
def test_get_gene_coordinates():
d = core.get_gene_coordinates()
assert len(d) == 11
assert d["S"] == (21563, 25384)
assert d["S"] == (21563, 25385)


class TestEncodeAligment:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,38 @@ def test_additional_problematic_sites(self, tmp_path, additional):
match_db = sc2ts.MatchDb(match_db_path)
assert len(match_db) == 0

def test_mask_flanks(self, tmp_path):
ts_path = tmp_path / "trees.ts"
match_db_path = tmp_path / "match.db"
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"initialise {ts_path} {match_db_path} --mask-flanks",
catch_exceptions=False,
)
assert result.exit_code == 0
ts = tskit.load(ts_path)
sites = ts.metadata["sc2ts"]["additional_problematic_sites"]
# < 266 (leftmost coordinate of ORF1a)
# > 29674 (rightmost coordinate of ORF10)
assert sites == list(range(1, 266)) + list(range(29675, 29904))

def test_mask_problematic_regions(self, tmp_path):
ts_path = tmp_path / "trees.ts"
match_db_path = tmp_path / "match.db"
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"initialise {ts_path} {match_db_path} --mask-problematic-regions",
catch_exceptions=False,
)
assert result.exit_code == 0
ts = tskit.load(ts_path)
sites = ts.metadata["sc2ts"]["additional_problematic_sites"]
# NTD: [21602-22472)
# ORF8: [27894, 28260)
assert sites == list(range(21602, 22472)) + list(range(27894, 28260))

def test_provenance(self, tmp_path):
ts_path = tmp_path / "trees.ts"
match_db_path = tmp_path / "match.db"
Expand Down

0 comments on commit 7a04007

Please sign in to comment.