Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add resolve_intramol_clashes #1413

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion tests/test_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,25 @@

import numpy as np
import pytest
from jax import grad
from rdkit import Chem
from rdkit.Chem import AllChem

from timemachine.constants import MAX_FORCE_NORM
from timemachine.datasets import fetch_freesolv
from timemachine.fe.free_energy import HostConfig
from timemachine.fe.model_utils import get_vacuum_val_and_grad_fn
from timemachine.fe.utils import get_romol_conf, read_sdf, read_sdf_mols_by_name
from timemachine.ff import Forcefield
from timemachine.ff.handlers import openmm_deserializer
from timemachine.md import builders, minimizer
from timemachine.md.barostat.utils import compute_box_volume
from timemachine.md.minimizer import equilibrate_host_barker, make_host_du_dx_fxn
from timemachine.md.minimizer import (
equilibrate_host_barker,
make_host_du_dx_fxn,
make_intramol_softened_fxn,
resolve_intramol_clashes,
)
from timemachine.potentials import NonbondedPairList
from timemachine.potentials.jax_utils import distance_on_pairs

Expand Down Expand Up @@ -388,3 +396,36 @@ def test_minimizer_failure_toy_system():
assert not np.isclose(initial_distance, final_distance, atol=2e-4)
assert initial_force_norms > np.linalg.norm(du_dx(minimized_coords))
minimizer.check_force_norm(minimized_coords)


def test_resolve_intramol_clashes():
"""start from a conformer where |force(x.astype(f32))| is +inf"""

mol_dict = {mol.GetProp("_Name"): mol for mol in fetch_freesolv()}
mol = mol_dict["mobley_2850833"]
ff = Forcefield.load_default()

# put a pair of atoms nearly on top of each other
np.random.seed(0)
conf = mol.GetConformer(0)
conf.SetAtomPosition(14, conf.GetAtomPosition(7) + np.random.randn(3) * 0.01)

x0 = get_romol_conf(mol)
U_fxn = make_intramol_softened_fxn(mol, ff)

def force_norm(x, lam):
return np.max(np.linalg.norm(grad(U_fxn)(x, lam), axis=1))

assert np.isposinf(force_norm(x0.astype(np.float32), 0.0)), "oops, test isn't strong enough"
assert force_norm(x0, 1.0) < force_norm(x0, 0.0), "oops, lam isn't doing enough"

x1 = resolve_intramol_clashes(mol, ff, in_place=False)

assert force_norm(x1, 0.0) < MAX_FORCE_NORM, "oops, minimization didn't achieve its goal"
np.testing.assert_equal(get_romol_conf(mol), x0, "oops, in_place=False updated mol in-place")
assert np.linalg.norm(x1 - x0, axis=1).max() < 1.0, "oops, minimization moved things too much"

x2 = resolve_intramol_clashes(mol, ff, in_place=True)

np.testing.assert_equal(x2, x1, "oops, minimization wasn't deterministic")
np.testing.assert_allclose(get_romol_conf(mol), x1, err_msg="oops, in_place=True didn't update mol in-place")
123 changes: 122 additions & 1 deletion timemachine/md/minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import jax.numpy as jnp
import numpy as np
import scipy.optimize
from jax import grad, jit, value_and_grad
from numpy.typing import NDArray
from rdkit import Chem

