Skip to content

Commit

Permalink
Update proteingym run scripts and get_seqs_from_var_name function
Browse files Browse the repository at this point in the history
Added optional/default-set shift_pos and assert_wt_aa inputs to def get_seqs_from_var_name()
  • Loading branch information
niklases committed Jul 14, 2024
1 parent af41bdc commit 56f70a1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 83 deletions.
10 changes: 8 additions & 2 deletions pypef/utils/variant_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def get_sequences_from_file(
def get_seqs_from_var_name(
wt_seq: str,
substitutions: list,
fitness_values: list
fitness_values: list,
shift_pos: int = 0,
assert_wt_aa: bool = False
) -> tuple[list, list, list]:
"""
Similar to function "get_sequences_from_file" but instead of getting
Expand All @@ -192,8 +194,12 @@ def get_seqs_from_var_name(
name = 'WT'
else:
for single_var in var: # single entries of substitution list
position_index = int(str(single_var)[1:-1]) - 1
position_index = int(str(single_var)[1:-1]) - 1 - shift_pos
new_amino_acid = str(single_var)[-1]
if assert_wt_aa: # Assertion only possible for format AaPosAa, e.g. A123C
assert str(single_var)[0] == temp[position_index], f"Input variant: "\
f"{str(single_var)[0]}{position_index}{new_amino_acid}, WT amino "\
f"acid variant {temp[position_index]}{position_index}{new_amino_acid}"
temp[position_index] = new_amino_acid
# checking if multiple entries are inside list
if separation == 0:
Expand Down
39 changes: 21 additions & 18 deletions scripts/ProteinGym_runs/download_proteingym_and_extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,38 +58,36 @@ def get_single_or_multi_point_mut_data(csv_description_path, datasets_path=None,
pdbs_path = os.path.join(file_dirname, 'ProteinGym_AF2_structures')
pdbs = os.listdir(pdbs_path)
description_df = pd.read_csv(csv_description_path, sep=',')
i_mps = []
i_s = []
for i, n_mp in enumerate(description_df['DMS_number_multiple_mutants'].to_list()):
if description_df['MSA_start'][i] == 1: # TODO: Else shift WT seq by description_df['MSA_start']]
if n_mp > 0:
if not single:
i_mps.append(i)
if n_mp > 0:
if not single:
i_s.append(i)
else:
if single:
i_s.append(i)
else:
if single:
i_mps.append(i)
else:
pass
mp_description_df = description_df.iloc[i_mps, :]
mp_filenames = mp_description_df['DMS_filename'].to_list()
mp_wt_seqs = mp_description_df['target_seq'].to_list()
pass
target_description_df = description_df.iloc[i_s, :]
target_filenames = target_description_df['DMS_filename'].to_list()
target_wt_seqs = target_description_df['target_seq'].to_list()
target_msa_starts = target_description_df['MSA_start'].to_list()
target_msa_ends = target_description_df['MSA_end'].to_list()
print(f'Searching for CSV files in {datasets_path}...')
csv_paths = [os.path.join(datasets_path, mp_filename) for mp_filename in mp_filenames]
csv_paths = [os.path.join(datasets_path, target_filename) for target_filename in target_filenames]
print(f'Found {len(csv_paths)} {type_str}-point datasets, will check if all are available in datasets folder...')
avail_filenames, avail_csvs, avail_wt_seqs = [], [], []
for i, csv_path in enumerate(csv_paths):
if not os.path.isfile(csv_path):
# Used to be an error in files: CHECK: Likely 'Rocklin' mistake in CSV! Should be Tsuboyama(?)
print(f"Did not find CSV file {csv_path} - will remove it from prediction process!")
else:
avail_csvs.append(csv_path)
avail_wt_seqs.append(mp_wt_seqs[i])
avail_filenames.append(os.path.splitext(mp_filenames[i])[0])
print(csv_paths[0])
avail_wt_seqs.append(target_wt_seqs[i])
avail_filenames.append(os.path.splitext(target_filenames[i])[0])
assert len(avail_wt_seqs) == len(avail_csvs)
print(f'Getting data from {len(avail_csvs)} {type_str}-point mutation DMS CSV files...')
dms_mp_data = {}
for i, csv_path in enumerate(avail_csvs):
#df = pd.read_csv(csv_path, sep=',')
begin = avail_filenames[i].split('_')[0] + '_' + avail_filenames[i].split('_')[1]
msa_path=None
for msa in msas:
Expand All @@ -99,12 +97,17 @@ def get_single_or_multi_point_mut_data(csv_description_path, datasets_path=None,
if pdb.startswith(begin):
pdb_path = os.path.join(pdbs_path, pdb)
if msa_path is None or pdb_path is None:
print(f'Did not find a MSA or a PDB beginning with {begin}, continuing...')
continue
target_msa_start = target_msa_starts[i]
target_msa_end = target_msa_ends[i]
dms_mp_data.update({
avail_filenames[i]: {
'CSV_path': csv_path,
'WT_sequence': avail_wt_seqs[i],
'MSA_path': msa_path,
'MSA_start': target_msa_start,
'MSA_end': target_msa_end,
'PDB_path': pdb_path
}
})
Expand Down
135 changes: 72 additions & 63 deletions scripts/ProteinGym_runs/run_performance_tests_proteingym_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import psutil
import gc

import logging
logger = logging.getLogger("pypef")
logger.setLevel(logging.INFO)

single_point_mut_data = os.path.abspath(os.path.join(os.path.dirname(__file__), f"single_point_dms_mut_data.json"))
higher_mut_data = os.path.abspath(os.path.join(os.path.dirname(__file__), f"higher_point_dms_mut_data.json"))
Expand All @@ -25,69 +28,75 @@ def plot_performance(mut_data, plot_name, mut_sep=':'):
tested_dsets = []
dset_perfs = []
for i, (dset_key, dset_paths) in enumerate(mut_data.items()):
if i >= 0:
#try:
print(f'\n{i+1}/{len(mut_data.items())}\n===============================================================')
csv_path = dset_paths['CSV_path']
msa_path = dset_paths['MSA_path']
wt_seq = dset_paths['WT_sequence']
print(msa_path)
time.sleep(5)
# Getting % usage of virtual_memory ( 3rd field)
print('RAM memory % used:', psutil.virtual_memory()[2])
# Getting usage of virtual_memory in GB ( 4th field)
print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)
variant_fitness_data = pd.read_csv(csv_path, sep=',')
print('N_variant-fitness-tuples:', np.shape(variant_fitness_data)[0])
if np.shape(variant_fitness_data)[0] > 400000:
print('More than 400000 variant-fitness pairs which is a potential OOM error risk, skipping dataset...')
continue
variants = variant_fitness_data['mutant']
fitnesses = variant_fitness_data['DMS_score']
variants_split = []
for variant in variants:
# Split double and higher substituted variants to multiple single substitutions;
# e.g. separated by ':' or '/'
variants_split.append(variant.split(mut_sep))
variants, fitnesses, sequences = get_seqs_from_var_name(wt_seq, variants_split, fitnesses)
# Only model sequences with length of max. 800 amino acids to avoid out of memory errors
print('Sequence length:', len(wt_seq))
if len(wt_seq) > 1000:
print('Sequence length over 1000 which is a potential OOM error risk, skipping dataset...')
continue
gremlin_new = GREMLIN(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=10000)
#gaps = gremlin_new.gaps
gaps_1_indexed = gremlin_new.gaps_1_indexed
var_pos = [int(v[1:-1]) for variants in variants_split for v in variants]
n_muts = []
for vs in variants_split:
n_muts.append(len(vs))
max_muts = max(n_muts)
c = 0
for vp in var_pos:
if vp in gaps_1_indexed:
c += 1
print(f'N max. (multiple) amino acid substitutions: {max_muts}')
c = c / max_muts
ratio_input_vars_at_gaps = c / len(var_pos)
if c > 0:
print(f'{int(c)} of {len(var_pos)} ({ratio_input_vars_at_gaps * 100:.2f}%) input variants to be predicted are variants with '
f'amino acid substitutions at gap positions (these variants will be predicted/labeled with a fitness of 0.0).')
if ratio_input_vars_at_gaps >= 1.0:
print('100% substitutions at gap sites, skipping dataset...')
continue
#variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses)
x_dca = gremlin_new.collect_encoded_sequences(sequences)
x_wt = gremlin_new.x_wt
# Statistical model performance
y_pred_new = get_delta_e_statistical_model(x_dca, x_wt)
print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {abs(spearmanr(fitnesses, y_pred_new)[0]):.3f}')
assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences)
#except SystemError: # Check MSAs
# continue
tested_dsets.append(f'{dset_key} ({100.0 - (ratio_input_vars_at_gaps * 100):.2f}%, {max_muts})')
dset_perfs.append(abs(spearmanr(fitnesses, y_pred_new)[0]))
gc.collect() # Potentially GC is needed to free some RAM (deallocated VRAM -> partly stored in RAM?) after run
print(f'\n{i+1}/{len(mut_data.items())}\n===============================================================')
csv_path = dset_paths['CSV_path']
msa_path = dset_paths['MSA_path']
wt_seq = dset_paths['WT_sequence']
msa_start = dset_paths['MSA_start']
msa_end = dset_paths['MSA_end']
print(wt_seq)
wt_seq = wt_seq[msa_start - 1:msa_end]
print('CSV path:', csv_path)
print('MSA path:', msa_path)
print('MSA start:', msa_start)
print('WT sequence (trimmed from MSA start to MSA end):\n' + wt_seq)
time.sleep(5)
# Getting % usage of virtual_memory ( 3rd field)
print('RAM memory % used:', psutil.virtual_memory()[2])
# Getting usage of virtual_memory in GB ( 4th field)
print('RAM Used (GB):', round(psutil.virtual_memory()[3]/1000000000), 3)
variant_fitness_data = pd.read_csv(csv_path, sep=',')
print('N_variant-fitness-tuples:', np.shape(variant_fitness_data)[0])
if np.shape(variant_fitness_data)[0] > 400000:
print('More than 400000 variant-fitness pairs which is a potential OOM error risk, skipping dataset...')
continue
variants = variant_fitness_data['mutant']
fitnesses = variant_fitness_data['DMS_score']
variants_split = []
for variant in variants:
# Split double and higher substituted variants to multiple single substitutions;
# e.g. separated by ':' or '/'
variants_split.append(variant.split(mut_sep))
variants, fitnesses, sequences = get_seqs_from_var_name(
wt_seq, variants_split, fitnesses, shift_pos=msa_start - 1, assert_wt_aa=True)
# Only model sequences with length of max. 800 amino acids to avoid out of memory errors
print('Sequence length:', len(wt_seq))
if len(wt_seq) > 1000:
print('Sequence length over 1000 which is a potential OOM error risk, skipping dataset...')
continue
gremlin_new = GREMLIN(alignment=msa_path, wt_seq=wt_seq, max_msa_seqs=10000)
#gaps = gremlin_new.gaps
gaps_1_indexed = gremlin_new.gaps_1_indexed
var_pos = [int(v[1:-1]) for variants in variants_split for v in variants]
n_muts = []
for vs in variants_split:
n_muts.append(len(vs))
max_muts = max(n_muts)
c = 0
for vp in var_pos:
if vp in gaps_1_indexed:
c += 1
print(f'N max. (multiple) amino acid substitutions: {max_muts}')
c = c / max_muts
ratio_input_vars_at_gaps = c / len(var_pos)
if c > 0:
print(f'{int(c)} of {len(var_pos)} ({ratio_input_vars_at_gaps * 100:.2f}%) input variants to be predicted are variants with '
f'amino acid substitutions at gap positions (these variants will be predicted/labeled with a fitness of 0.0).')
if ratio_input_vars_at_gaps >= 1.0:
print('100% substitutions at gap sites, skipping dataset...')
continue
#variants, sequences, fitnesses = remove_gap_pos(gaps, variants, sequences, fitnesses)
x_dca = gremlin_new.collect_encoded_sequences(sequences)
x_wt = gremlin_new.x_wt
# Statistical model performance
y_pred_new = get_delta_e_statistical_model(x_dca, x_wt)
print(f'Statistical DCA model performance on all datapoints; Spearman\'s rho: {abs(spearmanr(fitnesses, y_pred_new)[0]):.3f}')
assert len(x_dca) == len(fitnesses) == len(variants) == len(sequences)
#except SystemError: # Check MSAs
# continue
tested_dsets.append(f'{dset_key} ({100.0 - (ratio_input_vars_at_gaps * 100):.2f}%, {max_muts})')
dset_perfs.append(abs(spearmanr(fitnesses, y_pred_new)[0]))
gc.collect() # Potentially GC is needed to free some RAM (deallocated VRAM -> partly stored in RAM?) after run
plt.figure(figsize=(26, 12))
plt.plot(range(len(tested_dsets)), dset_perfs, 'o--', markersize=8)
plt.plot(range(len(tested_dsets)), np.full(np.shape(tested_dsets), np.mean(dset_perfs)), 'k--')
Expand Down

0 comments on commit 56f70a1

Please sign in to comment.