Skip to content

Commit

Permalink
Dev: Using Bio.AlignIO in updated GREMLIN class
Browse files Browse the repository at this point in the history
  • Loading branch information
niklases committed Jun 24, 2024
1 parent 897ef1e commit 700bb60
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 71 deletions.
19 changes: 16 additions & 3 deletions pypef/dca/gremlin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,19 @@ def __init__(
optimize=True,
gap_cutoff=0.5,
eff_cutoff=0.8,
opt_iter=100
opt_iter=100,
max_msa_seqs: int | None = 10000,
):
self.char_alphabet = char_alphabet
self.gap_cutoff = gap_cutoff
self.eff_cutoff = eff_cutoff
self.opt_iter = opt_iter
self.max_msa_seqs = max_msa_seqs
self.states = len(self.char_alphabet)
print('self.states', self.states)
self.seqs, _, _ = get_sequences_from_file(alignment)
self.msa_ori = self.get_msa_ori()
print(f'MSA shape: {np.shape(self.msa_ori)}')
self.n_col_ori = self.msa_ori.shape[1]
if wt_seq is not None:
self.wt_seq = wt_seq
Expand All @@ -92,14 +96,17 @@ def __init__(
raise SystemError("Length of (provided) wild-type sequence does not match "
"number of MSA columns, i.e., common MSA sequence length.")
self.msa_trimmed, self.v_idx, self.w_idx, self.w_rel_idx, self.gaps = self.filt_gaps(self.msa_ori)
print(f'OLD: {np.shape(self.msa_trimmed)}, {np.shape(self.v_idx)}, {np.shape(self.w_idx)}, {np.shape(self.w_rel_idx)}, {np.shape(self.gaps)}')
self.msa_weights = self.get_eff_msa_weights(self.msa_trimmed)
self.n_eff = np.sum(self.msa_weights)
self.n_row = self.msa_trimmed.shape[0]
self.n_col = self.msa_trimmed.shape[1]
self.v_ini, self.w_ini, self.aa_counts = self.initialize_v_w(remove_gap_entries=False)
print(f'OLD INI SHAPES: {np.shape(self.v_ini)}, {np.shape(self.w_ini)}, {np.shape(self.aa_counts)}')
self.optimize = optimize
if self.optimize:
self.v_opt, self.w_opt = self.run_opt_tf()
print(f'OLD OPT: {np.shape(self.v_opt)}, {np.shape(self.w_opt)}')
self.x_wt = self.collect_encoded_sequences(np.atleast_1d(self.wt_seq))

def a2n_dict(self):
Expand Down Expand Up @@ -141,18 +148,24 @@ def get_v_idx_w_idx(self):
def get_msa_ori(self):
"""converts list of sequences to msa"""
msa_ori = []
for seq in self.seqs:
msa_ori.append([self.aa2int(aa.upper()) for aa in seq])
for i, seq in enumerate(self.seqs):
if i < self.max_msa_seqs:
msa_ori.append([self.aa2int(aa.upper()) for aa in seq])
else:
print(f'Reached max. number of MSA sequences ({self.max_msa_seqs})...')
break
msa_ori = np.array(msa_ori)
return msa_ori

def filt_gaps(self, msa_ori):
"""filters alignment to remove gappy positions"""
print('old inner:', np.shape(msa_ori))
tmp = (msa_ori == self.states - 1).astype(float)
non_gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] < self.gap_cutoff)[0]

gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] >= self.gap_cutoff)[0]
logger.info(f'Gap positions (removed from MSA; 0-indexed):\n{gaps}')
print(f'Gap positions (removed from MSA; 0-indexed):\n{gaps}')
ncol_trimmed = len(non_gaps)
v_idx = non_gaps
w_idx = v_idx[np.stack(np.triu_indices(ncol_trimmed, 1), -1)]
Expand Down
82 changes: 32 additions & 50 deletions pypef/dca/new_gremlin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,14 @@
import pickle
import numpy as np
import matplotlib.pyplot as plt
from Bio import AlignIO
from scipy.spatial.distance import pdist, squareform
from scipy.special import logsumexp
from scipy.stats import boxcox
import pandas as pd
import tensorflow as tf
tf.get_logger().setLevel('DEBUG')

