diff --git a/pypef/dca/gremlin_inference.py b/pypef/dca/gremlin_inference.py index b715f60..cdeeaec 100644 --- a/pypef/dca/gremlin_inference.py +++ b/pypef/dca/gremlin_inference.py @@ -72,15 +72,19 @@ def __init__( optimize=True, gap_cutoff=0.5, eff_cutoff=0.8, - opt_iter=100 + opt_iter=100, + max_msa_seqs: int | None = 10000, ): self.char_alphabet = char_alphabet self.gap_cutoff = gap_cutoff self.eff_cutoff = eff_cutoff self.opt_iter = opt_iter + self.max_msa_seqs = max_msa_seqs self.states = len(self.char_alphabet) + print('self.states', self.states) self.seqs, _, _ = get_sequences_from_file(alignment) self.msa_ori = self.get_msa_ori() + print(f'MSA shape: {np.shape(self.msa_ori)}') self.n_col_ori = self.msa_ori.shape[1] if wt_seq is not None: self.wt_seq = wt_seq @@ -92,14 +96,17 @@ def __init__( raise SystemError("Length of (provided) wild-type sequence does not match " "number of MSA columns, i.e., common MSA sequence length.") self.msa_trimmed, self.v_idx, self.w_idx, self.w_rel_idx, self.gaps = self.filt_gaps(self.msa_ori) + print(f'OLD: {np.shape(self.msa_trimmed)}, {np.shape(self.v_idx)}, {np.shape(self.w_idx)}, {np.shape(self.w_rel_idx)}, {np.shape(self.gaps)}') self.msa_weights = self.get_eff_msa_weights(self.msa_trimmed) self.n_eff = np.sum(self.msa_weights) self.n_row = self.msa_trimmed.shape[0] self.n_col = self.msa_trimmed.shape[1] self.v_ini, self.w_ini, self.aa_counts = self.initialize_v_w(remove_gap_entries=False) + print(f'OLD INI SHAPES: {np.shape(self.v_ini)}, {np.shape(self.w_ini)}, {np.shape(self.aa_counts)}') self.optimize = optimize if self.optimize: self.v_opt, self.w_opt = self.run_opt_tf() + print(f'OLD OPT: {np.shape(self.v_opt)}, {np.shape(self.w_opt)}') self.x_wt = self.collect_encoded_sequences(np.atleast_1d(self.wt_seq)) def a2n_dict(self): @@ -141,18 +148,24 @@ def get_v_idx_w_idx(self): def get_msa_ori(self): """converts list of sequences to msa""" msa_ori = [] - for seq in self.seqs: - msa_ori.append([self.aa2int(aa.upper()) for aa in seq]) + for i, seq in enumerate(self.seqs): + if i < self.max_msa_seqs: + msa_ori.append([self.aa2int(aa.upper()) for aa in seq]) + else: + print(f'Reached max. number of MSA sequences ({self.max_msa_seqs})...') + break msa_ori = np.array(msa_ori) return msa_ori def filt_gaps(self, msa_ori): """filters alignment to remove gappy positions""" + print('old inner:', np.shape(msa_ori)) tmp = (msa_ori == self.states - 1).astype(float) non_gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] < self.gap_cutoff)[0] gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] >= self.gap_cutoff)[0] logger.info(f'Gap positions (removed from MSA; 0-indexed):\n{gaps}') + print(f'Gap positions (removed from MSA; 0-indexed):\n{gaps}') ncol_trimmed = len(non_gaps) v_idx = non_gaps w_idx = v_idx[np.stack(np.triu_indices(ncol_trimmed, 1), -1)] diff --git a/pypef/dca/new_gremlin_inference.py b/pypef/dca/new_gremlin_inference.py index 56eb0c7..9f4aa66 100644 --- a/pypef/dca/new_gremlin_inference.py +++ b/pypef/dca/new_gremlin_inference.py @@ -53,6 +53,7 @@ import pickle import numpy as np import matplotlib.pyplot as plt +from Bio import AlignIO from scipy.spatial.distance import pdist, squareform from scipy.special import logsumexp from scipy.stats import boxcox @@ -60,8 +61,6 @@ import tensorflow as tf tf.get_logger().setLevel('DEBUG') -from pypef.utils.variant_data import get_sequences_from_file - class GREMLIN: """ @@ -74,18 +73,18 @@ class GREMLIN: def __init__( self, alignment: str | PathLike, - char_alphabet: str = "ACDEFGHIKLMNPQRSTVWY-X", + char_alphabet: str = "ARNDCQEGHILKMFPSTWYV-", wt_seq=None, offset=0, optimize=True, - gap_cutoff=1.0, - eff_cutoff=1.0, + gap_cutoff=0.5, + eff_cutoff=0.8, opt_iter=100, max_msa_seqs: int | None = 10000, seqs=None ): self.char_alphabet = char_alphabet - self.allowed_chars = "ABCDEFGHIKLMNPQRSTVWYX-." + self.allowed_chars = "ARNDCQEGHILKMFPSTWYV-" self.allowed_chars += self.allowed_chars.lower() self.offset = offset self.gap_cutoff = gap_cutoff @@ -96,13 +95,16 @@ def __init__( else: self.max_msa_seqs = max_msa_seqs self.states = len(self.char_alphabet) + print('self.states', self.states) print('Loading MSA...') if seqs is None: - self.seqs, _ = self.get_sequences_from_msa(alignment) + self.seqs, self.seq_ids = self.get_sequences_from_msa(alignment) else: self.seqs = seqs + self.seq_ids = np.array([n for n in range(len(self.seqs))]) print(f'Found {len(self.seqs)} sequences in the MSA...') self.msa_ori = self.get_msa_ori() + print(f'MSA shape: {np.shape(self.msa_ori)}') self.n_col_ori = self.msa_ori.shape[1] if wt_seq is not None: self.wt_seq = wt_seq @@ -117,6 +119,7 @@ def __init__( f"i.e., common MSA sequence length.") print('Filtering gaps...') self.msa_trimmed, self.v_idx, self.w_idx, self.w_rel_idx, self.gaps = self.filt_gaps(self.msa_ori) + print(f'NEW: {np.shape(self.msa_trimmed)}, {np.shape(self.v_idx)}, {np.shape(self.w_idx)}, {np.shape(self.w_rel_idx)}, {np.shape(self.gaps)}') print('Getting effective sequence weights...') self.msa_weights = self.get_eff_msa_weights(self.msa_trimmed) self.n_eff = np.sum(self.msa_weights) @@ -124,13 +127,16 @@ def __init__( self.n_col = self.msa_trimmed.shape[1] print('Initializing v and W terms based on MSA frequencies...') self.v_ini, self.w_ini, self.aa_counts = self.initialize_v_w(remove_gap_entries=False) + print(f'NEW INI SHAPES: {np.shape(self.v_ini)}, {np.shape(self.w_ini)}, {np.shape(self.aa_counts)}') self.aa_freqs = self.aa_counts / self.n_row self.optimize = optimize if self.optimize: self.v_opt_with_gaps, self.w_opt_with_gaps = self.run_opt_tf() + print(f'NEW OPT: {np.shape(self.v_opt_with_gaps)}, {np.shape(self.v_opt_with_gaps)}') no_gap_states = self.states - 1 self.v_opt = self.v_opt_with_gaps[:, :no_gap_states], self.w_opt = self.w_opt_with_gaps[:, :no_gap_states, :, :no_gap_states] + print(f'NEW OPT: {np.shape(self.v_opt)}, {np.shape(self.w_opt)}') self.x_wt = self.collect_encoded_sequences(np.atleast_1d(self.wt_seq)) def get_sequences_from_msa(self, msa_file: str): @@ -148,25 +154,16 @@ def get_sequences_from_msa(self, msa_file: str): Path to MSA in FASTA or A2M format. """ sequences = [] - names_of_seqs = [] - with open(msa_file, 'r') as f: - words = "" - for line in f: - if line.startswith('>'): - if words != "": - sequences.append(words[self.offset:]) - words = line.split('>') - names_of_seqs.append(words[1].strip()) - words = "" - elif line.startswith('#'): - pass # are comments - else: - line = line.strip() - words += line - if words != "": - sequences.append(words[self.offset:]) - assert len(sequences) == len(names_of_seqs), f"{len(sequences)}, {len(names_of_seqs)}" - return np.array(sequences), np.array(names_of_seqs) + seq_ids = [] + alignment = AlignIO.read(open(msa_file), "fasta") + print("Alignment length %i" % alignment.get_alignment_length()) + for record in alignment: + #print(record.seq + " " + record.id) + sequences.append(str(record.seq)) + seq_ids.append(str(record.id)) + assert len(sequences) == len(seq_ids), f"{len(sequences)}, {len(seq_ids)}" + print("SSSS", sequences[0]) + return np.array(sequences), np.array(seq_ids) def a2n_dict(self): """convert alphabet to numerical integer values, e.g.: @@ -211,38 +208,21 @@ def get_msa_ori(self): Converts list of sequences to MSA. Also checks for unknown amino acid characters and removes those sequences from the MSA. """ - n_skipped = 0 msa_ori = [] - for i, seq in enumerate(self.seqs): - skip = False - for aa in seq: - if aa not in self.allowed_chars: - if n_skipped == 0: - f"The input file(s) (MSA or train/test sets) contain(s) " - f"unknown protein sequence characters " - f"(e.g.: \"{aa}\" in sequence {i + 1}). " - f"Will remove those sequences from MSA!" - skip = True - if skip: - n_skipped += 1 - if i <= self.max_msa_seqs and not skip: + for i, (seq, seq_id) in enumerate(zip(self.seqs, self.seq_ids)): + if i < self.max_msa_seqs: msa_ori.append([self.aa2int(aa.upper()) for aa in seq]) - elif i >= self.max_msa_seqs: - print(f'Reached the number of maximal MSA sequences ({self.max_msa_seqs}), ' - 'skipping the rest of the MSA sequences...') + else: + print(f'Reached max. number of MSA sequences ({self.max_msa_seqs})...') break - try: - msa_ori = np.array(msa_ori) - except ValueError: - raise ValueError("The provided MSA seems to have inhomogeneous " - "shape, i.e., unequal sequence length.") + msa_ori = np.array(msa_ori) return msa_ori def filt_gaps(self, msa_ori): """filters alignment to remove gappy positions""" + print('new inner:', np.shape(msa_ori)) tmp = (msa_ori == self.states - 1).astype(float) non_gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] < self.gap_cutoff)[0] - gaps = np.where(np.sum(tmp.T, -1).T / msa_ori.shape[0] >= self.gap_cutoff)[0] print(f'Gap positions (removed from MSA; 0-indexed):\n{gaps}') ncol_trimmed = len(non_gaps) @@ -444,7 +424,9 @@ def feed(feed_all=False): # save the v and w parameters of the MRF v_opt = sess.run(v) w_opt = sess.run(w) - return v_opt, w_opt + no_gap_states = self.states - 1 + return v_opt[:, :no_gap_states], w_opt[:, :no_gap_states, :, :no_gap_states] + #return v_opt, w_opt def initialize_v_w(self, remove_gap_entries=True): diff --git a/pypef/utils/variant_data.py b/pypef/utils/variant_data.py index a6c726b..fc2d858 100644 --- a/pypef/utils/variant_data.py +++ b/pypef/utils/variant_data.py @@ -139,10 +139,11 @@ def get_sequences_from_file( if any(not c in line for c in allowed_chars): for c in line: if c not in allowed_chars: - raise SystemError( - f"The input file(s) (MSA or train/test sets) contain(s) unknown protein sequence characters " - f"(e.g.: \"{c}\"). Note that an MSA has to be provided in FASTA or A2M format (or formatted as " - F"pure linebreak-separated sequences).") + pass + #raise SystemError( + # f"The input file(s) (MSA or train/test sets) contain(s) unknown protein sequence characters " + # f"(e.g.: \"{c}\"). Note that an MSA has to be provided in FASTA or A2M format (or formatted as " + # F"pure linebreak-separated sequences).") words += line except IndexError: raise IndexError("Sequences in input file(s) likely " diff --git a/scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py b/scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py index d3d16f4..8648058 100644 --- a/scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py +++ b/scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py @@ -4,6 +4,7 @@ import sys import json import pandas as pd +import numpy as np from scipy.stats import spearmanr sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) print(sys.path) @@ -23,11 +24,14 @@ for i, (dset_key, dset_paths) in enumerate(mut_data.items()): + #if i >= 0: #try: - print(i, '\n===============================================================') + print('\n', i, '\n===============================================================') + print('#'*60 + ' OLD MODEL ' + '#'*60) csv_path = dset_paths['CSV_path'] msa_path = dset_paths['MSA_path'] wt_seq = dset_paths['WT_sequence'] + print(msa_path) variant_fitness_data = pd.read_csv(csv_path, sep=',') variants = variant_fitness_data['mutant'] fitnesses = variant_fitness_data['DMS_score'] @@ -36,36 +40,42 @@ variants_split.append(variant.split('/')) variants, fitnesses, sequences = get_seqs_from_var_name(wt_seq, variants_split, fitnesses) #### - with open(msa_path, 'r') as fh: - cnt = 0 - for line in fh: - cnt += 1 - if cnt > 100000: - print('Too big MSA, continuing...') + if len(wt_seq) > 800: + print('Sequence length over 800, continuing...') continue - gremlin_old = GREMLIN_OLD(alignment=msa_path, wt_seq=wt_seq) + #with open(msa_path, 'r') as fh: + # cnt = 0 + # for line in fh: + # cnt += 1 + #if cnt > 100000: + # print('Too big MSA, continuing...') + # continue + gremlin_old = GREMLIN_OLD(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=10000) gaps = gremlin_old.gaps variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses) x_dca = gremlin_old.collect_encoded_sequences(sequences) print(f'N Variants remaining after excluding non-DCA-encodable positions: {len(x_dca)}') x_wt = gremlin_old.x_wt # Statistical model performance - y_pred = get_delta_e_statistical_model(x_dca, x_wt) - print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred)[0]:.3f}') + y_pred_old = get_delta_e_statistical_model(x_dca, x_wt) + print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred_old)[0]:.3f}') # Split double and higher substituted variants to multiple single substitutions separated by '/' assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences) + print('#'*60 + ' NEW MODEL ' + '#'*60) #### - gremlin_new = GREMLIN_NEW(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=None) + gremlin_new = GREMLIN_NEW(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=10000) gaps = gremlin_new.gaps - variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses) + #variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses) x_dca = gremlin_new.collect_encoded_sequences(sequences) print(f'N Variants remaining after excluding non-DCA-encodable positions: {len(x_dca)}') x_wt = gremlin_new.x_wt # Statistical model performance - y_pred = get_delta_e_statistical_model(x_dca, x_wt) - print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred)[0]:.3f}') + y_pred_new = get_delta_e_statistical_model(x_dca, x_wt) + print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {spearmanr(fitnesses, y_pred_new)[0]:.3f}') # Split double and higher substituted variants to multiple single substitutions separated by '/' assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences) + np.testing.assert_almost_equal(spearmanr(fitnesses, y_pred_old)[0], spearmanr(fitnesses, y_pred_new)[0], decimal=3) #except SystemError: # Check MSAs # continue + #exit()