diff --git a/python/dgllife/data/pdbbind.py b/python/dgllife/data/pdbbind.py index 39e402bd..cb92c68b 100644 --- a/python/dgllife/data/pdbbind.py +++ b/python/dgllife/data/pdbbind.py @@ -56,7 +56,7 @@ class PDBBind(object): load_binding_pocket : bool Whether to load binding pockets or full proteins. Default to True. remove_coreset_from_refinedset: bool - Whether to remove core set from refined set when training with refined set and test with core set. + Whether to remove core set from refined set when training with refined set and test with core set. Default to True. sanitize : bool Whether sanitization is performed in initializing RDKit molecule instances. See @@ -88,7 +88,7 @@ class PDBBind(object): Default None, and PDBBind dataset will be downloaded from DGL database. Specify this argument to a local path of customized dataset, which should follow the structure and the naming format of PDBBind v2015. """ - def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False, + def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False, calc_charges=False, remove_hs=False, use_conformation=True, construct_graph_and_featurize=ACNN_graph_construction_and_featurization, zero_padding=True, num_processes=None, local_path=None): @@ -101,9 +101,9 @@ def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove # Prepare for Refined, Agglomerative Sequence Split and Agglomerative Structure Split if pdb_version == 'v2007' and not local_path: merged_df = self.df.merge(self.agg_split, on='PDB_code') - self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index) + self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index) for target_set in ['train', 'valid', 'test']] - self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index) + self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index) for target_set in ['train', 'valid', 'test']] def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path): @@ -132,7 +132,7 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core extracted_data_path = root_dir_path + '/pdbbind_v2007' download(_get_dgl_url(self._url), path=data_path, overwrite=False) extract_archive(data_path, extracted_data_path, overwrite=False) - extracted_data_path += '/home/ubuntu' # extra layer + extracted_data_path += '/home/ubuntu' # extra layer # DataFrame containing the pdbbind_2007_agglomerative_split.txt self.agg_split = pd.read_csv(extracted_data_path + '/v2007/pdbbind_2007_agglomerative_split.txt') @@ -171,8 +171,6 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core self.df = pd.DataFrame(contents, columns=( 'PDB_code', 'resolution', 'release_year', '-logKd/Ki', 'Kd/Ki', 'cluster_ID')) - - pdbs = self.df['PDB_code'].tolist() # remove core set from refined set if using refined if remove_coreset_from_refinedset and subset == 'refined': @@ -183,12 +181,21 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core elif pdb_version == 'v2007': core_path = extracted_data_path + '/v2007/INDEX.2007.core.data' + core_pdbs = [] with open(core_path,'r') as f: for line in f: fields = line.strip().split() - if fields[0] != "#" and fields[0] in pdbs: - pdbs.remove(fields[0]) - + if fields[0] != "#": + core_pdbs.append(fields[0]) + + non_core_ids = [] + for i in range(len(self.df)): + if self.df['PDB_code'][i] not in core_pdbs: + non_core_ids.append(i) + self.df = self.df.iloc[non_core_ids] + + pdbs = self.df['PDB_code'].tolist() + if local_path: pdb_path = local_path else: @@ -317,7 +324,7 @@ def _preprocess(self, load_binding_pocket, max_num_ligand_atoms = None max_num_protein_atoms = None - construct_graph_and_featurize = partial(construct_graph_and_featurize, + construct_graph_and_featurize = partial(construct_graph_and_featurize, max_num_ligand_atoms=max_num_ligand_atoms, max_num_protein_atoms=max_num_protein_atoms) @@ -326,7 +333,7 @@ def _preprocess(self, load_binding_pocket, # construct graphs with multiprocessing pool = multiprocessing.Pool(processes=num_processes) - self.graphs = pool.starmap(construct_graph_and_featurize, + self.graphs = pool.starmap(construct_graph_and_featurize, zip(self.ligand_mols, self.protein_mols, self.ligand_coordinates, self.protein_coordinates)) print(f'Done constructing {len(self.graphs)} graphs.')