Skip to content

Commit

Permalink
some improvements + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MJ10 committed Oct 5, 2021
1 parent 38b028b commit 6ba40e6
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 11 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ Additional requirements for active learning experiments:
## Molecule experiments

Additional requirements:
- `pandas rdkit torch_geometric h5py`
- `pandas rdkit torch_geometric h5py ray`
- a few biochemistry programs, see `mols/Programs/README`

For `rdkit` in particular we found it to be easier to install through (mini)conda. [`torch_geometric`](https://github.com/rusty1s/pytorch_geometric) has non-trivial installation instructions.

If you have CUDA 10.1 configured, you can run `pip install -r requirements.txt`. You can also change `requirements.txt` to match your CUDA version. (Replace cu101 to cuXXX, where XXX is your CUDA version).

We compress the 300k molecule dataset for size. To uncompress it, run `cd mols/data/; gunzip docked_mols.h5.gz`.

We omit docking routines since they are part of a separate contribution still to be submitted. These are available on demand, please do reach out to [email protected] or [email protected].
2 changes: 1 addition & 1 deletion mols/Programs/README
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ MGL tools can be downloaded from https://ccsb.scripps.edu/mgltools/downloads/ (b

AutoDock Vina can be downloaded from http://vina.scripps.edu/download.html

OpenBabel can be downloaded from https://github.com/openbabel/openbabel/releases/tag/openbabel-3-1-1 (building required)
OpenBabel can be downloaded from https://github.com/openbabel/openbabel/releases/tag/openbabel-3-1-1 (building required, make sure after the build, the bin folder is in your PATH and lib in LD_LIBRARY_PATH)
45 changes: 37 additions & 8 deletions mols/gflownet_activelearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.distributions.categorical import Categorical

from utils import chem

import ray
from mol_mdp_ext import MolMDPExtended, BlockMoleculeDataExtended

import model_atom, model_block, model_fingerprint
Expand Down Expand Up @@ -436,8 +436,28 @@ def stop_everything():
'test_infos': test_infos,
'train_infos': train_infos}

@ray.remote
class _SimDockLet:
def __init__(self, tmp_dir):
self.dock = chem.DockVina_smi(tmp_dir)
self.target_norm = [-8.6, 1.10]

def eval(self, mol, norm=False):
s = "None"
try:
s = Chem.MolToSmiles(mol.mol)
print("docking {}".format(s))
_, r, _ = self.dock.dock(s)
except Exception as e: # Sometimes the prediction fails
print('exception for', s, e)
r = 0
if not norm:
return r
reward = -(r-self.target_norm[0])/self.target_norm[1]
return reward


def sample_and_update_dataset(args, model, proxy_dataset, generator_dataset, docker):
def sample_and_update_dataset(args, model, proxy_dataset, generator_dataset, dock_pool):
# generator_dataset.set_sampling_model(model, docker, sample_prob=args.sample_prob)
# sampler = generator_dataset.start_samplers(8, args.num_samples)
print("Sampling")
Expand Down Expand Up @@ -470,13 +490,20 @@ def sample_and_update_dataset(args, model, proxy_dataset, generator_dataset, doc
# print('skip', mol.blockidxs, mol.jbonds)
continue
# print('here')
score = docker.eval(mol, norm=False)
mol.reward = proxy_dataset.r2r(score)
# score = docker.eval(mol, norm=False)
# mol.reward = proxy_dataset.r2r(score)
# mol.smiles = s
smis.append(mol.smiles)
rews.append(mol.reward)
print(mol.smiles, mol.reward)
# rews.append(mol.reward)
# print(mol.smiles, mol.reward)
sampled_mols.append(mol)

t0 = time.time()
rews = list(dock_pool.map(lambda a, m: a.eval.remote(m), sampled_mols))
t1 = time.time()
print('Docking sim done in {}'.format(t1-t0))
for i in range(len(sampled_mols)):
sampled_mols[i].reward = rews[i]

print("Computing distances")
dists =[]
Expand Down Expand Up @@ -504,7 +531,9 @@ def main(args):
reward_norm = args.reward_norm
rews = []
smis = []
docker = Docker(tmp_dir, cpu_req=args.cpu_req)
actors = [_SimDockLet.remote(tmp_dir)
for i in range(10)]
pool = ray.util.ActorPool(actors)
args.repr_type = proxy_repr_type
args.replay_mode = "dataset"
args.reward_exp = 1
Expand Down Expand Up @@ -548,7 +577,7 @@ def main(args):

print(f"Sampling mols: {i}")
# sample molecule batch for generator and update dataset with docking scores for sampled batch
_proxy_dataset, r, s, batch_metrics = sample_and_update_dataset(args, model, proxy_dataset, gen_model_dataset, docker)
_proxy_dataset, r, s, batch_metrics = sample_and_update_dataset(args, model, proxy_dataset, gen_model_dataset, pool)
print(f"Batch Metrics: dists_mean: {batch_metrics['dists_mean']}, dists_sum: {batch_metrics['dists_sum']}, reward_mean: {batch_metrics['reward_mean']}, reward_max: {batch_metrics['reward_max']}")
rews.append(r)
smis.append(s)
Expand Down
10 changes: 9 additions & 1 deletion mols/utils/chem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
"Zr": 40, "Nb": 41, "Mo": 42, "Tc": 43, "Ru": 44, "Rh": 45, "Pd": 46, "Ag": 47, "Cd": 48, "In": 49, "Sn": 50, "Sb": 51,
"Te": 52, "I": 53, "Xe": 54, "Cs": 55, "Ba": 56}


def onehot(arr, num_classes, dtype=np.int):
arr = np.asarray(arr, dtype=np.int)
assert len(arr.shape) ==1, "dims other than 1 not implemented"
onehot_arr = np.zeros(arr.shape + (num_classes,), dtype=dtype)
onehot_arr[np.arange(arr.shape[0]), arr] = 1
return onehot_arr

def mol_from_frag(jun_bonds, frags=None, frag_smis=None, coord=None, optimize=False):
"joins 2 or more fragments into a single molecule"
jun_bonds = np.asarray(jun_bonds)
Expand Down Expand Up @@ -309,4 +317,4 @@ def dock(self, smi, mol_name=None, molgen_conf=20):
os.remove(os.path.join(self.outpath, "mol2", f"{mol_name}.mol2"))
os.remove(os.path.join(self.outpath, "pdbqt", f"{mol_name}.pdbqt"))
os.remove(os.path.join(self.outpath, "docked", f"{mol_name}.pdb"))
return mol_name, dockscore, coord
return mol_name, dockscore, coord
12 changes: 12 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
mkl-fft==1.3.0
mkl-random==1.2.2
mkl-service==2.4.0
numpy
pandas==1.3.3
ray==1.1.0
rdkit-pypi
torch==1.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
torch-geometric==1.6.3
torch-scatter==2.0.5 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
torch-sparse==0.6.8 -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
tqdm==4.36.1

0 comments on commit 6ba40e6

Please sign in to comment.