Skip to content

Commit

Permalink
Make linting pass
Browse files Browse the repository at this point in the history
  • Loading branch information
maabuu committed Nov 21, 2023
1 parent 5e76aa0 commit bbecfaa
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 35 deletions.
19 changes: 11 additions & 8 deletions posebusters/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def bust(
"""PoseBusters: Plausibility checks for generated molecule poses."""
if table is None and len(mol_pred) == 0:
raise ValueError("Provide either MOLS_PRED or TABLE.")
elif table is not None:

if table is not None:
# run on table
file_paths = pd.read_csv(table, index_col=None)
mode = _select_mode(config, file_paths.columns.tolist())
Expand Down Expand Up @@ -115,26 +116,28 @@ def _parse_args(args):
def _format_results(df: pd.DataFrame, outfmt: str = "short", no_header: bool = False, index: int = 0) -> str:
if outfmt == "long":
return create_long_output(df)
elif outfmt == "csv":

if outfmt == "csv":
header = (not no_header) and (index == 0)
df.index.names = ["file", "molecule"]
df.columns = [c.lower().replace(" ", "_") for c in df.columns]
return df.to_csv(index=True, header=header)
elif outfmt == "short":

if outfmt == "short":
return create_short_output(df)
else:
raise ValueError(f"Unknown output format {outfmt}")

raise ValueError(f"Unknown output format {outfmt}")


def _select_mode(config, columns: Iterable[str]) -> str | dict[str, Any]:
# decide on mode to run

# load config if provided
if type(config) == Path:
return dict(safe_load(open(config)))
if isinstance(config, Path):
return dict(safe_load(open(config, encoding="utf-8")))

# forward string if config provide
if type(config) == str:
if isinstance(config, str):
return str(config)

# select mode based on inputs
Expand Down
10 changes: 5 additions & 5 deletions posebusters/modules/energy_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,28 @@ def check_energy_ratio(
mol_pred = assert_sanity(mol_pred)
AddHs(mol_pred, addCoords=True)
except Exception as e:
logger.warning(f"Failed to prepare molecule: {e}")
logger.warning("Failed to prepare molecule: %s", e)
return _empty_results

try:
inchi = get_inchi(mol_pred, inchi_strict=inchi_strict)
except InchiReadWriteError as e:
logger.warning(f"Molecule does not sanitize: {e.args[1]}")
logger.warning("Molecule does not sanitize: %s", e.args[1])
return _empty_results
except Exception as e:
logger.warning(f"Molecule does not sanitize: {e}")
logger.warning("Molecule does not sanitize: %s", e)
return _empty_results

try:
conf_energy = get_conf_energy(mol_pred)
except Exception as e:
logger.warning(f"Failed to calculate conformation energy for {inchi}: {e}")
logger.warning("Failed to calculate conformation energy for %s: %s", inchi, e)
conf_energy = np.nan

try:
avg_energy = float(get_average_energy(inchi, ensemble_number_conformations))
except Exception as e:
logger.warning(f"Failed to calculate ensemble conformation energy for {inchi}: {e}")
logger.warning("Failed to calculate ensemble conformation energy for %s: %s", inchi, e)
avg_energy = np.nan

pred_factor = conf_energy / avg_energy
Expand Down
3 changes: 1 addition & 2 deletions posebusters/modules/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,5 +135,4 @@ def _call_rdkit_rmsd(mol_probe: Mol, mol_ref: Mol, conf_id_probe: int, conf_id_r
def _rmsd(mol_probe: Mol, mol_ref: Mol, conf_id_probe: int, conf_id_ref: int, kabsch: bool = False, **params):
if kabsch is True:
return GetBestRMS(prbMol=mol_probe, refMol=mol_ref, prbId=conf_id_probe, refId=conf_id_ref, **params)
else:
return CalcRMS(prbMol=mol_probe, refMol=mol_ref, prbId=conf_id_probe, refId=conf_id_ref, **params)
return CalcRMS(prbMol=mol_probe, refMol=mol_ref, prbId=conf_id_probe, refId=conf_id_ref, **params)
21 changes: 14 additions & 7 deletions posebusters/posebusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,26 @@
class PoseBusters:
"""Class to run all tests on a set of molecules."""

file_paths: pd.DataFrame
module_name: list
module_func: list
module_args: list
fname: list

def __init__(self, config: str | dict[str, Any] = "redock", top_n: int | None = None):
"""Initialize PoseBusters object."""
self.module_func: list # dict[str, Callable]
self.module_args: list # dict[str, set[str]]

if isinstance(config, str) and config in {"dock", "redock", "mol"}:
logger.info(f"Using default configuration for mode {config}.")
self.config = safe_load(open(Path(__file__).parent / "config" / f"{config}.yml"))
logger.info("Using default configuration for mode %s.", config)
with open(Path(__file__).parent / "config" / f"{config}.yml", encoding="utf-8") as config_file:
self.config = safe_load(config_file)
elif isinstance(config, dict):
logger.info("Using configuration dictionary provided by user.")
self.config = config
else:
logger.error(f"Configuration {config} not valid. Provide either 'dock', 'redock', 'mol' or a dictionary.")
logger.error("Configuration %s not valid. Provide either 'dock', 'redock', 'mol' or a dictionary.", config)
assert len(set(self.config.get("tests", {}).keys()) - set(module_dict.keys())) == 0

self.config["top_n"] = self.config.get("top_n", top_n)
Expand Down Expand Up @@ -175,16 +182,16 @@ def _initialize_modules(self) -> None:
self.fname.append(module["function"])
self.module_func.append(partial(function, **parameters))
self.module_args.append(module_args)
pass

@staticmethod
def _get_name(mol: Mol, i: int) -> str:
if mol is None:
return f"invalid_mol_at_pos_{i}"
elif not mol.HasProp("_Name") or mol.GetProp("_Name") == "":

if not mol.HasProp("_Name") or mol.GetProp("_Name") == "":
return f"mol_at_pos_{i}"
else:
return mol.GetProp("_Name")

return mol.GetProp("_Name")


def _dataframe_from_output(results_dict, config, full_report: bool = False) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion posebusters/tools/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _value_map(x):
if type(x) == bool:
if isinstance(x, bool):
return ". " if x else "Fail"
return x

Expand Down
2 changes: 1 addition & 1 deletion posebusters/tools/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def safe_load_mol(path: Path, load_all: bool = False, **load_params) -> Mol | No
mol = _load_mol(path, load_all=load_all, **load_params)
return mol
except Exception as exception:
logger.warning(f"Could not load molecule from {path} with error: {exception}")
logger.warning("Could not load molecule from %s with error: %s", path, exception)
return None


Expand Down
4 changes: 2 additions & 2 deletions posebusters/tools/molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def _renumber_mol1_to_match_mol2(mol1: Mol, mol2: Mol) -> Mol:
n_atoms = mol1.GetNumAtoms()
n_different = n_atoms - sum(is_identical_idx)
if all(is_identical_idx):
logger.info("All {n_atoms} atoms are already in the same order")
logger.info("All %d atoms are already in the same order", n_atoms)
else:
logger.info(f"Swapping {n_different} out of {n_atoms} indices")
logger.info("Swapping %d out of %d indices", n_different, n_atoms)
mol1 = RenumberAtoms(mol1, [m[1] for m in atom_map])
return mol1

Expand Down
15 changes: 6 additions & 9 deletions posebusters/tools/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,24 @@ def bae(val: float, lb: float, ub: float) -> float:
"""Calculate out of bounds absolute error."""
if val < lb:
return ae(val, lb)
elif val > ub:
if val > ub:
return ae(val, ub)
else:
return 0.0
return 0.0


def bpe(val: float, lb: float, ub: float) -> float:
"""Calculate out of bounds percentage error."""
if val < lb:
return pe(val, lb)
elif val > ub:
if val > ub:
return pe(val, ub)
else:
return 0.0
return 0.0


def bape(val: float, lb: float, ub: float) -> float:
"""Calculate out of bounds absolute percentage error."""
if val < lb:
return ape(val, lb)
elif val > ub:
if val > ub:
return ape(val, ub)
else:
return 0.0
return 0.0

0 comments on commit bbecfaa

Please sign in to comment.