Skip to content

Commit

Permalink
Update PGym test run script
Browse files Browse the repository at this point in the history
  • Loading branch information
niklases committed Jun 26, 2024
1 parent 883da1f commit 02e58f3
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
# Add local PyPEF path if not using pip-installed PyPEF version
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from pypef.dca.gremlin_inference import GREMLIN
from pypef.dca.hybrid_model import get_delta_e_statistical_model, remove_gap_pos
from pypef.utils.variant_data import get_seqs_from_var_name
Expand All @@ -21,14 +21,15 @@ def plot_performance(mut_data, plot_name, mut_sep=':'):
tested_dsets = []
dset_perfs = []
for i, (dset_key, dset_paths) in enumerate(mut_data.items()):
#if i < 3:
if i < 8:
#try:
print('\n', i, '\n===============================================================')
csv_path = dset_paths['CSV_path']
msa_path = dset_paths['MSA_path']
wt_seq = dset_paths['WT_sequence']
print(msa_path)
variant_fitness_data = pd.read_csv(csv_path, sep=',')
print(np.shape(variant_fitness_data))
variants = variant_fitness_data['mutant']
fitnesses = variant_fitness_data['DMS_score']
variants_split = []
Expand Down Expand Up @@ -68,15 +69,16 @@ def plot_performance(mut_data, plot_name, mut_sep=':'):
assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences)
#except SystemError: # Check MSAs
# continue
tested_dsets.append(f'{dset_key} ({100.0 - (ratio_input_vars_at_gaps * 100):.2f}%, '
+ r'$N_\mathrm{AASubs:}$' + f'{max_muts})')
tested_dsets.append(f'{dset_key} ({100.0 - (ratio_input_vars_at_gaps * 100):.2f}%, {max_muts})')
dset_perfs.append(abs(spearmanr(fitnesses, y_pred_new)[0]))
plt.figure(figsize=(20, 12))
plt.plot(range(len(tested_dsets)), dset_perfs, 'o--', markersize=12)
plt.figure(figsize=(26, 12))
plt.plot(range(len(tested_dsets)), dset_perfs, 'o--', markersize=8)
plt.plot(range(len(tested_dsets)), np.full(np.shape(tested_dsets), np.mean(dset_perfs)), 'k--')
plt.text(len(tested_dsets) - 1, np.mean(dset_perfs), f'{np.mean(dset_perfs):.2f}')
plt.xticks(range(len(tested_dsets)), tested_dsets, rotation=45, ha='right')
plt.tight_layout()
plt.ylim(0.0, 1.0)
plt.ylabel(r'|Spearmanr $\rho$|')
plt.savefig(os.path.join(os.path.dirname(__file__), f'{plot_name}.png'), dpi=300)
print('Saved file as ' + os.path.join(os.path.dirname(__file__), f'{plot_name}.png') + '.')

Expand All @@ -86,4 +88,4 @@ def plot_performance(mut_data, plot_name, mut_sep=':'):
with open(higher_mut_data, 'r') as fh:
h_mut_data = json.loads(fh.read())
plot_performance(mut_data=s_mut_data, plot_name='single_point_mut_performance')
plot_performance(mut_data=h_mut_data, plot_name='multi_point_mut_performance')
#plot_performance(mut_data=h_mut_data, plot_name='multi_point_mut_performance')

0 comments on commit 02e58f3

Please sign in to comment.