Skip to content

Commit

Permalink
Cleanup ruff linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
maabuu committed Feb 20, 2024
1 parent 0173910 commit 045071a
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 40 deletions.
2 changes: 1 addition & 1 deletion posebusters/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main():
logger.error(e)


def bust(
def bust( # noqa: PLR0913
mol_pred: list[Path | Mol] = [],
mol_true: Path | Mol | None = None,
mol_cond: Path | Mol | None = None,
Expand Down
6 changes: 3 additions & 3 deletions posebusters/modules/distance_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
}


def check_geometry(
def check_geometry( # noqa: PLR0913, PLR0915
mol_pred: Mol,
threshold_bad_bond_length: float = 0.2,
threshold_clash: float = 0.2,
Expand Down Expand Up @@ -234,8 +234,8 @@ def _two_bonds_to_angle(bond1: tuple[int, int], bond2: tuple[int, int]) -> None
set1 = set(bond1)
set2 = set(bond2)
all_atoms = set1 | set2
# angle requires two bonds to share exactly one atom
if len(all_atoms) != 3:
# angle requires two bonds to share exactly one atom, that is we must have 3 atoms
if len(all_atoms) != 3: # noqa: PLR2004
return None
# find shared atom
shared_atom = set1 & set2
Expand Down
2 changes: 1 addition & 1 deletion posebusters/modules/intermolecular_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_periodic_table = GetPeriodicTable()


def check_intermolecular_distance(
def check_intermolecular_distance( # noqa: PLR0913
mol_pred: Mol,
mol_cond: Mol,
radius_type: str = "vdw",
Expand Down
2 changes: 1 addition & 1 deletion posebusters/modules/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def check_rmsd(
return {"results": results}


def robust_rmsd(
def robust_rmsd( # noqa: PLR0913
mol_probe: Mol,
mol_ref: Mol,
conf_id_probe: int = -1,
Expand Down
2 changes: 1 addition & 1 deletion posebusters/modules/volume_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)


def check_volume_overlap(
def check_volume_overlap( # noqa: PLR0913
mol_pred: Mol,
mol_cond: Mol,
clash_cutoff: float = 0.05,
Expand Down
10 changes: 5 additions & 5 deletions posebusters/posebusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _run(self) -> Generator[dict, None, None]:
"""
self._initialize_modules()

for i, paths in self.file_paths.iterrows():
for _, paths in self.file_paths.iterrows():
mol_args = {}
if "mol_cond" in paths and paths["mol_cond"] is not None:
mol_cond_load_params = self.config.get("loading", {}).get("mol_cond", {})
Expand All @@ -151,15 +151,15 @@ def _run(self) -> Generator[dict, None, None]:

for name, fname, func, args in zip(self.module_name, self.fname, self.module_func, self.module_args):
# pick needed arguments for module
args = {k: v for k, v in mol_args.items() if k in args}
args_needed = {k: v for k, v in mol_args.items() if k in args}
# loading takes all inputs
if fname == "loading":
args = {k: args.get(k, None) for k in args}
args_needed = {k: args_needed.get(k, None) for k in args_needed}
# run module when all needed input molecules are valid Mol objects
if fname != "loading" and not all(args.get(m, None) for m in args):
if fname != "loading" and not all(args_needed.get(m, None) for m in args_needed):
module_output: dict[str, Any] = {"results": {}}
else:
module_output = func(**args)
module_output = func(**args_needed)

# save to object
self.results[results_key].extend([(name, k, v) for k, v in module_output["results"].items()])
Expand Down
14 changes: 7 additions & 7 deletions posebusters/tools/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def safe_supply_mols(path: Path, load_all=True, sanitize=True, **load_params) ->
supplier = SDMolSupplier(str(path), sanitize=False, strictParsing=True)
i = 0
for mol in supplier:
mol = _process_mol(mol, sanitize=sanitize, **load_params)
mol_clean = _process_mol(mol, sanitize=sanitize, **load_params)
i += 1
if mol is not None:
mol.SetProp("_Path", str(path))
mol.SetProp("_Index", str(i))
yield mol
if mol_clean is not None:
mol_clean.SetProp("_Path", str(path))
mol_clean.SetProp("_Index", str(i))
yield mol_clean


def _load_mol(
def _load_mol( # noqa: PLR0913
path: Path,
load_all=False,
sanitize=False,
Expand Down Expand Up @@ -136,7 +136,7 @@ def _load_and_combine_mols(path: Path, sanitize=True, removeHs=True, strictParsi
return mol


def _process_mol(
def _process_mol( # noqa: PLR0913
mol: Mol | None,
smiles: str | None = None,
cleanup=False,
Expand Down
19 changes: 3 additions & 16 deletions posebusters/tools/molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,10 @@
from rdkit import RDLogger
from rdkit.Chem.AllChem import AssignBondOrdersFromTemplate
from rdkit.Chem.Lipinski import HAcceptorSmarts, HDonorSmarts
from rdkit.Chem.rdchem import (
AtomValenceException,
Bond,
Conformer,
GetPeriodicTable,
Mol,
RWMol,
)
from rdkit.Chem.rdchem import AtomValenceException, Bond, Conformer, GetPeriodicTable, Mol, RWMol
from rdkit.Chem.rdMolAlign import GetBestAlignmentTransform
from rdkit.Chem.rdmolfiles import MolFromSmarts
from rdkit.Chem.rdmolops import (
AddHs,
RemoveHs,
RemoveStereochemistry,
RenumberAtoms,
SanitizeMol,
)
from rdkit.Chem.rdmolops import AddHs, RemoveHs, RemoveStereochemistry, RenumberAtoms, SanitizeMol
from rdkit.Chem.rdMolTransforms import TransformConformer

logger = getLogger(__name__)
Expand Down Expand Up @@ -184,7 +171,7 @@ def _get_atomic_number(atomic_symbol: str):
symbol = "H"
return _periodic_table.GetAtomicNumber(symbol)
except Exception:
print(atomic_symbol)
logger.error("Unknown atomic symbol: %s", atomic_symbol)
return atomic_symbol


Expand Down
2 changes: 1 addition & 1 deletion posebusters/tools/protein.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_atom_type_mask(mol: Mol, ignore_types: Iterable[str]) -> list[bool]:
]


def _keep_atom(
def _keep_atom( # noqa: PLR0913, PLR0911
atom: Atom, ignore_h: bool, ignore_protein: bool, ignore_org_cof: bool, ignore_inorg_cof: bool, ignore_water: bool
) -> bool:
"""Whether to keep atom for given ignore flags."""
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ lint.ignore = ["E501"]
[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D", "PLR2004"]

[tool.mypy]
ignore_missing_imports = true

Expand Down
8 changes: 4 additions & 4 deletions tests/test_posebusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def test_bust_mol_rdkit() -> None:
assert df.all(axis=1).values[0]


def test_bust_mols_hydrogen() -> None:
def test_bust_mols_hydrogen(threshold=8) -> None:
posebusters = PoseBusters("mol")
df = posebusters.bust([mol_single_h])
assert df.sum(axis=1).values[0] >= 8 # energy ratio test fails
assert df.sum(axis=1).values[0] >= threshold # energy ratio test fails


def test_bust_mols_consistency() -> None:
def test_bust_mols_consistency(atol=1e-6) -> None:
# check that running the same molecule twice gives the same result

posebusters = PoseBusters("mol")
Expand All @@ -84,4 +84,4 @@ def test_bust_mols_consistency() -> None:
for v1, v2 in zip(result_2.values, result_3.values):
if v1[2] == v2[2] or math.isnan(v1[2]):
continue
assert abs(v1[2] - v2[2]) < 1e-6, f"{v1[0], v1[1]}: {v1[2]} != {v2[2]}"
assert abs(v1[2] - v2[2]) < atol, f"{v1[0], v1[1]}: {v1[2]} != {v2[2]}"

0 comments on commit 045071a

Please sign in to comment.