from pypef.utils.variant_data import get_sequences_from_file


class GREMLIN:
"""
Expand All @@ -74,18 +73,18 @@ class GREMLIN:
def __init__(
self,
alignment: str | PathLike,
char_alphabet: str = "ACDEFGHIKLMNPQRSTVWY-X",
char_alphabet: str = "ARNDCQEGHILKMFPSTWYV-",
wt_seq=None,
offset=0,
optimize=True,
gap_cutoff=1.0,
eff_cutoff=1.0,
gap_cutoff=0.5,
eff_cutoff=0.8,
opt_iter=100,
max_msa_seqs: int | None = 10000,
seqs=None
):
self.char_alphabet = char_alphabet
self.allowed_chars = "ABCDEFGHIKLMNPQRSTVWYX-."
self.allowed_chars = "ARNDCQEGHILKMFPSTWYV-"
self.allowed_chars += self.allowed_chars.lower()
self.offset = offset
self.gap_cutoff = gap_cutoff
Expand All @@ -96,13 +95,16 @@ def __init__(
else:
self.max_msa_seqs = max_msa_seqs
self.states = len(self.char_alphabet)
print('self.states', self.states)
print('Loading MSA...')
if seqs is None:
self.seqs, _ = self.get_sequences_from_msa(alignment)
self.seqs, self.seq_ids = self.get_sequences_from_msa(alignment)
else:
self.seqs = seqs
self.seq_ids = np.array([n for n in range(len(self.seqs))])
print(f'Found {len(self.seqs)} sequences in the MSA...')
self.msa_ori = self.get_msa_ori()
print(f'MSA shape: {np.shape(self.msa_ori)}')
self.n_col_ori = self.msa_ori.shape[1]
if wt_seq is not None:
self.wt_seq = wt_seq
Expand All @@ -117,20 +119,24 @@ def __init__(
f"i.e., common MSA sequence length.")
print('Filtering gaps...')
self.msa_trimmed, self.v_idx, self.w_idx, self.w_rel_idx, self.gaps = self.filt_gaps(self.msa_ori)
print(f'NEW: {np.shape(self.msa_trimmed)}, {np.shape(self.v_idx)}, {np.shape(self.w_idx)}, {np.shape(self.w_rel_idx)}, {np.shape(self.gaps)}')
print('Getting effective sequence weights...')
self.msa_weights = self.get_eff_msa_weights(self.msa_trimmed)
self.n_eff = np.sum(self.msa_weights)
self.n_row = self.msa_trimmed.shape[0]
self.n_col = self.msa_trimmed.shape[1]
print('Initializing v and W terms based on MSA frequencies...')
self.v_ini, self.w_ini, self.aa_counts = self.initialize_v_w(remove_gap_entries=False)
print(f'NEW INI SHAPES: {np.shape(self.v_ini)}, {np.shape(self.w_ini)}, {np.shape(self.aa_counts)}')
self.aa_freqs = self.aa_counts / self.n_row
self.optimize = optimize
if self.optimize:
self.v_opt_with_gaps, self.w_opt_with_gaps = self.run_opt_tf()
print(f'NEW OPT: {np.shape(self.v_opt_with_gaps)}, {np.shape(self.v_opt_with_gaps)}')
no_gap_states = self.states - 1
self.v_opt = self.v_opt_with_gaps[:, :no_gap_states],
self.w_opt = self.w_opt_with_gaps[:, :no_gap_states, :, :no_gap_states]
print(f'NEW OPT: {np.shape(self.v_opt)}, {np.shape(self.w_opt)}')
self.x_wt = self.collect_encoded_sequences(np.atleast_1d(self.wt_seq))

def get_sequences_from_msa(self, msa_file: str):
Expand All @@ -148,25 +154,16 @@ def get_sequences_from_msa(self, msa_file: str):
Path to MSA in FASTA or A2M format.
"""
sequences = []
names_of_seqs = []
with open(msa_file, 'r') as f:
words = ""
for line in f:
if line.startswith('>'):
if words != "":
sequences.append(words[self.offset:])
words = line.split('>')
names_of_seqs.append(words[1].strip())
words = ""
elif line.startswith('#'):
pass # are comments
else:
line = line.strip()
words += line
if words != "":
sequences.append(words[self.offset:])
assert len(sequences) == len(names_of_seqs), f"{len(sequences)}, {len(names_of_seqs)}"
return np.array(sequences), np.array(names_of_seqs)
seq_ids = []
alignment = AlignIO.read(open(msa_file), "fasta")
print("Alignment length %i" % alignment.get_alignment_length())
for record in alignment:
#print(record.seq + " " + record.id)
sequences.append(str(record.seq))
seq_ids.append(str(record.id))
assert len(sequences) == len(seq_ids), f"{len(sequences)}, {len(seq_ids)}"
print("SSSS", sequences[0])
return np.array(sequences), np.array(seq_ids)

