Skip to content

Commit

Permalink
explainability vis: use heavy atoms instead of all atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 15, 2024
1 parent 00e4b72 commit 667df3a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _add_gaussians_for_atoms(
ValueGrid
ValueGrid object with added functions.
"""
for i, _ in enumerate(mol.GetAtoms()):
for i in range(mol.GetNumHeavyAtoms()):
if atom_weights[i] == 0:
continue
pos = conf.GetAtomPosition(i)
Expand Down Expand Up @@ -233,7 +233,7 @@ def make_sum_of_gaussians_grid(
"""
# assign default values and convert to numpy array
if atom_weights is None:
atom_weights = np.zeros(len(mol.GetAtoms()))
atom_weights = np.zeros(mol.GetNumHeavyAtoms())
elif not isinstance(atom_weights, np.ndarray):
atom_weights = np.array(atom_weights)

Expand All @@ -243,8 +243,10 @@ def make_sum_of_gaussians_grid(
bond_weights = np.array(bond_weights)

# validate input
if not len(atom_weights) == len(mol.GetAtoms()):
raise ValueError("len(atom_weights) is not equal to number of bonds in mol")
if not len(atom_weights) == mol.GetNumHeavyAtoms():
raise ValueError(
"len(atom_weights) is not equal to number of heavy atoms in mol"
)

if not len(bond_weights) == len(mol.GetBonds()):
raise ValueError("len(bond_weights) is not equal to number of bonds in mol")
Expand Down

0 comments on commit 667df3a

Please sign in to comment.