Skip to content

Commit

Permalink
adding UTRs for multi-exon transcripts seems to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
KatharinaHoff committed Sep 29, 2023
1 parent d3601ab commit 78d736b
Showing 1 changed file with 106 additions and 50 deletions.
156 changes: 106 additions & 50 deletions scripts/stringtie2utr.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import re


def read_gtf(gtf_file):
"""
Reads a GTF file and extracts gene and non-gene features.
Expand All @@ -11,15 +12,18 @@ def read_gtf(gtf_file):
- gtf_file (str): Path to the GTF file.
Returns:
tuple: (non_gene_dict, gene_dict) where:
tuple: (non_gene_dict, gene_dict, tx_to_gene_dict, tx_dict) where:
- non_gene_dict (dict): Dictionary with transcript IDs as keys and lists of non-gene feature lines as values.
- gene_dict (dict): Dictionary with transcript IDs as keys and the corresponding gene line as value.
- gene_dict (dict): Dictionary with gene IDs extracted from the last column of gene feature lines as keys and the corresponding gene line as value.
- tx_to_gene_dict (dict): Dictionary with transcript IDs as keys and the corresponding gene ID as value, extracted from the last column of transcript feature lines.
- tx_dict (dict): Dictionary with transcript IDs as keys and the corresponding transcript line as value. The entire last column is used as the transcript ID.
"""
non_gene_dict = {}
gene_dict = {}
temp_gene_storage = {} # Temporary storage for gene lines

tx_to_gene_dict = {}
tx_dict = {}
transcript_id_pattern = re.compile(r'transcript_id "([^"]+)"')
gene_id_pattern = re.compile(r'gene_id "([^"]+)"')

with open(gtf_file, 'r') as f:
for line in f:
Expand All @@ -28,35 +32,28 @@ def read_gtf(gtf_file):
feature_type = fields[2]
last_field = fields[-1]

transcript_id_match = transcript_id_pattern.search(last_field)

gene_id = None
transcript_id = None

# Extract gene_id and transcript_id irrespective of their order in the last_field
for item in last_field.split(';'):
if "gene_id" in item:
gene_id = item.split(' ')[1].replace('"', '').strip()
if "transcript_id" in item:
transcript_id = transcript_id_match.group(1) if transcript_id_match else None

if feature_type == "gene":
temp_gene_storage[gene_id] = line.strip()
gene_dict[last_field] = line.strip() # Use the entire last field as gene_id
elif feature_type == "transcript":
if transcript_id:
if gene_id in temp_gene_storage:
gene_dict[transcript_id] = temp_gene_storage[gene_id]
else:
# Format 2, no explicit gene line but inferred from the transcript
gene_dict[transcript_id] = line.strip()
non_gene_dict[transcript_id] = [line.strip()]
tx_dict[last_field] = line.strip() # Use the entire last field as transcript_id
else:
transcript_id_match = transcript_id_pattern.search(last_field)
gene_id_match = gene_id_pattern.search(last_field)
gene_id = None
transcript_id = None
if transcript_id_match:
transcript_id = transcript_id_match.group(1)
if gene_id_match:
gene_id = gene_id_match.group(1)
if transcript_id: # Ensure we have a transcript ID
non_gene_dict.setdefault(transcript_id, []).append(line.strip())
if transcript_id not in tx_to_gene_dict:
tx_to_gene_dict[transcript_id] = gene_id

return non_gene_dict, gene_dict, tx_to_gene_dict, tx_dict


return non_gene_dict, gene_dict

import re

def add_intron_features(gtf_dict):
"""
Expand Down Expand Up @@ -262,8 +259,6 @@ def merge_features(braker_gtf, stringtie_gtf, selected_transcripts):
return braker_gtf


import re

def compute_utr_features(braker_gtf):
"""
Compute the UTR features for each transcript in braker_gtf based on strand information.
Expand Down Expand Up @@ -298,23 +293,23 @@ def compute_utr_features(braker_gtf):
# Check for UTR based on strand
if strand == "+":
if start < cds_start:
utr5 = "\t".join(fields[:2] + ["5'UTR"] + fields[3:])
utr5 = "\t".join(fields[:2] + ["five_prime_UTR"] + fields[3:])
utr5 = utr5.replace(str(end), str(cds_start - 1))
utr_features.append(utr5)

if end > cds_end:
utr3 = "\t".join(fields[:2] + ["3'UTR"] + fields[3:])
utr3 = "\t".join(fields[:2] + ["three_prime_UTR"] + fields[3:])
utr3 = utr3.replace(str(start), str(cds_end + 1))
utr_features.append(utr3)

elif strand == "-":
if end > cds_end:
utr5 = "\t".join(fields[:2] + ["5'UTR"] + fields[3:])
utr5 = "\t".join(fields[:2] + ["five_prime_UTR"] + fields[3:])
utr5 = utr5.replace(str(start), str(cds_end + 1))
utr_features.append(utr5)

if start < cds_start:
utr3 = "\t".join(fields[:2] + ["3'UTR"] + fields[3:])
utr3 = "\t".join(fields[:2] + ["three_prime_UTR"] + fields[3:])
utr3 = utr3.replace(str(end), str(cds_start - 1))
utr_features.append(utr3)

Expand All @@ -324,33 +319,92 @@ def compute_utr_features(braker_gtf):
return braker_gtf


