Skip to content

Commit

Permalink
Allow amplicon schemes in artic bed format
Browse files Browse the repository at this point in the history
  • Loading branch information
martinghunt committed Jun 11, 2024
1 parent 919be0a commit 141b1bc
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 66 deletions.
76 changes: 55 additions & 21 deletions tests/scheme_id_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,9 @@
]


def test_scheme_init_from_tsv():
tmp_tsv = "test_scheme_init_from_tsv.tsv"
with open(tmp_tsv, "w") as f:
print(*SCHEME_TSV_HEADER_FIELDS, sep="\t", file=f)
print("amp1", "amp1_left", "left", "ACGT", "20", sep="\t", file=f)
print("amp1", "amp1_left_alt", "left", "CAG", "15", sep="\t", file=f)
print("amp1", "amp1_left_alt2", "left", "CA", "15", sep="\t", file=f)
print("amp1", "amp1_right", "right", "AAA", "50", sep="\t", file=f)
print("amp2", "amp2_left", "left", "A", "40", sep="\t", file=f)
print("amp2", "amp2_right", "right", "ATGTT", "90", sep="\t", file=f)
print("amp2", "amp2_right_alt", "right", "GTA", "101", sep="\t", file=f)
print("amp2", "amp2_right_alt2", "right", "GGTA", "100", sep="\t", file=f)

scheme = scheme_id.Scheme(tsv_file=tmp_tsv)
assert scheme.amplicons == [
@pytest.fixture()
def amp_scheme_data():
amps = [
{
"name": "amp1",
"start": 15,
Expand All @@ -53,14 +41,60 @@ def test_scheme_init_from_tsv():
},
},
]
assert scheme.left_starts == {15: 0, 20: 0, 40: 1}
assert scheme.right_ends == {52: 0, 94: 1, 103: 1}
assert scheme.amplicon_name_indexes == {"amp1": 0, "amp2": 1}
mean_amp_length = statistics.mean([52 - 15 + 1, 40 - 15, 103 - 52, 103 - 40 + 1])
assert scheme.mean_amp_length == mean_amp_length
return {
"amplicons": amps,
"left_starts": {15: 0, 20: 0, 40: 1},
"right_ends": {52: 0, 94: 1, 103: 1},
"amplicon_name_indexes": {"amp1": 0, "amp2": 1},
"mean_amp_length": statistics.mean(
[52 - 15 + 1, 40 - 15, 103 - 52, 103 - 40 + 1]
),
}


def test_scheme_init_from_tsv(amp_scheme_data):
tmp_tsv = "test_scheme_init_from_tsv.tsv"
with open(tmp_tsv, "w") as f:
print(*SCHEME_TSV_HEADER_FIELDS, sep="\t", file=f)
print("amp1", "amp1_left", "left", "ACGT", "20", sep="\t", file=f)
print("amp1", "amp1_left_alt", "left", "CAG", "15", sep="\t", file=f)
print("amp1", "amp1_left_alt2", "left", "CA", "15", sep="\t", file=f)
print("amp1", "amp1_right", "right", "AAA", "50", sep="\t", file=f)
print("amp2", "amp2_left", "left", "A", "40", sep="\t", file=f)
print("amp2", "amp2_right", "right", "ATGTT", "90", sep="\t", file=f)
print("amp2", "amp2_right_alt", "right", "GTA", "101", sep="\t", file=f)
print("amp2", "amp2_right_alt2", "right", "GGTA", "100", sep="\t", file=f)

scheme = scheme_id.Scheme(amp_scheme_file=tmp_tsv)
assert scheme.amplicons == amp_scheme_data["amplicons"]
assert scheme.left_starts == amp_scheme_data["left_starts"]
assert scheme.right_ends == amp_scheme_data["right_ends"]
assert scheme.amplicon_name_indexes == amp_scheme_data["amplicon_name_indexes"]
assert scheme.mean_amp_length == amp_scheme_data["mean_amp_length"]
os.unlink(tmp_tsv)