from timemachine.constants import BOLTZ, DEFAULT_PRESSURE, DEFAULT_TEMP, MAX_FORCE_NORM
from timemachine.constants import BOLTZ, DEFAULT_POSITIONAL_RESTRAINT_K, DEFAULT_PRESSURE, DEFAULT_TEMP, MAX_FORCE_NORM
from timemachine.fe import topology
from timemachine.fe.free_energy import HostConfig
from timemachine.fe.utils import get_romol_conf, set_romol_conf
Expand Down Expand Up @@ -470,6 +471,9 @@ def equilibrate_host_barker(
return x_host


# note: two code paths, one for Sequence[BoundPotential] and one for jax-transformable fxn


def get_val_and_grad_fn(bps: Sequence[BoundPotential], box: NDArray, precision=np.float32):
"""
Convert impls, box into a function that only takes in coords.
Expand Down Expand Up @@ -541,6 +545,30 @@ def val_and_grad_fn_bfgs(x_flattened):
return res.x.reshape(final_shape)


def wrap_for_scipy(f, x0):
"""scipy L-BFGS-B assumes a flat array, and gradient in f64

Parameters
----------
f : jax-transformable scalar-valued fxn
function to be minimized
x0: array
determines shape used for flattening / unflattening

Returns
-------
fun(x_flat) -> (value, gradient)
"""
vg = jit(value_and_grad(f))

def fun(x_flat):
x = x_flat.reshape(x0.shape)
v, g = vg(x)
return float(v), np.array(g, dtype=np.float64).flatten()

return fun


def local_minimize(
x0: NDArray,
box0: NDArray,
Expand Down Expand Up @@ -700,3 +728,96 @@ def replace_conformer_with_minimized(

xs_opt = local_minimize(xs, box, val_and_grad_fn, all_idxs, minimizer_config, verbose=False)
set_romol_conf(mol, xs_opt, conf_id)


def make_intramol_softened_fxn(mol, ff):
"""Construct a potential function U(x, lam)
where lam controls the `w` coordinate associated with each intramol NB pair:

lam = 0 --> w_ij = 0
(No softening applied.)

lam = 1 --> w_ij = 0.75 * sig_ij
(0.75*sig_ij chosen so that, even at r_ij = 0, lennard_jones(sqrt(r_ij^2 + w_ij^2); sig_ij, eps_ij)
will be at most ~(100 * eps_ij).)
"""
bt = topology.BaseTopology(mol, ff)
vacuum_system = bt.setup_end_state()
bps = [vacuum_system.bond, vacuum_system.angle, vacuum_system.torsion, vacuum_system.nonbonded]
potentials = [bp.potential for bp in bps] # TODO: should these also include chiral_atom, chiral_bond ?
params = [bp.params for bp in bps]
summed_potential = SummedPotential(potentials, params)
params_0 = [jnp.array(p) for p in params]
nb_params_0 = jnp.array(params_0[-1])
box = 100000 * np.eye(3)

# scale each w_ij between 0 and 0.75 * sig_ij
lj_sig = nb_params_0[:, 1]
max_ws = 0.75 * lj_sig # TODO: expose the 0.75 parameter?

def make_nb_params(lam):
nb_params = nb_params_0.at[:, -1].set(lam * max_ws)
return nb_params

def U(x, lam):
new_nb_params = make_nb_params(lam)
_params = params_0[:-1] + [new_nb_params]
return summed_potential.call_with_params_list(x, _params, box)

return U


def resolve_intramol_clashes(mol, ff, k=DEFAULT_POSITIONAL_RESTRAINT_K, verbose=True, in_place=True):
"""Minimize energy of mol

Parameters
----------
mol : rdkit romol
ff : forcefield
k : float
force constant for positional restraints
verbose : bool
print messages
in_place : bool
update mol's 0'th conformer in-place

Returns
-------
final_conf: array
energy-minimized version of get_romol_conf(mol, 0)
"""
x0 = get_romol_conf(mol)
box = 100000 * np.eye(3)

U_fxn = make_intramol_softened_fxn(mol, ff)

def force_norm(x, lam):
return jnp.max(jnp.linalg.norm(grad(U_fxn)(x, lam), axis=1))

if verbose:
print("initial force norm: ", force_norm(x0, 0.0))

U_restr = lambda x: harmonic_positional_restraint(get_romol_conf(mol), x, box, k)
U_combined = lambda x, lam: U_fxn(x, lam) + U_restr(x)

# TODO: could replace minimize(x, lam) with something like mcmc_update(x, lam)
def minimize(x, lam):
fun = wrap_for_scipy(lambda x: U_combined(x, lam), x)
result = scipy.optimize.minimize(fun, x.flatten(), jac=True, method="L-BFGS-B")
return result.x.reshape(x.shape)

# TODO: instead of just looping over lam in [1.0, 0.0], might need to introduce more steps adaptively
adjusted_conf = minimize(x0, 1.0)
final_conf = minimize(adjusted_conf, 0.0)

final_force_norm = force_norm(final_conf, 0.0)
if verbose:
print("final force norm: ", final_force_norm)

if final_force_norm > MAX_FORCE_NORM:
raise RuntimeError("final force norm exceeds threshold")

if in_place:
set_romol_conf(mol, final_conf)

return final_conf