def a2n_dict(self):
"""convert alphabet to numerical integer values, e.g.:
Expand Down Expand Up @@ -211,38 +208,21 @@ def get_msa_ori(self):
Converts list of sequences to MSA.
Also checks for unknown amino acid characters and removes those sequences from the MSA.
"""
n_skipped = 0
msa_ori = []
for i, seq in enumerate(self.seqs):
skip = False
for aa in seq:
if aa not in self.allowed_chars:
if n_skipped == 0:
f"The input file(s) (MSA or train/test sets) contain(s) "
f"unknown protein sequence characters "
f"(e.g.: \"{aa}\" in sequence {i + 1}). "
f"Will remove those sequences from MSA!"
skip = True
if skip:
n_skipped += 1
if i <= self.max_msa_seqs and not skip:
for i, (seq, seq_id) in enumerate(zip(self.seqs, self.seq_ids)):
if i < self.max_msa_seqs:
msa_ori.append([self.aa2int(aa.upper()) for aa in seq])
elif i >= self.max_msa_seqs:
print(f'Reached the number of maximal MSA sequences ({self.max_msa_seqs}), '
'skipping the rest of the MSA sequences...')
else:
print(f'Reached max. number of MSA sequences ({self.max_msa_seqs})...')
break
try:
msa_ori = np.array(msa_ori)
except ValueError:
raise ValueError("The provided MSA seems to have inhomogeneous "
"shape, i.e., unequal sequence length.")
msa_ori = np.array(msa_ori)
return msa_ori

def filt_gaps(self, msa_ori):
"""filters alignment to remove gappy positions"""
print('new inner:', np.shape(msa_ori))
tmp = (msa_ori == self.states - 1).astype(float)
non_gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] < self.gap_cutoff)[0]

gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] >= self.gap_cutoff)[0]
print(f'Gap positions (removed from MSA; 0-indexed):\n{gaps}')
ncol_trimmed = len(non_gaps)
Expand Down Expand Up @@ -444,7 +424,9 @@ def feed(feed_all=False):
# save the v and w parameters of the MRF
v_opt = sess.run(v)
w_opt = sess.run(w)
return v_opt, w_opt
no_gap_states = self.states - 1
return v_opt[:, :no_gap_states], w_opt[:, :no_gap_states, :, :no_gap_states]
#return v_opt, w_opt


