Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug that made --print output inconsitent #112

Merged
merged 6 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 77 additions & 53 deletions rmsd/calculate_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,9 @@ def set_coordinates(atoms: ndarray, V: ndarray, title: str = "", decimals: int =
"""
N, D = V.shape

if N != len(atoms):
raise ValueError("Mismatch between expected atoms and coordinate size")

fmt = "{:<2}" + (" {:15." + str(decimals) + "f}") * 3

out = list()
Expand Down Expand Up @@ -1818,6 +1821,14 @@ def parse_arguments(arguments: Optional[Union[str, List[str]]] = None) -> argpar
),
)

parser.add_argument(
"--print-only-rmsd-atoms",
action="store_true",
help=(
"Print only atoms used in finding optimal RMSD calculation (relevant if filtering e.g. Hydrogens)"
),
)

args = parser.parse_args(arguments)

# Check illegal combinations
Expand Down Expand Up @@ -1877,35 +1888,35 @@ def parse_arguments(arguments: Optional[Union[str, List[str]]] = None) -> argpar
return args


def main(args: Optional[List[str]] = None):
def main(args: Optional[List[str]] = None) -> str:

# Parse arguments
settings = parse_arguments(args)

# As default, load the extension as format
# Parse pdb.gz and xyz.gz as pdb and xyz formats
p_all_atoms, p_all = get_coordinates(
p_atoms, p_coord = get_coordinates(
settings.structure_a,
settings.format,
is_gzip=settings.format_is_gzip,
return_atoms_as_int=True,
)

q_all_atoms, q_all = get_coordinates(
q_atoms, q_coord = get_coordinates(
settings.structure_b,
settings.format,
is_gzip=settings.format_is_gzip,
return_atoms_as_int=True,
)

p_size = p_all.shape[0]
q_size = q_all.shape[0]
p_size = p_coord.shape[0]
q_size = q_coord.shape[0]

if not p_size == q_size:
print("error: Structures not same size")
sys.exit()

if np.count_nonzero(p_all_atoms != q_all_atoms) and not settings.reorder:
if np.count_nonzero(p_atoms != q_atoms) and not settings.reorder:
msg = """
error: Atoms are not in the same order.

Expand All @@ -1923,12 +1934,11 @@ def main(args: Optional[List[str]] = None):
# Set local view
p_view: Optional[ndarray] = None
q_view: Optional[ndarray] = None
use_view: bool = True

if settings.ignore_hydrogen:
assert type(p_all_atoms[0]) != str
assert type(q_all_atoms[0]) != str
p_view = np.where(p_all_atoms != 1) # type: ignore
q_view = np.where(q_all_atoms != 1) # type: ignore
p_view = np.where(p_atoms != 1) # type: ignore
q_view = np.where(q_atoms != 1) # type: ignore

elif settings.remove_idx:
index = np.array(list(set(range(p_size)) - set(settings.remove_idx)))
Expand All @@ -1939,26 +1949,27 @@ def main(args: Optional[List[str]] = None):
p_view = settings.add_idx
q_view = settings.add_idx

else:
use_view = False

# Set local view
if p_view is None:
p_coord = copy.deepcopy(p_all)
q_coord = copy.deepcopy(q_all)
p_atoms = copy.deepcopy(p_all_atoms)
q_atoms = copy.deepcopy(q_all_atoms)
if use_view:
p_coord_sub = copy.deepcopy(p_coord[p_view])
q_coord_sub = copy.deepcopy(q_coord[q_view])
p_atoms_sub = copy.deepcopy(p_atoms[p_view])
q_atoms_sub = copy.deepcopy(q_atoms[q_view])

else:
assert p_view is not None
assert q_view is not None
p_coord = copy.deepcopy(p_all[p_view])
q_coord = copy.deepcopy(q_all[q_view])
p_atoms = copy.deepcopy(p_all_atoms[p_view])
q_atoms = copy.deepcopy(q_all_atoms[q_view])
p_coord_sub = copy.deepcopy(p_coord)
q_coord_sub = copy.deepcopy(q_coord)
p_atoms_sub = copy.deepcopy(p_atoms)
q_atoms_sub = copy.deepcopy(q_atoms)

# Recenter to centroid
p_cent = centroid(p_coord)
q_cent = centroid(q_coord)
p_coord -= p_cent
q_coord -= q_cent
p_cent_sub = centroid(p_coord_sub)
q_cent_sub = centroid(q_coord_sub)
p_coord_sub -= p_cent_sub
q_coord_sub -= q_cent_sub

rmsd_method: RmsdCallable
reorder_method: Optional[ReorderCallable]
Expand All @@ -1985,7 +1996,7 @@ def main(args: Optional[List[str]] = None):
reorder_method = reorder_distance

# Save the resulting RMSD
result_rmsd = None
result_rmsd: Optional[float] = None

# Collect changes to be done on q coords
q_swap = None
Expand All @@ -1995,21 +2006,21 @@ def main(args: Optional[List[str]] = None):
if settings.use_reflections:

result_rmsd, q_swap, q_reflection, q_review = check_reflections(
p_atoms,
q_atoms,
p_coord,
q_coord,
p_atoms_sub,
q_atoms_sub,
p_coord_sub,
q_coord_sub,
reorder_method=reorder_method,
rmsd_method=rmsd_method,
)

elif settings.use_reflections_keep_stereo:

result_rmsd, q_swap, q_reflection, q_review = check_reflections(
p_atoms,
q_atoms,
p_coord,
q_coord,
p_atoms_sub,
q_atoms_sub,
p_coord_sub,
q_coord_sub,
reorder_method=reorder_method,
rmsd_method=rmsd_method,
keep_stereo=True,
Expand All @@ -2023,42 +2034,55 @@ def main(args: Optional[List[str]] = None):
# If there is a reorder, then apply before print
if q_review is not None:

q_all_atoms = q_all_atoms[q_review]
q_atoms = q_atoms[q_review]
q_coord = q_coord[q_review]
q_atoms_sub = q_atoms_sub[q_review]
q_coord_sub = q_coord_sub[q_review]

assert all(
p_atoms == q_atoms
p_atoms_sub == q_atoms_sub
), "error: Structure not aligned. Please submit bug report at http://github.com/charnley/rmsd"

# Calculate the RMSD value
if result_rmsd is None:
result_rmsd = rmsd_method(p_coord_sub, q_coord_sub)

# print result
if settings.output:

if q_swap is not None:
q_coord = q_coord[:, q_swap]
q_coord_sub = q_coord_sub[:, q_swap]

if q_reflection is not None:
q_coord = np.dot(q_coord, np.diag(q_reflection))
q_coord_sub = np.dot(q_coord_sub, np.diag(q_reflection))

q_coord -= centroid(q_coord)
U = kabsch(q_coord_sub, p_coord_sub)

# Rotate q coordinates
# TODO Should actually follow rotation method
q_coord = kabsch_rotate(q_coord, p_coord)
if settings.print_only_rmsd_atoms or not use_view:
q_coord_sub = np.dot(q_coord_sub, U)
q_coord_sub += p_cent_sub
return set_coordinates(
q_atoms_sub,
q_coord_sub,
title=f"Rotated '{settings.structure_b}' to match '{settings.structure_a}', with a RMSD of {result_rmsd:.8f}",
)

# center q on p's original coordinates
q_coord += p_cent
# Swap, reflect, rotate and re-center on the full atom and coordinate set
q_coord -= q_cent_sub

# done and done
xyz = set_coordinates(q_all_atoms, q_coord, title=f"{settings.structure_b} - modified")
return xyz
if q_swap is not None:
q_coord = q_coord[:, q_swap]

else:
if q_reflection is not None:
q_coord = np.dot(q_coord, np.diag(q_reflection))

if not result_rmsd:
result_rmsd = rmsd_method(p_coord, q_coord)
q_coord = np.dot(q_coord, U)
q_coord += p_cent_sub
return set_coordinates(
q_atoms,
q_coord,
title=f"Rotated {settings.structure_b} to match {settings.structure_a}, with RMSD of {result_rmsd:.8f}",
)

return result_rmsd
return str(result_rmsd)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/resources/issue93/b.xyz
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
C -0.2769166422 -0.7099045701 -1.2865831838
C -0.0318894674 0.6967020886 -1.8794695568
C 0.6647089780 -1.0315870523 -0.0954617436
H 0.7080595367 -2.0732382850 -2.5032982174
H -1.3132498520 -0.6901271514 -0.9163939649
C 0.2611763277 -0.4666702466 1.2320987930
C -0.0191751268 -1.0867015822 2.4345274135
C -0.2291714337 1.0136818561 2.8036493128
H -0.8594037813 -2.4661239083 -2.2309783928
H 0.7080595367 -2.0732382850 -2.5032982174
H -1.3132498520 -0.6901271514 -0.9163939649
H 1.6889557061 -0.7206820514 -0.3611372211
H 0.6818980659 -2.1236321622 0.0197593401
H 0.2183924342 1.5973044617 0.7751919144
Expand Down
51 changes: 48 additions & 3 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from context import RESOURCE_PATH, call_main

import rmsd as rmsdlib
from rmsd.calculate_rmsd import get_coordinates_pdb, get_coordinates_xyz, get_coordinates_xyz_lines


def test_print_reflection_reorder() -> None:
Expand Down Expand Up @@ -59,16 +60,19 @@ def test_print_reflection_reorder() -> None:

# Main call print, check rmsd is still the same
# Note, that --print is translating b to a center
args = f"--use-reflections --reorder --print {filename_a} {filename_b}"
stdout = call_main(args.split())
_, coord = rmsdlib.get_coordinates_xyz_lines(stdout)
_args = f"--use-reflections --reorder --print {filename_a} {filename_b}"
_stdout: str = rmsdlib.main(_args.split())
atoms, coord = rmsdlib.get_coordinates_xyz_lines(_stdout.split("\n"), return_atoms_as_int=True)
coord -= rmsdlib.centroid(coord) # fix translation
print(coord)
print(atoms)
print(atoms_b)

rmsd_check1 = rmsdlib.kabsch_rmsd(coord, coord_a)
rmsd_check2 = rmsdlib.rmsd(coord, coord_a)
print(rmsd_check1)
print(rmsd_check2)
print(result_rmsd)
np.testing.assert_almost_equal(rmsd_check2, rmsd_check1)
np.testing.assert_almost_equal(rmsd_check2, result_rmsd)

Expand Down Expand Up @@ -136,3 +140,44 @@ def test_ignore() -> None:
rmsdlib.main(f"{filename_a} {filename_b} --remove-idx 0 5".split())

rmsdlib.main(f"{filename_a} {filename_b} --add-idx 0 1 2 3 4".split())


def test_print_match_no_hydrogen() -> None:

filename_a = RESOURCE_PATH / "CHEMBL3039407_order.xyz"
filename_b = RESOURCE_PATH / "CHEMBL3039407_order.xyz"

cmd = f"--no-hydrogen --print {filename_a} {filename_b}"
print(cmd)
out = rmsdlib.main(cmd.split()).split("\n")
atoms1, coord1 = get_coordinates_xyz_lines(out)

print(atoms1)
print(len(atoms1))

assert len(atoms1) == 60
assert coord1.shape
assert "H" in atoms1

cmd = f"--print {filename_a} {filename_b}"
out = rmsdlib.main(cmd.split()).split("\n")
atoms2, coord2 = get_coordinates_xyz_lines(out)

print(atoms2)
print(len(atoms2))

assert len(atoms2) == 60
assert coord2.shape
assert "H" in atoms2

out = rmsdlib.main(
f"--no-hydrogen --print --print-only-rmsd-atoms {filename_a} {filename_b}".split()
).split("\n")
atoms1, coord1 = get_coordinates_xyz_lines(out)

print(atoms1)
print(len(atoms1))

assert len(atoms1) == 30
assert coord1.shape
assert "H" not in atoms1
24 changes: 16 additions & 8 deletions tests/test_reorder_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,32 @@ def test_reorder_print_and_rmsd() -> None:

filename_a = RESOURCE_PATH / "issue93" / "a.xyz"
filename_b = RESOURCE_PATH / "issue93" / "b.xyz"
atoms_a, coord_a = get_coordinates_xyz(filename_a)
atoms_b, coord_b = get_coordinates_xyz(filename_b)

# Get reorder rmsd
args = ["--reorder", f"{filename_a}", f"{filename_b}"]
stdout = call_main(args)
rmsd_ab = float(stdout[-1])
rmsd_ab = float(rmsdlib.main(f"--reorder {filename_a} {filename_b}".split()))
print(rmsd_ab)
assert isinstance(rmsd_ab, float)

# Get printed structure
stdout = call_main(args + ["--print"])
stdout = rmsdlib.main(f"--reorder --print {filename_a} {filename_b}".split())
print(stdout)
atoms_c, coord_c = get_coordinates_xyz_lines(stdout.split("\n"))

atoms_a, coord_a = get_coordinates_xyz(filename_a)
atoms_c, coord_c = get_coordinates_xyz_lines(stdout)
coord_c -= rmsdlib.centroid(coord_c)
coord_a -= rmsdlib.centroid(coord_a)

print(coord_a)
print(atoms_a)
print(atoms_b)
print(atoms_c)

print(coord_a)
print(coord_b)
print(coord_c)
print(atoms_c)

assert (atoms_a == atoms_c).all()
assert (atoms_a != atoms_b).any()

rmsd_ac = rmsdlib.rmsd(coord_a, coord_c)
print(rmsd_ac)
Expand Down
Loading