From 049f147f22f6add4d2d094825c46b0919bbfa329 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Thu, 2 Mar 2017 17:38:17 +0100 Subject: [PATCH 01/22] Add concurrence stuff (still needs improvements) --- contact_map/__init__.py | 2 +- contact_map/concurrence.py | 84 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 contact_map/concurrence.py diff --git a/contact_map/__init__.py b/contact_map/__init__.py index 3c16be9..5d09603 100644 --- a/contact_map/__init__.py +++ b/contact_map/__init__.py @@ -3,4 +3,4 @@ MinimumDistanceCounter ) -# import concurrence +import concurrence diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py new file mode 100644 index 0000000..2a7eacd --- /dev/null +++ b/contact_map/concurrence.py @@ -0,0 +1,84 @@ +import itertools +import mdtraj as md +import numpy as np + +class Concurrence(object): + def __init__(self, values, labels=None): + self.values = values + self.labels = labels + + @property + def lifetimes(self): + pass + + def set_labels(self, labels): + self.labels = labels + + def __getitem__(self, label): + idx = self.labels.index(label) + return self.values[idx] + + def coincidence(self, label_list): + this_list = np.asarray(self[label_list[0]]) + coincidence_list = this_list + norm_sq = sum(this_list) + for label in label_list[1:]: + this_list = np.asarray(self[label]) + coincidence_list &= this_list + norm_sq *= sum(this_list) + + return sum(coincidence_list) / np.sqrt(norm_sq) + + + +class AtomContactConcurrence(Concurrence): + def __init__(self, trajectory, atom_contacts, cutoff=0.45): + atom_pairs = [[contact[0][0].index, contact[0][1].index] + for contact in atom_contacts] + labels = [str(contact[0]) for contact in atom_contacts] + distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) + vector_f = np.vectorize(lambda d: d < cutoff) + values = zip(*vector_f(distances)) + super(AtomContactConcurrence, self).__init__(values=values, + labels=labels) + +class ResidueContactConcurrence(Concurrence): + def __init__(self, trajectory, residue_contacts, cutoff=0.45, + select="and not symbol == 'H'"): + residue_pairs = [[contact[0][0], contact[0][1]] + for contact in residue_contacts] + labels = [str(contact[0]) for contact in residue_contacts] + values = [] + for res_A, res_B in residue_pairs: + atoms_A = trajectory.topology.select("resid " + str(res_A.index) + + " " + select) + atoms_B = trajectory.topology.select("resid " + str(res_B.index) + + " " + select) + atom_pairs = itertools.product(atoms_A, atoms_B) + distances = md.compute_distances(trajectory, + atom_pairs=atom_pairs) + min_dists = [min(dists) for dists in distances] + values.append(map(lambda d: d < cutoff, min_dists)) + + super(ResidueContactConcurrence, self).__init__(values=values, + labels=labels) + +def plot_concurrence(concurrence, labels=None, x_values=None): + import matplotlib.pyplot as plt + if x_values is None: + x_values = range(len(concurrence.values[0])) + if labels is None: + if concurrence.labels is not None: + labels = concurrence.labels + else: + labels = [str(i) for i in range(len(values))] + + y_val = -1.0 + for label, val_set in zip(labels, concurrence.values): + x_vals = [x for (x, y) in zip(x_values, val_set) if y] + plt.plot(x_vals, [y_val] * len(x_vals), '.', markersize=1, label=label) + y_val -= 1.0 + + plt.ylim(ymax=0.0) + plt.xlim(xmin=min(x_values), xmax=max(x_values)) + plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) From c0a278c034a388d5648e1b9c86a9f7ca663f5d09 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Sun, 5 Nov 2017 21:12:43 +0100 Subject: [PATCH 02/22] Fix mixed-in tabs (I think) --- contact_map/concurrence.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 2a7eacd..bbe8342 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -33,12 +33,12 @@ def coincidence(self, label_list): class AtomContactConcurrence(Concurrence): def __init__(self, trajectory, atom_contacts, cutoff=0.45): - atom_pairs = [[contact[0][0].index, contact[0][1].index] - for contact in atom_contacts] - labels = [str(contact[0]) for contact in atom_contacts] - distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) - vector_f = np.vectorize(lambda d: d < cutoff) - values = zip(*vector_f(distances)) + atom_pairs = [[contact[0][0].index, contact[0][1].index] + for contact in atom_contacts] + labels = [str(contact[0]) for contact in atom_contacts] + distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) + vector_f = np.vectorize(lambda d: d < cutoff) + values = zip(*vector_f(distances)) super(AtomContactConcurrence, self).__init__(values=values, labels=labels) From d2150160133a73fbc1914651932064f48c783bac Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Sun, 5 Nov 2017 21:15:56 +0100 Subject: [PATCH 03/22] refactor select_residue as lambda fcn --- contact_map/concurrence.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index bbe8342..c691fec 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -49,11 +49,12 @@ def __init__(self, trajectory, residue_contacts, cutoff=0.45, for contact in residue_contacts] labels = [str(contact[0]) for contact in residue_contacts] values = [] + select_residue = lambda idx: trajectory.topology.select( + "resid " + str(idx) + " " + select + ) for res_A, res_B in residue_pairs: - atoms_A = trajectory.topology.select("resid " + str(res_A.index) - + " " + select) - atoms_B = trajectory.topology.select("resid " + str(res_B.index) - + " " + select) + atoms_A = select_residue(res_A.index) + atoms_A = select_residue(res_B.index) atom_pairs = itertools.product(atoms_A, atoms_B) distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) From 16bb06bd2ca3cdbee387025688533063e797edc2 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Sun, 5 Nov 2017 21:52:21 +0100 Subject: [PATCH 04/22] Refactor plot_concurrence into object + function --- contact_map/concurrence.py | 70 ++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index c691fec..1e74f05 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -2,6 +2,14 @@ import mdtraj as md import numpy as np +try: + import matplotlib.pyplot as plt +except ImportError: + HAS_MATPLOTLIB = False +else: + HAS_MATPLOTLIB = True + + class Concurrence(object): def __init__(self, values, labels=None): self.values = values @@ -64,22 +72,48 @@ def __init__(self, trajectory, residue_contacts, cutoff=0.45, super(ResidueContactConcurrence, self).__init__(values=values, labels=labels) + +class ConcurrencePlotter(object): + def __init__(self, concurrence=None, labels=None, x_values=None): + self.concurrence = concurrence + self.labels = self.get_concurrence_labels(concurrence, labels) + self.x_values = self.get_x_values(x_values) + + def get_concurrence_labels(concurrence, labels): + if labels is None: + if concurrence and concurrence.labels is not None: + labels = concurrence.labels + else: + labels = [str(i) for i in range(len(values))] + return labels + + def get_x_values(x_values): + if x_values is None: + x_values = range(len(concurrence.values[0])) + return x_values + + def plot(concurrence=None): + if not HAS_MATPLOTLIB: + raise ImportError("matplotlib not installed") + if concurrence is None: + concurrence = self.concurrence + labels = self.get_concurrence_labels(concurrence=concurrence) + x_values = self.get_x_values() + + y_val = -1.0 + for label, val_set in zip(labels, concurrence.values): + x_vals = [x for (x, y) in zip(x_values, val_set) if y] + plt.plot(x_vals, [y_val] * len(x_vals), '.', markersize=1, + label=label) + y_val -= 1.0 + + plt.ylim(ymax=0.0) + plt.xlim(xmin=min(x_values), xmax=max(x_values)) + plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + + def plot_concurrence(concurrence, labels=None, x_values=None): - import matplotlib.pyplot as plt - if x_values is None: - x_values = range(len(concurrence.values[0])) - if labels is None: - if concurrence.labels is not None: - labels = concurrence.labels - else: - labels = [str(i) for i in range(len(values))] - - y_val = -1.0 - for label, val_set in zip(labels, concurrence.values): - x_vals = [x for (x, y) in zip(x_values, val_set) if y] - plt.plot(x_vals, [y_val] * len(x_vals), '.', markersize=1, label=label) - y_val -= 1.0 - - plt.ylim(ymax=0.0) - plt.xlim(xmin=min(x_values), xmax=max(x_values)) - plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + """ + Convenience function for concurrence plots. + """ + ConcurrencePlotter(concurrence, labels, x_values).plot() From 5d9a440fa915e3b0cbb77dd9269330f32a0f8fa8 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Mon, 6 Nov 2017 13:36:44 +0100 Subject: [PATCH 05/22] Python 3 import fix --- contact_map/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contact_map/__init__.py b/contact_map/__init__.py index 1b61899..a5ce8f9 100644 --- a/contact_map/__init__.py +++ b/contact_map/__init__.py @@ -11,4 +11,4 @@ from .min_dist import NearestAtoms, MinimumDistanceCounter -import concurrence +import .concurrence From 1ab3d5169ed1c273820478c067b8f9b999db4baf Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Mon, 6 Nov 2017 13:38:03 +0100 Subject: [PATCH 06/22] fix stupid error --- contact_map/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contact_map/__init__.py b/contact_map/__init__.py index a5ce8f9..a4292ab 100644 --- a/contact_map/__init__.py +++ b/contact_map/__init__.py @@ -11,4 +11,4 @@ from .min_dist import NearestAtoms, MinimumDistanceCounter -import .concurrence +from . import concurrence From bac1610f166620963166094cd2dd1681d0250b47 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Fri, 24 Nov 2017 17:02:00 +0100 Subject: [PATCH 07/22] Fixes to concurrence plotting --- contact_map/concurrence.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 1e74f05..7165df9 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -77,9 +77,10 @@ class ConcurrencePlotter(object): def __init__(self, concurrence=None, labels=None, x_values=None): self.concurrence = concurrence self.labels = self.get_concurrence_labels(concurrence, labels) - self.x_values = self.get_x_values(x_values) + self.x_values = x_values - def get_concurrence_labels(concurrence, labels): + @staticmethod + def get_concurrence_labels(concurrence, labels=None): if labels is None: if concurrence and concurrence.labels is not None: labels = concurrence.labels @@ -87,33 +88,43 @@ def get_concurrence_labels(concurrence, labels): labels = [str(i) for i in range(len(values))] return labels - def get_x_values(x_values): + @property + def x_values(self): + x_values = self._x_values if x_values is None: - x_values = range(len(concurrence.values[0])) + x_values = range(len(self.concurrence.values[0])) return x_values - def plot(concurrence=None): + @x_values.setter + def x_values(self, x_values): + self._x_values = x_values + + def plot(self, concurrence=None): if not HAS_MATPLOTLIB: raise ImportError("matplotlib not installed") if concurrence is None: concurrence = self.concurrence labels = self.get_concurrence_labels(concurrence=concurrence) - x_values = self.get_x_values() + x_values = self.x_values + + fig = plt.figure(1) + ax = fig.add_subplot(111) y_val = -1.0 for label, val_set in zip(labels, concurrence.values): x_vals = [x for (x, y) in zip(x_values, val_set) if y] - plt.plot(x_vals, [y_val] * len(x_vals), '.', markersize=1, - label=label) + ax.plot(x_vals, [y_val] * len(x_vals), '.', markersize=1, + label=label) y_val -= 1.0 - plt.ylim(ymax=0.0) - plt.xlim(xmin=min(x_values), xmax=max(x_values)) - plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + ax.set_ylim(top=0.0) + ax.set_xlim(left=min(x_values), right=max(x_values)) + lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) + return (fig, ax, lgd) def plot_concurrence(concurrence, labels=None, x_values=None): """ Convenience function for concurrence plots. """ - ConcurrencePlotter(concurrence, labels, x_values).plot() + return ConcurrencePlotter(concurrence, labels, x_values).plot() From 1fef87af897666d67068bb9abbcdb98068772629 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 1 May 2018 15:57:39 +0200 Subject: [PATCH 08/22] fix outdated default selection --- contact_map/concurrence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 7165df9..111ff43 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -52,7 +52,7 @@ def __init__(self, trajectory, atom_contacts, cutoff=0.45): class ResidueContactConcurrence(Concurrence): def __init__(self, trajectory, residue_contacts, cutoff=0.45, - select="and not symbol == 'H'"): + select="and symbol != 'H'"): residue_pairs = [[contact[0][0], contact[0][1]] for contact in residue_contacts] labels = [str(contact[0]) for contact in residue_contacts] From 7275242b3b4313caa3bc36ae8f7fa363dc09e5bb Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 29 May 2018 15:13:50 +0200 Subject: [PATCH 09/22] Correct typo in residue contact concurrence --- contact_map/concurrence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 111ff43..cf143a1 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -62,7 +62,7 @@ def __init__(self, trajectory, residue_contacts, cutoff=0.45, ) for res_A, res_B in residue_pairs: atoms_A = select_residue(res_A.index) - atoms_A = select_residue(res_B.index) + atoms_B = select_residue(res_B.index) atom_pairs = itertools.product(atoms_A, atoms_B) distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) From e8747473ff4fa50cdd0e2f6d5407af37065e2904 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 4 Jul 2018 13:00:43 +0200 Subject: [PATCH 10/22] Start tests for concurrence --- contact_map/concurrence.py | 30 ++++----- contact_map/tests/concurrence.pdb | 67 ++++++++++++++++++++ contact_map/tests/test_concurrence.py | 91 +++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 14 deletions(-) create mode 100644 contact_map/tests/concurrence.pdb create mode 100644 contact_map/tests/test_concurrence.py diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index cf143a1..cf23722 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -15,9 +15,9 @@ def __init__(self, values, labels=None): self.values = values self.labels = labels - @property - def lifetimes(self): - pass + # @property + # def lifetimes(self): + # pass def set_labels(self, labels): self.labels = labels @@ -26,17 +26,19 @@ def __getitem__(self, label): idx = self.labels.index(label) return self.values[idx] - def coincidence(self, label_list): - this_list = np.asarray(self[label_list[0]]) - coincidence_list = this_list - norm_sq = sum(this_list) - for label in label_list[1:]: - this_list = np.asarray(self[label]) - coincidence_list &= this_list - norm_sq *= sum(this_list) - - return sum(coincidence_list) / np.sqrt(norm_sq) + # temporarily removed until we find a good metric here; this metric did + # not seem optimzal and I stopped using it, so remove from code before + # release (can add back in later) + # def coincidence(self, label_list): + # this_list = np.asarray(self[label_list[0]]) + # coincidence_list = this_list + # norm_sq = sum(this_list) + # for label in label_list[1:]: + # this_list = np.asarray(self[label]) + # coincidence_list &= this_list + # norm_sq *= sum(this_list) + # return sum(coincidence_list) / np.sqrt(norm_sq) class AtomContactConcurrence(Concurrence): @@ -100,7 +102,7 @@ def x_values(self, x_values): self._x_values = x_values def plot(self, concurrence=None): - if not HAS_MATPLOTLIB: + if not HAS_MATPLOTLIB: # pragma: no cover raise ImportError("matplotlib not installed") if concurrence is None: concurrence = self.concurrence diff --git a/contact_map/tests/concurrence.pdb b/contact_map/tests/concurrence.pdb new file mode 100644 index 0000000..b32f329 --- /dev/null +++ b/contact_map/tests/concurrence.pdb @@ -0,0 +1,67 @@ +REMARK 225 HAND-WRITTEN TEST TRAJECTORY +REMARK 225 COMMENTS HERE COUNT FROM 1 (CODE COUNTS FROM 0, PDB FROM 1) +REMARK 225 SYSTEM CONSISTS OF TWO CHAINS: A TWO-RESIDUE "PROTEIN" (CHAIN A) +REMARK 225 AND A "LIGAND" (CHAIN B). +REMARK 225 IN FRAMES 1, 2 AND 5, THE PROTEIN IS IN THE SAME CONFIGURATION +REMARK 225 * FRAME 1 +REMARK 225 * FRAME 2 +REMARK 225 * FRAME 3 +REMARK 225 * FRAME 4 +REMARK 225 * FRAME 5 +REMARK 225 +CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 +HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C +HETATM 2 C2 AAA A 1 0.250 1.750 0.000 1.00 0.00 C +HETATM 3 H AAA A 1 0.250 2.750 0.000 1.00 0.00 H +HETATM 4 C1 BBB A 2 0.250 0.750 0.000 1.00 0.00 C +HETATM 5 C2 BBB A 2 0.250 1.250 0.000 1.00 0.00 C +HETATM 6 H BBB A 2 0.250 0.250 0.000 1.00 0.00 H +HETATM 7 C1 LLL B 3 0.750 1.750 0.000 1.00 0.00 C +HETATM 8 C2 LLL B 3 0.750 1.250 0.000 1.00 0.00 C +HETATM 9 H LLL B 3 1.250 1.750 0.000 1.00 0.00 H +ENDMDL +CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 +HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C +HETATM 2 C2 AAA A 1 0.250 1.750 0.000 1.00 0.00 C +HETATM 3 H AAA A 1 0.250 2.750 0.000 1.00 0.00 H +HETATM 4 C1 BBB A 2 0.250 0.750 0.000 1.00 0.00 C +HETATM 5 C2 BBB A 2 0.250 1.250 0.000 1.00 0.00 C +HETATM 6 H BBB A 2 0.250 0.250 0.000 1.00 0.00 H +HETATM 7 C1 LLL B 3 0.750 2.250 0.000 1.00 0.00 C +HETATM 8 C2 LLL B 3 0.750 1.750 0.000 1.00 0.00 C +HETATM 9 H LLL B 3 1.250 2.250 0.000 1.00 0.00 H +ENDMDL +CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 +HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C +HETATM 2 C2 AAA A 1 0.250 1.750 0.000 1.00 0.00 C +HETATM 3 H AAA A 1 0.750 2.250 0.000 1.00 0.00 H +HETATM 4 C1 BBB A 2 0.250 0.750 0.000 1.00 0.00 C +HETATM 5 C2 BBB A 2 0.250 1.250 0.000 1.00 0.00 C +HETATM 6 H BBB A 2 0.250 0.250 0.000 1.00 0.00 H +HETATM 7 C1 LLL B 3 1.250 2.250 0.000 1.00 0.00 C +HETATM 8 C2 LLL B 3 1.750 1.750 0.000 1.00 0.00 C +HETATM 9 H LLL B 3 1.750 2.250 0.000 1.00 0.00 H +ENDMDL +CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 +HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C +HETATM 2 C2 AAA A 1 0.250 1.750 0.000 1.00 0.00 C +HETATM 3 H AAA A 1 0.750 2.250 0.000 1.00 0.00 H +HETATM 4 C1 BBB A 2 1.250 1.250 0.000 1.00 0.00 C +HETATM 5 C2 BBB A 2 0.750 1.250 0.000 1.00 0.00 C +HETATM 6 H BBB A 2 1.750 1.250 0.000 1.00 0.00 H +HETATM 7 C1 LLL B 3 1.250 2.250 0.000 1.00 0.00 C +HETATM 8 C2 LLL B 3 1.750 1.750 0.000 1.00 0.00 C +HETATM 9 H LLL B 3 1.750 2.250 0.000 1.00 0.00 H +ENDMDL +CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 +HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C +HETATM 2 C2 AAA A 1 0.250 1.750 0.000 1.00 0.00 C +HETATM 3 H AAA A 1 0.250 2.750 0.000 1.00 0.00 H +HETATM 4 C1 BBB A 2 0.250 0.750 0.000 1.00 0.00 C +HETATM 5 C2 BBB A 2 0.250 1.250 0.000 1.00 0.00 C +HETATM 6 H BBB A 2 0.250 0.250 0.000 1.00 0.00 H +HETATM 7 C1 LLL B 3 1.250 2.250 0.000 1.00 0.00 C +HETATM 8 C2 LLL B 3 1.250 1.750 0.000 1.00 0.00 C +HETATM 9 H LLL B 3 1.250 2.750 0.000 1.00 0.00 H +ENDMDL +END diff --git a/contact_map/tests/test_concurrence.py b/contact_map/tests/test_concurrence.py new file mode 100644 index 0000000..a3bcefa --- /dev/null +++ b/contact_map/tests/test_concurrence.py @@ -0,0 +1,91 @@ +from .utils import * + +from contact_map.concurrence import * +from contact_map import ContactFrequency + +def setup_module(): + global traj, contacts + traj = md.load(find_testfile("concurrence.pdb")) + query = traj.topology.select("resSeq 3") + haystack = traj.topology.select("resSeq 1 to 2") + # note that this includes *all* atoms + contacts = ContactFrequency(traj, query, haystack, cutoff=0.051, + n_neighbors_ignored=0) + +class TestAtomContactConcurrence(object): + def setup(self): + pass + + def test_default_labels(self): + pass + + def test_get_items(self): + pass + + def test_values(self): + pass + +class TestResidueContactConcurrence(object): + def setup(self): + self.heavy_contact_concurrence = ResidueContactConcurrence( + trajectory=traj, + residue_contacts=contacts.residue_contacts.most_common(), + cutoff=0.051 + ) + + self.all_contact_concurrence = ResidueContactConcurrence( + trajectory=traj, + residue_contacts=contacts.residue_contacts.most_common(), + cutoff=0.051, + select="" + ) + + @pytest.mark.parametrize('conc_type', ('heavy', 'all')) + def test_default_labels(self, conc_type): + concurrence = {'heavy': self.heavy_contact_concurrence, + 'all': self.all_contact_concurrence}[conc_type] + residue_labels = ['[AAA1, LLL3]', '[BBB2, LLL3]', + '[LLL3, AAA1]', '[LLL3, BBB2]'] + + assert len(concurrence.labels) == 2 + for label in concurrence.labels: + assert label in residue_labels + + def test_get_items(self): + pass + + @pytest.mark.parametrize('conc_type', ('heavy', 'all')) + def test_getitem(self, conc_type): + concurrence = {'heavy': self.heavy_contact_concurrence, + 'all': self.all_contact_concurrence}[conc_type] + label_to_pair = {'[AAA1, LLL3]': 'AL', + '[LLL3, AAA1]': 'AL', + '[BBB2, LLL3]': 'BL', + '[LLL3, BBB2]': 'BL'} + pair_to_expected = { + 'heavy': { + 'AL': [True, True, False, False, False], + 'BL': [True, False, False, False, False] + }, + 'all': { + 'AL': [True, True, True, True, False], + 'BL': [True, False, False, True, False] + } + }[conc_type] + for label in concurrence.labels: + values = concurrence[label] + pair = label_to_pair[label] + expected_values = pair_to_expected[pair] + assert values == expected_values + + +class TestConcurrencePlotter(object): + def setup(self): + pass + + def test_x_values(self): + pass + + def test_plot(self): + # SMOKE TEST + pass From fa4783809437d20504946632e62977bd7d95ea71 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 4 Jul 2018 16:32:24 +0200 Subject: [PATCH 11/22] Tests on concurrences --- contact_map/concurrence.py | 4 +- contact_map/tests/concurrence.pdb | 8 +- contact_map/tests/test_concurrence.py | 123 ++++++++++++++++++-------- 3 files changed, 90 insertions(+), 45 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index cf23722..1dc5490 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -27,7 +27,7 @@ def __getitem__(self, label): return self.values[idx] # temporarily removed until we find a good metric here; this metric did - # not seem optimzal and I stopped using it, so remove from code before + # not seem optimal and I stopped using it, so remove from code before # release (can add back in later) # def coincidence(self, label_list): # this_list = np.asarray(self[label_list[0]]) @@ -48,7 +48,7 @@ def __init__(self, trajectory, atom_contacts, cutoff=0.45): labels = [str(contact[0]) for contact in atom_contacts] distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) vector_f = np.vectorize(lambda d: d < cutoff) - values = zip(*vector_f(distances)) + values = list(map(list, zip(*vector_f(distances)))) super(AtomContactConcurrence, self).__init__(values=values, labels=labels) diff --git a/contact_map/tests/concurrence.pdb b/contact_map/tests/concurrence.pdb index b32f329..e13a682 100644 --- a/contact_map/tests/concurrence.pdb +++ b/contact_map/tests/concurrence.pdb @@ -38,9 +38,9 @@ HETATM 3 H AAA A 1 0.750 2.250 0.000 1.00 0.00 H HETATM 4 C1 BBB A 2 0.250 0.750 0.000 1.00 0.00 C HETATM 5 C2 BBB A 2 0.250 1.250 0.000 1.00 0.00 C HETATM 6 H BBB A 2 0.250 0.250 0.000 1.00 0.00 H -HETATM 7 C1 LLL B 3 1.250 2.250 0.000 1.00 0.00 C +HETATM 7 C1 LLL B 3 1.750 2.250 0.000 1.00 0.00 C HETATM 8 C2 LLL B 3 1.750 1.750 0.000 1.00 0.00 C -HETATM 9 H LLL B 3 1.750 2.250 0.000 1.00 0.00 H +HETATM 9 H LLL B 3 1.250 2.250 0.000 1.00 0.00 H ENDMDL CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C @@ -49,9 +49,9 @@ HETATM 3 H AAA A 1 0.750 2.250 0.000 1.00 0.00 H HETATM 4 C1 BBB A 2 1.250 1.250 0.000 1.00 0.00 C HETATM 5 C2 BBB A 2 0.750 1.250 0.000 1.00 0.00 C HETATM 6 H BBB A 2 1.750 1.250 0.000 1.00 0.00 H -HETATM 7 C1 LLL B 3 1.250 2.250 0.000 1.00 0.00 C +HETATM 7 C1 LLL B 3 1.750 2.250 0.000 1.00 0.00 C HETATM 8 C2 LLL B 3 1.750 1.750 0.000 1.00 0.00 C -HETATM 9 H LLL B 3 1.750 2.250 0.000 1.00 0.00 H +HETATM 9 H LLL B 3 1.250 2.250 0.000 1.00 0.00 H ENDMDL CRYST1 25.000 25.000 25.000 90.00 90.00 90.00 P 1 1 HETATM 1 C1 AAA A 1 0.250 2.250 0.000 1.00 0.00 C diff --git a/contact_map/tests/test_concurrence.py b/contact_map/tests/test_concurrence.py index a3bcefa..41a4f84 100644 --- a/contact_map/tests/test_concurrence.py +++ b/contact_map/tests/test_concurrence.py @@ -12,57 +12,85 @@ def setup_module(): contacts = ContactFrequency(traj, query, haystack, cutoff=0.051, n_neighbors_ignored=0) -class TestAtomContactConcurrence(object): +class ContactConcurrenceTester(object): + def _test_default_labels(self, concurrence): + assert len(concurrence.labels) == len(self.labels) / 2 + for label in concurrence.labels: + assert label in self.labels + + def _test_set_labels(self, concurrence, expected): + new_labels = [self.label_to_pair[label] + for label in concurrence.labels] + concurrence.set_labels(new_labels) + for label in new_labels: + assert concurrence[label] == expected[label] + + def _test_getitem(self, concurrence, pair_to_expected): + for label in concurrence.labels: + values = concurrence[label] + pair = self.label_to_pair[label] + expected_values = pair_to_expected[pair] + assert values == expected_values + +class TestAtomContactConcurrence(ContactConcurrenceTester): def setup(self): - pass + self.concurrence = AtomContactConcurrence( + trajectory=traj, + atom_contacts=contacts.atom_contacts.most_common(), + cutoff=0.051 + ) + # dupes each direction until we have better way to handle frozensets + self.label_to_pair = {'[AAA1-H, LLL3-H]': 'AH-LH', + '[LLL3-H, AAA1-H]': 'AH-LH', + '[AAA1-C1, LLL3-C1]': 'AC1-LC1', + '[LLL3-C1, AAA1-C1]': 'AC1-LC1', + '[BBB2-H, LLL3-C2]': 'BH-LC2', + '[LLL3-C2, BBB2-H]': 'BH-LC2', + '[AAA1-C2, LLL3-C2]': 'AC2-LC2', + '[LLL3-C2, AAA1-C2]': 'AC2-LC2', + '[AAA1-C2, LLL3-C1]': 'AC2-LC1', + '[LLL3-C1, AAA1-C2]': 'AC2-LC1', + '[BBB2-C2, LLL3-C2]': 'BC2-LC2', + '[LLL3-C2, BBB2-C2]': 'BC2-LC2'} + self.labels = list(self.label_to_pair.keys()) + self.pair_to_expected = { + 'AH-LH': [False, False, True, True, False], + 'AC1-LC1': [False, True, False, False, False], + 'BH-LC2': [False, False, False, True, False], + 'AC2-LC2': [False, True, False, False, False], + 'AC2-LC1': [True, False, False, False, False], + 'BC2-LC2': [True, False, False, False, False] + } def test_default_labels(self): - pass + self._test_default_labels(self.concurrence) - def test_get_items(self): - pass + def test_getitem(self): + self._test_getitem(self.concurrence, self.pair_to_expected) - def test_values(self): - pass + def test_set_labels(self): + self._test_set_labels(self.concurrence, self.pair_to_expected) -class TestResidueContactConcurrence(object): + +class TestResidueContactConcurrence(ContactConcurrenceTester): def setup(self): self.heavy_contact_concurrence = ResidueContactConcurrence( trajectory=traj, residue_contacts=contacts.residue_contacts.most_common(), cutoff=0.051 ) - self.all_contact_concurrence = ResidueContactConcurrence( trajectory=traj, residue_contacts=contacts.residue_contacts.most_common(), cutoff=0.051, select="" ) - - @pytest.mark.parametrize('conc_type', ('heavy', 'all')) - def test_default_labels(self, conc_type): - concurrence = {'heavy': self.heavy_contact_concurrence, - 'all': self.all_contact_concurrence}[conc_type] - residue_labels = ['[AAA1, LLL3]', '[BBB2, LLL3]', - '[LLL3, AAA1]', '[LLL3, BBB2]'] - - assert len(concurrence.labels) == 2 - for label in concurrence.labels: - assert label in residue_labels - - def test_get_items(self): - pass - - @pytest.mark.parametrize('conc_type', ('heavy', 'all')) - def test_getitem(self, conc_type): - concurrence = {'heavy': self.heavy_contact_concurrence, - 'all': self.all_contact_concurrence}[conc_type] - label_to_pair = {'[AAA1, LLL3]': 'AL', - '[LLL3, AAA1]': 'AL', - '[BBB2, LLL3]': 'BL', - '[LLL3, BBB2]': 'BL'} - pair_to_expected = { + self.label_to_pair = {'[AAA1, LLL3]': 'AL', + '[LLL3, AAA1]': 'AL', + '[BBB2, LLL3]': 'BL', + '[LLL3, BBB2]': 'BL'} + self.labels = list(self.label_to_pair.keys()) + self.pair_to_expected = { 'heavy': { 'AL': [True, True, False, False, False], 'BL': [True, False, False, False, False] @@ -71,12 +99,29 @@ def test_getitem(self, conc_type): 'AL': [True, True, True, True, False], 'BL': [True, False, False, True, False] } - }[conc_type] - for label in concurrence.labels: - values = concurrence[label] - pair = label_to_pair[label] - expected_values = pair_to_expected[pair] - assert values == expected_values + } + + def _get_concurrence(self, conc_type): + return {'heavy': self.heavy_contact_concurrence, + 'all': self.all_contact_concurrence}[conc_type] + + + @pytest.mark.parametrize('conc_type', ('heavy', 'all')) + def test_default_labels(self, conc_type): + concurrence = self._get_concurrence(conc_type) + self._test_default_labels(concurrence) + + def test_set_labels(self): + # only run this for the heavy atom concurrence + concurrence = self.heavy_contact_concurrence + expected = self.pair_to_expected['heavy'] + self._test_set_labels(concurrence, expected) + + @pytest.mark.parametrize('conc_type', ('heavy', 'all')) + def test_getitem(self, conc_type): + concurrence = self._get_concurrence(conc_type) + pair_to_expected = self.pair_to_expected[conc_type] + self._test_getitem(concurrence, pair_to_expected) class TestConcurrencePlotter(object): From 30ecd46a27530db840de6d52963761873dc390e7 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 4 Jul 2018 16:51:30 +0200 Subject: [PATCH 12/22] tests for concurrence plotting --- .coveragerc | 1 + contact_map/concurrence.py | 6 ++--- contact_map/tests/test_concurrence.py | 32 +++++++++++++++++++++++---- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/.coveragerc b/.coveragerc index 279183f..58f7a9a 100644 --- a/.coveragerc +++ b/.coveragerc @@ -8,5 +8,6 @@ omit = */_version.py exclude_lines = pragma: no cover + -no-cov- def __repr__ raise NotImplementedError diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 1dc5490..804897a 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -87,14 +87,14 @@ def get_concurrence_labels(concurrence, labels=None): if concurrence and concurrence.labels is not None: labels = concurrence.labels else: - labels = [str(i) for i in range(len(values))] + labels = [str(i) for i in range(len(concurrence.values))] return labels @property def x_values(self): x_values = self._x_values if x_values is None: - x_values = range(len(self.concurrence.values[0])) + x_values = list(range(len(self.concurrence.values[0]))) return x_values @x_values.setter @@ -125,7 +125,7 @@ def plot(self, concurrence=None): return (fig, ax, lgd) -def plot_concurrence(concurrence, labels=None, x_values=None): +def plot_concurrence(concurrence, labels=None, x_values=None): # -no-cov- """ Convenience function for concurrence plots. """ diff --git a/contact_map/tests/test_concurrence.py b/contact_map/tests/test_concurrence.py index 41a4f84..d3e562f 100644 --- a/contact_map/tests/test_concurrence.py +++ b/contact_map/tests/test_concurrence.py @@ -126,11 +126,35 @@ def test_getitem(self, conc_type): class TestConcurrencePlotter(object): def setup(self): - pass + self.concurrence = ResidueContactConcurrence( + trajectory=traj, + residue_contacts=contacts.residue_contacts.most_common(), + cutoff=0.051 + ) + self.plotter = ConcurrencePlotter(self.concurrence) def test_x_values(self): - pass + time_values = [0.3, 0.4, 0.5, 0.6, 0.7] + assert self.plotter.x_values == [0, 1, 2, 3, 4] + self.plotter.x_values = time_values + assert self.plotter.x_values == time_values + + def test_get_concurrence_labels_given(self): + alpha_labels = ['a', 'b'] + labels = self.plotter.get_concurrence_labels(self.concurrence, + labels=alpha_labels) + assert labels == alpha_labels + + def test_get_concurrence_labels_default(self): + labels = self.plotter.get_concurrence_labels(self.concurrence) + assert labels == self.concurrence.labels + + def test_get_concurrence_label_none_in_concurrence(self): + numeric_labels = ['0', '1'] + self.concurrence.labels = None + labels = self.plotter.get_concurrence_labels(self.concurrence) + assert labels == numeric_labels def test_plot(self): - # SMOKE TEST - pass + # SMOKE TEST ONLY + self.plotter.plot() From f31130e6a4fac9d31f98a280022e0f91276a847f Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 4 Jul 2018 17:06:09 +0200 Subject: [PATCH 13/22] minor test fixes --- contact_map/__init__.py | 2 +- contact_map/concurrence.py | 2 +- contact_map/tests/test_concurrence.py | 16 +++++++++++++--- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/contact_map/__init__.py b/contact_map/__init__.py index 350489a..d48eec9 100644 --- a/contact_map/__init__.py +++ b/contact_map/__init__.py @@ -17,4 +17,4 @@ from .dask_runner import DaskContactFrequency -from . import plot_utils \ No newline at end of file +from . import plot_utils diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 804897a..7e34a90 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -69,7 +69,7 @@ def __init__(self, trajectory, residue_contacts, cutoff=0.45, distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) min_dists = [min(dists) for dists in distances] - values.append(map(lambda d: d < cutoff, min_dists)) + values.append(list(map(lambda d: d < cutoff, min_dists))) super(ResidueContactConcurrence, self).__init__(values=values, labels=labels) diff --git a/contact_map/tests/test_concurrence.py b/contact_map/tests/test_concurrence.py index d3e562f..1be7538 100644 --- a/contact_map/tests/test_concurrence.py +++ b/contact_map/tests/test_concurrence.py @@ -1,8 +1,15 @@ +# pylint: disable=wildcard-import, missing-docstring, protected-access +# pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use +# pylint: disable=wrong-import-order, unused-wildcard-import + from .utils import * from contact_map.concurrence import * from contact_map import ContactFrequency +# pylint: disable=wildcard-import, missing-docstring, protected-access +# pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use +# pylint: disable=wrong-import-order, unused-wildcard-import def setup_module(): global traj, contacts traj = md.load(find_testfile("concurrence.pdb")) @@ -12,6 +19,7 @@ def setup_module(): contacts = ContactFrequency(traj, query, haystack, cutoff=0.051, n_neighbors_ignored=0) + class ContactConcurrenceTester(object): def _test_default_labels(self, concurrence): assert len(concurrence.labels) == len(self.labels) / 2 @@ -19,7 +27,7 @@ def _test_default_labels(self, concurrence): assert label in self.labels def _test_set_labels(self, concurrence, expected): - new_labels = [self.label_to_pair[label] + new_labels = [self.label_to_pair[label] for label in concurrence.labels] concurrence.set_labels(new_labels) for label in new_labels: @@ -32,6 +40,7 @@ def _test_getitem(self, concurrence, pair_to_expected): expected_values = pair_to_expected[pair] assert values == expected_values + class TestAtomContactConcurrence(ContactConcurrenceTester): def setup(self): self.concurrence = AtomContactConcurrence( @@ -40,7 +49,7 @@ def setup(self): cutoff=0.051 ) # dupes each direction until we have better way to handle frozensets - self.label_to_pair = {'[AAA1-H, LLL3-H]': 'AH-LH', + self.label_to_pair = {'[AAA1-H, LLL3-H]': 'AH-LH', '[LLL3-H, AAA1-H]': 'AH-LH', '[AAA1-C1, LLL3-C1]': 'AC1-LC1', '[LLL3-C1, AAA1-C1]': 'AC1-LC1', @@ -144,7 +153,7 @@ def test_get_concurrence_labels_given(self): labels = self.plotter.get_concurrence_labels(self.concurrence, labels=alpha_labels) assert labels == alpha_labels - + def test_get_concurrence_labels_default(self): labels = self.plotter.get_concurrence_labels(self.concurrence) assert labels == self.concurrence.labels @@ -157,4 +166,5 @@ def test_get_concurrence_label_none_in_concurrence(self): def test_plot(self): # SMOKE TEST ONLY + pytest.importorskip('matplotlib.pyplot') self.plotter.plot() From 9c9987b2120fee78ceffea180adeb4692e433754 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 4 Jul 2018 17:55:52 +0200 Subject: [PATCH 14/22] docstrings; add in __init__ --- contact_map/__init__.py | 5 +- contact_map/concurrence.py | 114 ++++++++++++++++++++++++++++++++++++- docs/api.rst | 12 ++++ 3 files changed, 129 insertions(+), 2 deletions(-) diff --git a/contact_map/__init__.py b/contact_map/__init__.py index d48eec9..b46f5b6 100644 --- a/contact_map/__init__.py +++ b/contact_map/__init__.py @@ -13,7 +13,10 @@ from .min_dist import NearestAtoms, MinimumDistanceCounter -from . import concurrence +from .concurrence import ( + Concurrence, AtomContactConcurrence, ResidueContactConcurrence, + ConcurrencePlotter, plot_concurrence +) from .dask_runner import DaskContactFrequency diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 7e34a90..fa08f30 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -11,6 +11,22 @@ class Concurrence(object): + """Superclass for contact concurrence objects. + + Contact concurrences measure what contacts occur simultaneously in a + trajectory. When defining states, one usually wants to characterize + based on multiple contacts that are made simultaneously; contact + concurrences makes it easier to identify those. + + Parameters + ---------- + values : list of list of bool + the whether a contact is present for each contact pair at each + point in time; outer list is length number of frames, inner list + in length number of (included) contacts + labels : list of string + labels for each contact pair + """ def __init__(self, values, labels=None): self.values = values self.labels = labels @@ -20,6 +36,13 @@ def __init__(self, values, labels=None): # pass def set_labels(self, labels): + """Set the contact labels + + Parameters + ---------- + labels : list of string + labels for each contact pair + """ self.labels = labels def __getitem__(self, label): @@ -42,9 +65,20 @@ def __getitem__(self, label): class AtomContactConcurrence(Concurrence): + """Contact concurrences for atom contacts. + + Parameters + ---------- + trajectory : :class:`mdtraj.Trajectory` + the trajectory to analyze + atom_contacts : list + output from ``contact_map.atom_contacts.most_common()`` + cutoff : float + cutoff, in nm. Should be the same as used in the contact map. + """ def __init__(self, trajectory, atom_contacts, cutoff=0.45): atom_pairs = [[contact[0][0].index, contact[0][1].index] - for contact in atom_contacts] + for contact in atom_contacts] labels = [str(contact[0]) for contact in atom_contacts] distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) vector_f = np.vectorize(lambda d: d < cutoff) @@ -52,7 +86,22 @@ def __init__(self, trajectory, atom_contacts, cutoff=0.45): super(AtomContactConcurrence, self).__init__(values=values, labels=labels) + class ResidueContactConcurrence(Concurrence): + """Contact concurrences for residue contacts. + + Parameters + ---------- + trajectory : :class:`mdtraj.Trajectory` + the trajectory to analyze + residue_contacts : list + output from ``contact_map.residue_contacts.most_common()`` + cutoff : float + cutoff, in nm. Should be the same as used in the contact map. + select : string + additional atom selection string for MDTraj; defaults to "and symbol + != 'H'" + """ def __init__(self, trajectory, residue_contacts, cutoff=0.45, select="and symbol != 'H'"): residue_pairs = [[contact[0][0], contact[0][1]] @@ -76,6 +125,20 @@ def __init__(self, trajectory, residue_contacts, cutoff=0.45, class ConcurrencePlotter(object): + """Plot manager for contact concurrences. + + Parameters + ---------- + concurrence : :class:`.Concurrence` + concurrence to plot; default None allows to override later + labels : list of string + labels for the contact pairs, default None will use concurrence + labels if available, integers if not + x_values : list of numeric + values to use for the time axis; default None uses integers starting + at 0 (can be used to assign the actual simulation time to the + x-axis) + """ def __init__(self, concurrence=None, labels=None, x_values=None): self.concurrence = concurrence self.labels = self.get_concurrence_labels(concurrence, labels) @@ -83,6 +146,26 @@ def __init__(self, concurrence=None, labels=None, x_values=None): @staticmethod def get_concurrence_labels(concurrence, labels=None): + """Extract labels for contact from a concurrence object + + If ``labels`` is given, that is returned. Otherwise, the + ``concurrence`` is checked for labels, and those are used. If those + are also not available, string forms of integers starting with 0 are + returned. + + + Parameters + ---------- + concurrence : :class:`.Concurrence` + concurrence, which may have label information + labels : list of string + labels to use for contacts (optional) + + Returns + ------- + list of string + labels to use for contacts + """ if labels is None: if concurrence and concurrence.labels is not None: labels = concurrence.labels @@ -92,6 +175,7 @@ def get_concurrence_labels(concurrence, labels=None): @property def x_values(self): + """list : values to use for the x-axis (time)""" x_values = self._x_values if x_values is None: x_values = list(range(len(self.concurrence.values[0]))) @@ -102,6 +186,21 @@ def x_values(self, x_values): self._x_values = x_values def plot(self, concurrence=None): + """Contact concurrence plot based on matplotlib + + Parameters + ---------- + concurrence : :class:`.Concurrence` + optional; default None uses ``self.concurrence``; this allows + one to override the use of ``self.concurrence`` + + Returns + ------- + fig : :class:`.matplotlib.Figure` + ax : :class:`.matplotlib.Axes` + lgd: :class:`.matplotlib.legend.Legend` + objects for matplotlib-based plot of contact concurrences + """ if not HAS_MATPLOTLIB: # pragma: no cover raise ImportError("matplotlib not installed") if concurrence is None: @@ -128,5 +227,18 @@ def plot(self, concurrence=None): def plot_concurrence(concurrence, labels=None, x_values=None): # -no-cov- """ Convenience function for concurrence plots. + + Parameters + ---------- + concurrence : :class:`.Concurrence` + concurrence to be plotted + labels: list of string + labels for contacts (optional) + x_values : list of float or list of int + values to use for the x-axis + + See also + -------- + :class:`.ConcurrencePlotter` """ return ConcurrencePlotter(concurrence, labels, x_values).plot() diff --git a/docs/api.rst b/docs/api.rst index 61376fb..57e8502 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -18,6 +18,18 @@ Contact maps ContactFrequency ContactDifference +Contact Concurrences +-------------------- +.. autosummary:: + :toctree: api/generated/ + + Concurrence + AtomContactConcurrence + ResidueContactConcurrence + ConcurrencePlotter + plot_concurrence + + Minimum Distance (and related) ------------------------------ From 5d8184d83e9402f47c9de647f4c315a1ec1ab8ef Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 4 Jul 2018 20:04:33 +0200 Subject: [PATCH 15/22] add a couple to-dos --- contact_map/concurrence.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index fa08f30..1cc9a6c 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -77,6 +77,7 @@ class AtomContactConcurrence(Concurrence): cutoff, in nm. Should be the same as used in the contact map. """ def __init__(self, trajectory, atom_contacts, cutoff=0.45): + # TODO: the use of atom_contacts as input from most_common is weird atom_pairs = [[contact[0][0].index, contact[0][1].index] for contact in atom_contacts] labels = [str(contact[0]) for contact in atom_contacts] @@ -104,6 +105,7 @@ class ResidueContactConcurrence(Concurrence): """ def __init__(self, trajectory, residue_contacts, cutoff=0.45, select="and symbol != 'H'"): + # TODO: the use of residue_contacts as input from most_common is weird residue_pairs = [[contact[0][0], contact[0][1]] for contact in residue_contacts] labels = [str(contact[0]) for contact in residue_contacts] From b84283eed4137fa61e1cca1ac76711f1150c8175 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 10 Jul 2018 20:52:58 +0200 Subject: [PATCH 16/22] Update after Sander's review --- contact_map/concurrence.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 1cc9a6c..c58d52f 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -22,7 +22,7 @@ class Concurrence(object): ---------- values : list of list of bool the whether a contact is present for each contact pair at each - point in time; outer list is length number of frames, inner list + point in time; inner list is length number of frames, outer list in length number of (included) contacts labels : list of string labels for each contact pair @@ -83,7 +83,10 @@ def __init__(self, trajectory, atom_contacts, cutoff=0.45): labels = [str(contact[0]) for contact in atom_contacts] distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) vector_f = np.vectorize(lambda d: d < cutoff) - values = list(map(list, zip(*vector_f(distances)))) + # distances is (ndarray) shape (n_frames, n_contacts); + # values should be list shape # (n_contacts, n_frames) + value_iter = zip(*vector_f(distances)) # make bool; transpose + values = list(map(list, value_iter)) # convert to list of list super(AtomContactConcurrence, self).__init__(values=values, labels=labels) @@ -120,7 +123,7 @@ def __init__(self, trajectory, residue_contacts, cutoff=0.45, distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) min_dists = [min(dists) for dists in distances] - values.append(list(map(lambda d: d < cutoff, min_dists))) + values.append([d < cutoff for d in min_dists]) super(ResidueContactConcurrence, self).__init__(values=values, labels=labels) From a41bbb1480220633eef5a97a9287e301d5cc5d33 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 11 Jul 2018 08:49:34 +0200 Subject: [PATCH 17/22] support multiple input types with concurrences --- contact_map/concurrence.py | 46 +++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index c58d52f..ef12c26 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -2,6 +2,8 @@ import mdtraj as md import numpy as np +import contact_map + try: import matplotlib.pyplot as plt except ImportError: @@ -63,6 +65,41 @@ def __getitem__(self, label): # return sum(coincidence_list) / np.sqrt(norm_sq) +def _regularize_contact_input(contact_input, atom_or_res): + """Clean input for concurrence objects. + + The allowed inputs are the :class:`.ContactFrequency`, or the + :class:`.ContactObject` coming from the ``.atom_contacts`` or + ``.residue_contacts`` attribute of the contact frequency, or the list + coming from the ``.most_common()`` method for the + :class:`.ContactObject`. + + Parameters + ---------- + contact_input : many possible types; see method description + input to the contact concurrences + atom_or_res : string + whether to treat this as an atom-based or residue-based contact; + allowed values are "atom", "res", and "residue" + + Returns + ------- + list : + list in the format of ``ContactCount.most_common()`` + """ + if isinstance(contact_input, contact_map.ContactFrequency): + if atom_or_res == "atom": + contact_input = contact_input.atom_contacts.most_common() + elif atom_or_res == "residue" or atom_or_res == "res": + contact_input = contact_input.residue_contacts.most_common() + else: + raise RuntimeError("Bad value for atom_or_res: " + + str(atom_or_res)) + elif isinstance(contact_input, contact_map.ContactCount): + contact_input = contact_input.most_common() + + return contact_input + class AtomContactConcurrence(Concurrence): """Contact concurrences for atom contacts. @@ -78,13 +115,14 @@ class AtomContactConcurrence(Concurrence): """ def __init__(self, trajectory, atom_contacts, cutoff=0.45): # TODO: the use of atom_contacts as input from most_common is weird + atom_contacts = _regularize_contact_input(atom_contacts, "atom") atom_pairs = [[contact[0][0].index, contact[0][1].index] for contact in atom_contacts] labels = [str(contact[0]) for contact in atom_contacts] distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) vector_f = np.vectorize(lambda d: d < cutoff) - # distances is (ndarray) shape (n_frames, n_contacts); - # values should be list shape # (n_contacts, n_frames) + # distances is ndarray shape (n_frames, n_contacts); values should + # be list shape (n_contacts, n_frames) value_iter = zip(*vector_f(distances)) # make bool; transpose values = list(map(list, value_iter)) # convert to list of list super(AtomContactConcurrence, self).__init__(values=values, @@ -109,6 +147,8 @@ class ResidueContactConcurrence(Concurrence): def __init__(self, trajectory, residue_contacts, cutoff=0.45, select="and symbol != 'H'"): # TODO: the use of residue_contacts as input from most_common is weird + residue_contacts = _regularize_contact_input(residue_contacts, + "residue") residue_pairs = [[contact[0][0], contact[0][1]] for contact in residue_contacts] labels = [str(contact[0]) for contact in residue_contacts] @@ -147,7 +187,7 @@ class ConcurrencePlotter(object): def __init__(self, concurrence=None, labels=None, x_values=None): self.concurrence = concurrence self.labels = self.get_concurrence_labels(concurrence, labels) - self.x_values = x_values + self._x_values = x_values @staticmethod def get_concurrence_labels(concurrence, labels=None): From 1d59768b3e4eda45dcde592ad751adb716835e78 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Wed, 11 Jul 2018 11:23:42 +0200 Subject: [PATCH 18/22] Tests for input types; clearer code --- contact_map/concurrence.py | 13 +++++-------- contact_map/tests/test_concurrence.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index ef12c26..60b6c63 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -88,9 +88,9 @@ def _regularize_contact_input(contact_input, atom_or_res): list in the format of ``ContactCount.most_common()`` """ if isinstance(contact_input, contact_map.ContactFrequency): - if atom_or_res == "atom": + if atom_or_res in ["atom", "atoms"]: contact_input = contact_input.atom_contacts.most_common() - elif atom_or_res == "residue" or atom_or_res == "res": + elif atom_or_res in ["residue", "residues", "res"]: contact_input = contact_input.residue_contacts.most_common() else: raise RuntimeError("Bad value for atom_or_res: " + @@ -114,17 +114,15 @@ class AtomContactConcurrence(Concurrence): cutoff, in nm. Should be the same as used in the contact map. """ def __init__(self, trajectory, atom_contacts, cutoff=0.45): - # TODO: the use of atom_contacts as input from most_common is weird atom_contacts = _regularize_contact_input(atom_contacts, "atom") atom_pairs = [[contact[0][0].index, contact[0][1].index] for contact in atom_contacts] labels = [str(contact[0]) for contact in atom_contacts] distances = md.compute_distances(trajectory, atom_pairs=atom_pairs) vector_f = np.vectorize(lambda d: d < cutoff) - # distances is ndarray shape (n_frames, n_contacts); values should - # be list shape (n_contacts, n_frames) - value_iter = zip(*vector_f(distances)) # make bool; transpose - values = list(map(list, value_iter)) # convert to list of list + # transpose because distances is ndarray shape (n_frames, + # n_contacts); values should be list shape (n_contacts, n_frames) + values = vector_f(distances).T.tolist() super(AtomContactConcurrence, self).__init__(values=values, labels=labels) @@ -146,7 +144,6 @@ class ResidueContactConcurrence(Concurrence): """ def __init__(self, trajectory, residue_contacts, cutoff=0.45, select="and symbol != 'H'"): - # TODO: the use of residue_contacts as input from most_common is weird residue_contacts = _regularize_contact_input(residue_contacts, "residue") residue_pairs = [[contact[0][0], contact[0][1]] diff --git a/contact_map/tests/test_concurrence.py b/contact_map/tests/test_concurrence.py index 1be7538..0347b5d 100644 --- a/contact_map/tests/test_concurrence.py +++ b/contact_map/tests/test_concurrence.py @@ -20,6 +20,29 @@ def setup_module(): n_neighbors_ignored=0) +@pytest.mark.parametrize("contact_type", ('atoms', 'residues')) +def test_regularize_contact_input(contact_type): + from contact_map.concurrence import _regularize_contact_input \ + as regularize + most_common = { + 'atoms': contacts.atom_contacts.most_common(), + 'residues': contacts.residue_contacts.most_common() + }[contact_type] + contact_count = { + 'atoms': contacts.atom_contacts, + 'residues': contacts.residue_contacts + }[contact_type] + assert regularize(most_common, contact_type) == most_common + assert regularize(contact_count, contact_type) == most_common + assert regularize(contacts, contact_type) == most_common + +def test_regularize_contact_input_bad_type(): + from contact_map.concurrence import _regularize_contact_input \ + as regularize + with pytest.raises(RuntimeError): + regularize(contacts, "foo") + + class ContactConcurrenceTester(object): def _test_default_labels(self, concurrence): assert len(concurrence.labels) == len(self.labels) / 2 From 64d3deb9df1710654b136ee8d237b57a9e4628f7 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 17 Jul 2018 17:45:12 +0200 Subject: [PATCH 19/22] Add ContactsDict and .contacts attrib --- contact_map/concurrence.py | 12 +++------ contact_map/contact_map.py | 39 +++++++++++++++++++++++++++ contact_map/tests/test_contact_map.py | 28 +++++++++++++++++++ 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index 60b6c63..aa1cb63 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -87,15 +87,9 @@ def _regularize_contact_input(contact_input, atom_or_res): list : list in the format of ``ContactCount.most_common()`` """ - if isinstance(contact_input, contact_map.ContactFrequency): - if atom_or_res in ["atom", "atoms"]: - contact_input = contact_input.atom_contacts.most_common() - elif atom_or_res in ["residue", "residues", "res"]: - contact_input = contact_input.residue_contacts.most_common() - else: - raise RuntimeError("Bad value for atom_or_res: " + - str(atom_or_res)) - elif isinstance(contact_input, contact_map.ContactCount): + if isinstance(contact_input, contact_map.ContactObject): + contact_input = contact_input.contacts[atom_or_res] + if isinstance(contact_input, contact_map.ContactCount): contact_input = contact_input.most_common() return contact_input diff --git a/contact_map/contact_map.py b/contact_map/contact_map.py index c0b66a1..269a9e3 100644 --- a/contact_map/contact_map.py +++ b/contact_map/contact_map.py @@ -42,6 +42,7 @@ def residue_neighborhood(residue, n=1): # good, and it only gets run once per residue return [idx for idx in neighborhood if idx in chain] + def _residue_and_index(residue, topology): res = residue try: @@ -52,6 +53,39 @@ def _residue_and_index(residue, topology): return (res, res_idx) +class ContactsDict(object): + """Dict-like object giving access to atom or residue contacts. + + In some algorithmic situations, either the atom_contacts or the + residue_contacts might be used. Rather than use lots of if-statements, + or build an actual dictionary with the associated time cost of + generating both, this class provides an object that allows dict-like + access to either the atom or residue contacts. + + Atom-based contacts (``contact.atom_contacts``) can be accessed with as + ``contact_dict['atom']`` or ``contact_dict['atoms']``. Residue-based + contacts can be accessed with the keys ``'residue'``, ``'residues'``, or + ``'res'``. + + Parameters + ---------- + contacts : :class:`.ContactObject` + contact object with fundamental data + """ + def __init__(self, contacts): + self.contacts = contacts + + def __getitem__(self, atom_or_res): + if atom_or_res in ["atom", "atoms"]: + contacts = self.contacts.atom_contacts + elif atom_or_res in ["residue", "residues", "res"]: + contacts = self.contacts.residue_contacts + else: + raise RuntimeError("Bad value for atom_or_res: " + + str(atom_or_res)) + return contacts + + class ContactObject(object): """ Generic object for contact map related analysis. Effectively abstract. @@ -75,6 +109,11 @@ def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored): self._atom_idx_to_residue_idx = {atom.index: atom.residue.index for atom in self.topology.atoms} + @property + def contacts(self): + """:class:`.ContactsDict` : contact dict for these contacts""" + return ContactsDict(self) + def __hash__(self): return hash((self.cutoff, self.n_neighbors_ignored, frozenset(self._query), frozenset(self._haystack), diff --git a/contact_map/tests/test_contact_map.py b/contact_map/tests/test_contact_map.py index 3a3eece..6157907 100644 --- a/contact_map/tests/test_contact_map.py +++ b/contact_map/tests/test_contact_map.py @@ -238,6 +238,15 @@ def test_saving(self, idx): assert m.atom_contacts.counter == m2.atom_contacts.counter os.remove(test_file) + def test_contacts_dict(self, idx): + m = self.maps[idx] + assert (m.atom_contacts.counter == m.contacts['atom'].counter + == m.contacts['atoms'].counter) + assert (m.residue_contacts.counter == m.contacts['res'].counter + == m.contacts['residue'].counter + == m.contacts['residues'].counter) + + # TODO: add tests for ContactObject._check_consistency @@ -277,6 +286,14 @@ def test_counters(self): } assert residue_contacts.counter == expected_residue_contacts + def test_contacts_dict(self): + m = self.map + assert (m.atom_contacts.counter == m.contacts['atom'].counter + == m.contacts['atoms'].counter) + assert (m.residue_contacts.counter == m.contacts['res'].counter + == m.contacts['residue'].counter + == m.contacts['residues'].counter) + def test_check_compatibility_true(self): map2 = ContactFrequency(trajectory=traj[0:2], cutoff=0.075, @@ -546,6 +563,17 @@ def test_diff_frame_frame(self): for (k, v) in expected_residues_1.items()} assert diff_2.residue_contacts.counter == expected_residues_2 + def test_contacts_dict(self): + ttraj = ContactFrequency(traj[0:4], cutoff=0.075, + n_neighbors_ignored=0) + frame = ContactMap(traj[4], cutoff=0.075, n_neighbors_ignored=0) + m = ttraj - frame + assert (m.atom_contacts.counter == m.contacts['atom'].counter + == m.contacts['atoms'].counter) + assert (m.residue_contacts.counter == m.contacts['res'].counter + == m.contacts['residue'].counter + == m.contacts['residues'].counter) + def test_diff_traj_traj(self): traj_1 = ContactFrequency(trajectory=traj[0:2], cutoff=0.075, From 626456a93510dc4fbf08910d8d2d21c88373617a Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 17 Jul 2018 17:56:10 +0200 Subject: [PATCH 20/22] clean up repetitive code --- contact_map/tests/test_contact_map.py | 32 +++++++++++---------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/contact_map/tests/test_contact_map.py b/contact_map/tests/test_contact_map.py index 6157907..7ef23c6 100644 --- a/contact_map/tests/test_contact_map.py +++ b/contact_map/tests/test_contact_map.py @@ -79,6 +79,16 @@ def _contact_object_compare(m, m2): if hasattr(m, '_residue_contacts') or hasattr(m2, '_residue_contacts'): assert m._residue_contacts == m2._residue_contacts + +def _check_contacts_dict_names(contact_object): + aliases = { + contact_object.residue_contacts: ['residue', 'residues', 'res'], + contact_object.atom_contacts: ['atom', 'atoms'] + } + for (contacts, names) in aliases.items(): + for name in names: + assert contacts.counter == contact_object.contacts[name].counter + def test_residue_neighborhood(): top = traj.topology residues = list(top.residues) @@ -239,13 +249,7 @@ def test_saving(self, idx): os.remove(test_file) def test_contacts_dict(self, idx): - m = self.maps[idx] - assert (m.atom_contacts.counter == m.contacts['atom'].counter - == m.contacts['atoms'].counter) - assert (m.residue_contacts.counter == m.contacts['res'].counter - == m.contacts['residue'].counter - == m.contacts['residues'].counter) - + _check_contacts_dict_names(self.maps[idx]) # TODO: add tests for ContactObject._check_consistency @@ -287,12 +291,7 @@ def test_counters(self): assert residue_contacts.counter == expected_residue_contacts def test_contacts_dict(self): - m = self.map - assert (m.atom_contacts.counter == m.contacts['atom'].counter - == m.contacts['atoms'].counter) - assert (m.residue_contacts.counter == m.contacts['res'].counter - == m.contacts['residue'].counter - == m.contacts['residues'].counter) + _check_contacts_dict_names(self.map) def test_check_compatibility_true(self): map2 = ContactFrequency(trajectory=traj[0:2], @@ -567,12 +566,7 @@ def test_contacts_dict(self): ttraj = ContactFrequency(traj[0:4], cutoff=0.075, n_neighbors_ignored=0) frame = ContactMap(traj[4], cutoff=0.075, n_neighbors_ignored=0) - m = ttraj - frame - assert (m.atom_contacts.counter == m.contacts['atom'].counter - == m.contacts['atoms'].counter) - assert (m.residue_contacts.counter == m.contacts['res'].counter - == m.contacts['residue'].counter - == m.contacts['residues'].counter) + _check_contacts_dict_names(ttraj - frame) def test_diff_traj_traj(self): traj_1 = ContactFrequency(trajectory=traj[0:2], From 91e3092470b9e79dde862148b1e66044c72fe62f Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 17 Jul 2018 18:05:42 +0200 Subject: [PATCH 21/22] Fix ContactObject imports? --- contact_map/concurrence.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index aa1cb63..a03ec5b 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -3,6 +3,7 @@ import numpy as np import contact_map +from .contact_map import ContactObject try: import matplotlib.pyplot as plt @@ -87,7 +88,7 @@ def _regularize_contact_input(contact_input, atom_or_res): list : list in the format of ``ContactCount.most_common()`` """ - if isinstance(contact_input, contact_map.ContactObject): + if isinstance(contact_input, ContactObject): contact_input = contact_input.contacts[atom_or_res] if isinstance(contact_input, contact_map.ContactCount): contact_input = contact_input.most_common() From 24f9928dea6f430cb474985c86a9a4b43af50bbb Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Thu, 26 Jul 2018 22:23:54 +0200 Subject: [PATCH 22/22] Add extra blank line for readability --- contact_map/concurrence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/contact_map/concurrence.py b/contact_map/concurrence.py index a03ec5b..c456bdb 100644 --- a/contact_map/concurrence.py +++ b/contact_map/concurrence.py @@ -90,6 +90,7 @@ def _regularize_contact_input(contact_input, atom_or_res): """ if isinstance(contact_input, ContactObject): contact_input = contact_input.contacts[atom_or_res] + if isinstance(contact_input, contact_map.ContactCount): contact_input = contact_input.most_common()