def test_scheme_init_from_bed(amp_scheme_data):
tmp_bed = "test_scheme_init_from_tsv.bed"
with open(tmp_bed, "w") as f:
print("REF", 20, 24, "amp1_LEFT_1", "1", "+", "ACGT", sep="\t", file=f)
print("REF", 15, 18, "amp1_LEFT_alt", "1", "+", "CAG", sep="\t", file=f)
print("REF", 15, 17, "amp1_LEFT_alt2", "1", "+", "CA", sep="\t", file=f)
print("REF", 50, 53, "amp1_RIGHT_1", "1", "+", "AAA", sep="\t", file=f)
print("REF", 40, 41, "amp2_LEFT_1", "1", "+", "A", sep="\t", file=f)
print("REF", 90, 95, "amp2_RIGHT_1", "1", "+", "ATGTT", sep="\t", file=f)
print("REF", 101, 104, "amp2_RIGHT_1", "1", "+", "GTA", sep="\t", file=f)
print("REF", 100, 104, "amp2_RIGHT_2", "1", "+", "GGTA", sep="\t", file=f)

scheme = scheme_id.Scheme(amp_scheme_file=tmp_bed)
assert scheme.amplicons == amp_scheme_data["amplicons"]
assert scheme.left_starts == amp_scheme_data["left_starts"]
assert scheme.right_ends == amp_scheme_data["right_ends"]
assert scheme.amplicon_name_indexes == amp_scheme_data["amplicon_name_indexes"]
assert scheme.mean_amp_length == amp_scheme_data["mean_amp_length"]
os.unlink(tmp_bed)


def test_scheme_init_distance_lists():
ref_length = 14
scheme = scheme_id.Scheme(end_tolerance=0)
Expand Down Expand Up @@ -123,7 +157,7 @@ def test_count_primer_hits():
print("amp1", "amp1_left_alt2", "left", "AAAAAA", "30", sep="\t", file=f)
print("amp1", "amp1_right", "right", "ACGTACG", "50", sep="\t", file=f)
print("amp1", "amp1_right_alt1", "right", "ACGT", "54", sep="\t", file=f)
scheme = scheme_id.Scheme(tsv_file=tmp_tsv)
scheme = scheme_id.Scheme(amp_scheme_file=tmp_tsv)
scheme.init_distance_lists(65)
expect_amplicons = copy.deepcopy(scheme.amplicons)
left_hits = [0] * 65
Expand Down
129 changes: 85 additions & 44 deletions viridian/scheme_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class Scheme:
def __init__(self, tsv_file=None, end_tolerance=3):
def __init__(self, amp_scheme_file=None, end_tolerance=3):
self.left_starts = {}
self.right_ends = {}
self.left_dists = []
Expand All @@ -35,64 +35,105 @@ def __init__(self, tsv_file=None, end_tolerance=3):
self.end_tolerance = end_tolerance
self.last_amplicon_end = -1

if tsv_file is not None:
if amp_scheme_file is not None:
try:
self.load_from_tsv_file(tsv_file)
self.load_from_file(amp_scheme_file)
except:
raise Exception(f"Error loading primer scheme from TSV file {tsv_file}")
raise Exception(
f"Error loading primer scheme from file {amp_scheme_file}"
)

self.amp_coords = [(a["start"], a["end"]) for a in self.amplicons]
self.amp_coords.sort()
self._calculate_mean_amp_length()

def load_from_tsv_file(self, tsv_file):
def read_tsv_lines(self, tsv_file):
with open(tsv_file) as f:
for d in csv.DictReader(f, delimiter="\t"):
if d["Left_or_right"] not in ["left", "right"]:
raise Exception(
f"Left_or_right column not left or right. Got: {d['Left_or_right']}"
)
yield d

if d["Amplicon_name"] not in self.amplicon_name_indexes:
self.amplicons.append(
{
"name": d["Amplicon_name"],
"start": float("inf"),
"end": -1,
"primers": {"left": [], "right": []},
}
def read_bed_lines(self, bed_file):
with open(bed_file) as f:
for line in f:
fields = line.rstrip().split("\t")
if len(fields) != 7:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Expected 7 columns, but got {len(fields)}:\n{line}"
)
self.amplicon_name_indexes[d["Amplicon_name"]] = (
len(self.amplicons) - 1

d = {}
try:
d["Position"] = int(fields[1])
except:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Could not get amplicon start position from second column:\n{line}"
)

amp_index = self.amplicon_name_indexes[d["Amplicon_name"]]
amp = self.amplicons[amp_index]
primer_start = int(d["Position"])
primer_end = primer_start + len(d["Sequence"]) - 1