def print_gtf(gtf_dict, gene_dict):
def fix_feature_coordinates(gtf_dict, gene_dict, tx_to_gene_dict, tx_dict):
"""
The list of features in gtf_dict has been expanded compared to the original version. We need to identify the left most and right most coordinate
for each transcript. Then, update the tx_dict lines with these new coordinates. In the end, loop over all tx_dict entries and identify the left most and right
most coordinate per gene, then update the gene lines in gene_dict.
Args:
- gtf_dict (dict): Dictionary with transcript IDs as keys and lists of non-gene feature lines as values.
- gene_dict (dict): Dictionary with gene IDs extracted from the last column of gene feature lines as keys and the corresponding gene line as value.
- tx_to_gene_dict (dict): Dictionary with transcript IDs as keys and the corresponding gene ID as value, extracted from the last column of transcript feature lines.
- tx_dict (dict): Dictionary with transcript IDs as keys and the corresponding transcript line as value. The entire last column is used as the transcript ID.
Returns:
gene_dict, tx_dict with updated coordinates
"""
# 1. Update transcript coordinates based on feature coordinates in gtf_dict
for tx_id, features in gtf_dict.items():
starts = []
ends = []
for feature in features:
fields = feature.split('\t')
starts.append(int(fields[3]))
ends.append(int(fields[4]))

# Find the min and max coordinates for the transcript
min_coord = min(starts)
max_coord = max(ends)

tx_fields = tx_dict[tx_id].split('\t')
tx_fields[3] = str(min_coord)
tx_fields[4] = str(max_coord)
tx_dict[tx_id] = '\t'.join(tx_fields)

# 2. Update gene coordinates based on updated transcript coordinates in tx_dict
for gene_id in gene_dict:
associated_transcripts = [tx for tx, gid in tx_to_gene_dict.items() if gid == gene_id]

starts = []
ends = []
for tx_id in associated_transcripts:
tx_fields = tx_dict[tx_id].split('\t')
starts.append(int(tx_fields[3]))
ends.append(int(tx_fields[4]))

# Find the min and max coordinates for the gene
min_coord = min(starts)
max_coord = max(ends)

gene_fields = gene_dict[gene_id].split('\t')
gene_fields[3] = str(min_coord)
gene_fields[4] = str(max_coord)
gene_dict[gene_id] = '\t'.join(gene_fields)

return gene_dict, tx_dict


def print_gtf(gtf_dict, gene_dict, tx_to_gene_dict, tx_dict):
"""
Print GTF lines based on gene_dict and gtf_dict.
Args:
- gtf_dict (dict): Dictionary with transcript IDs as keys and lists of GTF feature lines as values.
- gene_dict (dict): Dictionary with transcript IDs as keys and the corresponding gene line as value.
- gtf_dict (dict): Dictionary with transcript IDs as keys and lists of non-gene feature lines as values.
- gene_dict (dict): Dictionary with gene IDs extracted from the last column of gene feature lines as keys and the corresponding gene line as value.
- tx_to_gene_dict (dict): Dictionary with transcript IDs as keys and the corresponding gene ID as value, extracted from the last column of transcript feature lines.
- tx_dict (dict): Dictionary with transcript IDs as keys and the corresponding transcript line as value. The entire last column is used as the transcript ID.
Returns:
None: Prints the GTF lines to stdout.
"""

printed_gene = {}
# Iterate over gene_dict entries
for transcript_id, gene_line in gene_dict.items():
print("I am in a gene")
# Print the gene entry
print(gene_line)

# Retrieve the features and sort them by start position (4th column in GTF)
sorted_features = sorted(gtf_dict.get(transcript_id, []), key=lambda x: int(x.split('\t')[3]))

# Print corresponding transcript and other feature lines from gtf_dict
for tx_id, tx_line in tx_dict.items():
if not tx_id in printed_gene:
print(gene_dict[tx_to_gene_dict[tx_id]])
gene_printed = True
print(tx_line)
sorted_features = sorted(gtf_dict.get(tx_id, []), key=lambda x: int(x.split('\t')[3]))
for feature in sorted_features:
# If the feature is a UTR line, remove the exon_number
if "UTR" in feature:
feature = re.sub(r'exon_number "[0-9]+";', '', feature).strip()
print(feature)
# split line into fields
fields = feature.split("\t")
# build new gtf line
print(fields[0] + "\tstringtie2utr\t", "\t".join(fields[2:8]), "\ttranscript_id \"" + tx_id + "\"; gene_id \"" + tx_to_gene_dict[tx_id] + "\";")
elif "StringTie" not in feature:
print(feature)



Expand All @@ -368,10 +422,10 @@ def main():
stringtie_file = args.stringtie

# read the braker_file
braker_non_gene_dict, braker_gene_dict = read_gtf(braker_file)
braker_non_gene_dict, braker_gene_line_dict, braker_tx_to_gene_dict, braker_tx_dict = read_gtf(braker_file)

# read the stringtie_file
stringtie_non_gene_dict, stringtie_gene_dict = read_gtf(stringtie_file)
stringtie_non_gene_dict, stringtie_gene_line_dict, stringie_tx_to_gene_dict, stringtie_tx_dict = read_gtf(stringtie_file)
# add intron features to the stringtie_non_gene_dict
stringtie_non_gene_dict = add_intron_features(stringtie_non_gene_dict)

Expand All @@ -391,8 +445,10 @@ def main():
braker_gtf = merge_features(braker_non_gene_dict, stringtie_non_gene_dict, final_matching_tx)
# compute UTR features
braker_gtf = compute_utr_features(braker_gtf)
# fix gene and transcript coordinates
braker_gene_line_dict, braker_tx_dict = fix_feature_coordinates(braker_gtf, braker_gene_line_dict, braker_tx_to_gene_dict, braker_tx_dict)
# print the updated braker_gtf
print_gtf(braker_gtf, braker_gene_dict)
print_gtf(braker_gtf, braker_gene_line_dict, braker_tx_to_gene_dict, braker_tx_dict)

if __name__ == "__main__":
main()

0 comments on commit 78d736b

Please sign in to comment.