From fd5fd2b56b93612c3b80848d5a34ea15d129cb33 Mon Sep 17 00:00:00 2001
From: mufeili <mufeili1996@gmail.com>
Date: Thu, 27 Jan 2022 16:11:05 +0800
Subject: [PATCH] Update

---
 python/dgllife/data/pdbbind.py | 31 +++++++++++++++++++------------
 1 file changed, 19 insertions(+), 12 deletions(-)

diff --git a/python/dgllife/data/pdbbind.py b/python/dgllife/data/pdbbind.py
index 39e402bd..cb92c68b 100644
--- a/python/dgllife/data/pdbbind.py
+++ b/python/dgllife/data/pdbbind.py
@@ -56,7 +56,7 @@ class PDBBind(object):
     load_binding_pocket : bool
         Whether to load binding pockets or full proteins. Default to True.
     remove_coreset_from_refinedset: bool
-        Whether to remove core set from refined set when training with refined set and test with core set. 
+        Whether to remove core set from refined set when training with refined set and test with core set.
         Default to True.
     sanitize : bool
         Whether sanitization is performed in initializing RDKit molecule instances. See
@@ -88,7 +88,7 @@ class PDBBind(object):
         Default None, and PDBBind dataset will be downloaded from DGL database.
         Specify this argument to a local path of customized dataset, which should follow the structure and the naming format of PDBBind v2015.
     """
-    def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False, 
+    def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False,
                  calc_charges=False, remove_hs=False, use_conformation=True,
                  construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
                  zero_padding=True, num_processes=None, local_path=None):
@@ -101,9 +101,9 @@ def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove
         # Prepare for Refined, Agglomerative Sequence Split and Agglomerative Structure Split
         if pdb_version == 'v2007' and not local_path:
             merged_df = self.df.merge(self.agg_split, on='PDB_code')
-            self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index) 
+            self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index)
                                        for target_set in ['train', 'valid', 'test']]
-            self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index) 
+            self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index)
                                         for target_set in ['train', 'valid', 'test']]
 
     def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path):
@@ -132,7 +132,7 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core
             extracted_data_path = root_dir_path + '/pdbbind_v2007'
             download(_get_dgl_url(self._url), path=data_path, overwrite=False)
             extract_archive(data_path, extracted_data_path, overwrite=False)
-            extracted_data_path += '/home/ubuntu' # extra layer 
+            extracted_data_path += '/home/ubuntu' # extra layer
 
             # DataFrame containing the pdbbind_2007_agglomerative_split.txt
             self.agg_split = pd.read_csv(extracted_data_path + '/v2007/pdbbind_2007_agglomerative_split.txt')
@@ -171,8 +171,6 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core
             self.df = pd.DataFrame(contents, columns=(
                 'PDB_code', 'resolution', 'release_year',
                 '-logKd/Ki', 'Kd/Ki', 'cluster_ID'))
-        
-        pdbs = self.df['PDB_code'].tolist()
 
         # remove core set from refined set if using refined
         if remove_coreset_from_refinedset and subset == 'refined':
@@ -183,12 +181,21 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core
             elif pdb_version == 'v2007':
                 core_path = extracted_data_path + '/v2007/INDEX.2007.core.data'
 
+            core_pdbs = []
             with open(core_path,'r') as f:
                 for line in f:
                     fields = line.strip().split()
-                    if fields[0] != "#" and fields[0] in pdbs:
-                        pdbs.remove(fields[0])
-        
+                    if fields[0] != "#":
+                        core_pdbs.append(fields[0])
+
+            non_core_ids = []
+            for i in range(len(self.df)):
+                if self.df['PDB_code'][i] not in core_pdbs:
+                    non_core_ids.append(i)
+            self.df = self.df.iloc[non_core_ids]
+
+        pdbs = self.df['PDB_code'].tolist()
+
         if local_path:
             pdb_path = local_path
         else:
@@ -317,7 +324,7 @@ def _preprocess(self, load_binding_pocket,
             max_num_ligand_atoms = None
             max_num_protein_atoms = None
 
-        construct_graph_and_featurize = partial(construct_graph_and_featurize, 
+        construct_graph_and_featurize = partial(construct_graph_and_featurize,
                             max_num_ligand_atoms=max_num_ligand_atoms,
                             max_num_protein_atoms=max_num_protein_atoms)
 
@@ -326,7 +333,7 @@ def _preprocess(self, load_binding_pocket,
 
         # construct graphs with multiprocessing
         pool = multiprocessing.Pool(processes=num_processes)
-        self.graphs = pool.starmap(construct_graph_and_featurize, 
+        self.graphs = pool.starmap(construct_graph_and_featurize,
                                    zip(self.ligand_mols, self.protein_mols,
                                        self.ligand_coordinates, self.protein_coordinates))
         print(f'Done constructing {len(self.graphs)} graphs.')