def initialize_v_w(self, remove_gap_entries=True):
Expand Down
9 changes: 5 additions & 4 deletions pypef/utils/variant_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ def get_sequences_from_file(
if any(not c in line for c in allowed_chars):
for c in line:
if c not in allowed_chars:
raise SystemError(
f"The input file(s) (MSA or train/test sets) contain(s) unknown protein sequence characters "
f"(e.g.: \"{c}\"). Note that an MSA has to be provided in FASTA or A2M format (or formatted as "
F"pure linebreak-separated sequences).")
pass
#raise SystemError(
# f"The input file(s) (MSA or train/test sets) contain(s) unknown protein sequence characters "
# f"(e.g.: \"{c}\"). Note that an MSA has to be provided in FASTA or A2M format (or formatted as "
# F"pure linebreak-separated sequences).")
words += line
except IndexError:
raise IndexError("Sequences in input file(s) likely "
Expand Down
38 changes: 24 additions & 14 deletions scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import json
import pandas as pd
import numpy as np
from scipy.stats import spearmanr
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
print(sys.path)
Expand All @@ -23,11 +24,14 @@


for i, (dset_key, dset_paths) in enumerate(mut_data.items()):
#if i >= 0:
#try:
print(i, '\n===============================================================')
print('\n', i, '\n===============================================================')
print('#'*60 + ' OLD MODEL ' + '#'*60)
csv_path = dset_paths['CSV_path']
msa_path = dset_paths['MSA_path']
wt_seq = dset_paths['WT_sequence']
print(msa_path)
variant_fitness_data = pd.read_csv(csv_path, sep=',')
variants = variant_fitness_data['mutant']
fitnesses = variant_fitness_data['DMS_score']
Expand All @@ -36,36 +40,42 @@
variants_split.append(variant.split('/'))
variants, fitnesses, sequences = get_seqs_from_var_name(wt_seq, variants_split, fitnesses)
####
with open(msa_path, 'r') as fh:
cnt = 0
for line in fh:
cnt += 1
if cnt > 100000:
print('Too big MSA, continuing...')
if len(wt_seq) > 800:
print('Sequence length over 800, continuing...')
continue
gremlin_old = GREMLIN_OLD(alignment=msa_path, wt_seq=wt_seq)
#with open(msa_path, 'r') as fh:
# cnt = 0
# for line in fh:
# cnt += 1
#if cnt > 100000:
# print('Too big MSA, continuing...')
# continue
gremlin_old = GREMLIN_OLD(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=10000)
gaps = gremlin_old.gaps
variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses)
x_dca = gremlin_old.collect_encoded_sequences(sequences)
print(f'N Variants remaining after excluding non-DCA-encodable positions: {len(x_dca)}')
x_wt = gremlin_old.x_wt
# Statistical model performance
y_pred = get_delta_e_statistical_model(x_dca, x_wt)
print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred)[0]:.3f}')
y_pred_old = get_delta_e_statistical_model(x_dca, x_wt)
print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred_old)[0]:.3f}')
# Split double and higher substituted variants to multiple single substitutions separated by '/'
assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences)
print('#'*60 + ' NEW MODEL ' + '#'*60)
####
gremlin_new = GREMLIN_NEW(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=None)
gremlin_new = GREMLIN_NEW(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=10000)
gaps = gremlin_new.gaps
variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses)
#variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses)
x_dca = gremlin_new.collect_encoded_sequences(sequences)
print(f'N Variants remaining after excluding non-DCA-encodable positions: {len(x_dca)}')
x_wt = gremlin_new.x_wt
# Statistical model performance
y_pred = get_delta_e_statistical_model(x_dca, x_wt)
print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred)[0]:.3f}')
y_pred_new = get_delta_e_statistical_model(x_dca, x_wt)
print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred_new)[0]:.3f}')
# Split double and higher substituted variants to multiple single substitutions separated by '/'
assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences)
np.testing.assert_almost_equal(spearmanr(fitnesses, y_pred_old)[0], spearmanr(fitnesses, y_pred_new)[0], decimal=3)
#except SystemError: # Check MSAs
# continue
#exit()

0 comments on commit 700bb60

Please sign in to comment.