if d["Left_or_right"] == "left":
primers = amp["primers"]["left"]
same = [x for x in primers if x[0] == primer_start]
if len(same):
same[0][1] = max(primer_end, same[0][1])
else:
primers.append([primer_start, primer_end])
primers.sort()
self.left_starts[primer_start] = amp_index
amp["start"] = min(amp["start"], primer_start)
elif d["Left_or_right"] == "right":
primers = amp["primers"]["right"]
same = [x for x in primers if x[1] == primer_end]
if len(same):
same[0][0] = min(primer_start, same[0][0])
else:
primers.append([primer_start, primer_end])
primers.sort()

self.right_ends[primer_end] = amp_index
amp["end"] = max(amp["end"], primer_end)
self.last_amplicon_end = max(amp["end"], self.last_amplicon_end)
try:
d["Amplicon_name"], d["Left_or_right"], primer_name = fields[
3
].rsplit("_", maxsplit=2)
except:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Could not get amplicon name, left/right, primer name from column 4:\n{line}"
)

d["Left_or_right"] = d["Left_or_right"].lower()
if d["Left_or_right"] not in ["left", "right"]:
raise Exception(
f"Error reading amplicon scheme BED file {bed_file}. Could not get left/right from column 4:\n{line}"
)

d["Sequence"] = fields[6]
yield d

def load_from_file(self, filename):
read_func = (
self.read_bed_lines if filename.endswith(".bed") else self.read_tsv_lines
)
for d in read_func(filename):
if d["Amplicon_name"] not in self.amplicon_name_indexes:
self.amplicons.append(
{
"name": d["Amplicon_name"],
"start": float("inf"),
"end": -1,
"primers": {"left": [], "right": []},
}
)
self.amplicon_name_indexes[d["Amplicon_name"]] = len(self.amplicons) - 1

amp_index = self.amplicon_name_indexes[d["Amplicon_name"]]
amp = self.amplicons[amp_index]
primer_start = int(d["Position"])
primer_end = primer_start + len(d["Sequence"]) - 1

if d["Left_or_right"] == "left":
primers = amp["primers"]["left"]
same = [x for x in primers if x[0] == primer_start]
if len(same):
same[0][1] = max(primer_end, same[0][1])
else:
primers.append([primer_start, primer_end])
primers.sort()
self.left_starts[primer_start] = amp_index
amp["start"] = min(amp["start"], primer_start)
elif d["Left_or_right"] == "right":
primers = amp["primers"]["right"]
same = [x for x in primers if x[1] == primer_end]
if len(same):
same[0][0] = min(primer_start, same[0][0])
else:
primers.append([primer_start, primer_end])
primers.sort()

self.right_ends[primer_end] = amp_index
amp["end"] = max(amp["end"], primer_end)
self.last_amplicon_end = max(amp["end"], self.last_amplicon_end)

def _calculate_mean_amp_length(self):
left_lengths = [
Expand Down Expand Up @@ -261,7 +302,7 @@ def simulate_reads(
random.seed(42)

with open(outfile, "w") as f:
for (start, end) in self.amp_coords:
for start, end in self.amp_coords:
if read_length is None:
print(f">{start}_{end}", file=f)
print(ref_seq[start : end + 1], file=f)
Expand Down Expand Up @@ -506,7 +547,7 @@ def analyse_bam(
for scheme_name, scheme_tsv in scheme_tsvs.items():
logging.info(f"{LOG_PREFIX} Analysing amplicon scheme {scheme_name}")
logging.debug(f"{LOG_PREFIX} {scheme_name} Load TSV file {scheme_tsv}")
scheme = Scheme(tsv_file=scheme_tsv, end_tolerance=end_tolerance)
scheme = Scheme(amp_scheme_file=scheme_tsv, end_tolerance=end_tolerance)
if scheme.last_amplicon_end > ref_length:
return (
json_dict,
Expand Down
2 changes: 1 addition & 1 deletion viridian/scheme_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def simulate_all_schemes(

for scheme_name, scheme_tsv in amplicon_scheme_name_to_tsv.items():
logging.info(f"Processing scheme {scheme_name}")
scheme = scheme_id.Scheme(tsv_file=scheme_tsv)
scheme = scheme_id.Scheme(amp_scheme_file=scheme_tsv)

for fragment in False, True:
if fragment:
Expand Down

0 comments on commit 141b1bc

Please sign in to comment.