Skip to content

Commit

Permalink
Update (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeili authored Jan 30, 2022
1 parent 0000e6e commit 8fd3dba
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions python/dgllife/data/pdbbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class PDBBind(object):
load_binding_pocket : bool
Whether to load binding pockets or full proteins. Default to True.
remove_coreset_from_refinedset: bool
Whether to remove core set from refined set when training with refined set and test with core set.
Whether to remove core set from refined set when training with refined set and test with core set.
Default to True.
sanitize : bool
Whether sanitization is performed in initializing RDKit molecule instances. See
Expand Down Expand Up @@ -88,7 +88,7 @@ class PDBBind(object):
Default None, and PDBBind dataset will be downloaded from DGL database.
Specify this argument to a local path of customized dataset, which should follow the structure and the naming format of PDBBind v2015.
"""
def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False,
def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove_coreset_from_refinedset=True, sanitize=False,
calc_charges=False, remove_hs=False, use_conformation=True,
construct_graph_and_featurize=ACNN_graph_construction_and_featurization,
zero_padding=True, num_processes=None, local_path=None):
Expand All @@ -101,9 +101,9 @@ def __init__(self, subset, pdb_version='v2015', load_binding_pocket=True, remove
# Prepare for Refined, Agglomerative Sequence Split and Agglomerative Structure Split
if pdb_version == 'v2007' and not local_path:
merged_df = self.df.merge(self.agg_split, on='PDB_code')
self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index)
self.agg_sequence_split = [list(merged_df.loc[merged_df['sequence']==target_set, 'PDB_code'].index)
for target_set in ['train', 'valid', 'test']]
self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index)
self.agg_structure_split = [list(merged_df.loc[merged_df['structure']==target_set, 'PDB_code'].index)
for target_set in ['train', 'valid', 'test']]

def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_coreset_from_refinedset, local_path):
Expand Down Expand Up @@ -132,7 +132,7 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core
extracted_data_path = root_dir_path + '/pdbbind_v2007'
download(_get_dgl_url(self._url), path=data_path, overwrite=False)
extract_archive(data_path, extracted_data_path, overwrite=False)
extracted_data_path += '/home/ubuntu' # extra layer
extracted_data_path += '/home/ubuntu' # extra layer

# DataFrame containing the pdbbind_2007_agglomerative_split.txt
self.agg_split = pd.read_csv(extracted_data_path + '/v2007/pdbbind_2007_agglomerative_split.txt')
Expand Down Expand Up @@ -171,8 +171,6 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core
self.df = pd.DataFrame(contents, columns=(
'PDB_code', 'resolution', 'release_year',
'-logKd/Ki', 'Kd/Ki', 'cluster_ID'))

pdbs = self.df['PDB_code'].tolist()

# remove core set from refined set if using refined
if remove_coreset_from_refinedset and subset == 'refined':
Expand All @@ -183,12 +181,21 @@ def _read_data_files(self, pdb_version, subset, load_binding_pocket, remove_core
elif pdb_version == 'v2007':
core_path = extracted_data_path + '/v2007/INDEX.2007.core.data'

core_pdbs = []
with open(core_path,'r') as f:
for line in f:
fields = line.strip().split()
if fields[0] != "#" and fields[0] in pdbs:
pdbs.remove(fields[0])

if fields[0] != "#":
core_pdbs.append(fields[0])

non_core_ids = []
for i in range(len(self.df)):
if self.df['PDB_code'][i] not in core_pdbs:
non_core_ids.append(i)
self.df = self.df.iloc[non_core_ids]

pdbs = self.df['PDB_code'].tolist()

if local_path:
pdb_path = local_path
else:
Expand Down Expand Up @@ -317,7 +324,7 @@ def _preprocess(self, load_binding_pocket,
max_num_ligand_atoms = None
max_num_protein_atoms = None

construct_graph_and_featurize = partial(construct_graph_and_featurize,
construct_graph_and_featurize = partial(construct_graph_and_featurize,
max_num_ligand_atoms=max_num_ligand_atoms,
max_num_protein_atoms=max_num_protein_atoms)

Expand All @@ -326,7 +333,7 @@ def _preprocess(self, load_binding_pocket,

# construct graphs with multiprocessing
pool = multiprocessing.Pool(processes=num_processes)
self.graphs = pool.starmap(construct_graph_and_featurize,
self.graphs = pool.starmap(construct_graph_and_featurize,
zip(self.ligand_mols, self.protein_mols,
self.ligand_coordinates, self.protein_coordinates))
print(f'Done constructing {len(self.graphs)} graphs.')
Expand Down

0 comments on commit 8fd3dba

Please sign in to comment.