Skip to content

Commit

Permalink
fit results clustering speed up, a bit; see #44
Browse files Browse the repository at this point in the history
  • Loading branch information
RodenLuo committed Dec 9, 2024
1 parent d3af4d2 commit 17f459a
Showing 1 changed file with 37 additions and 6 deletions.
43 changes: 37 additions & 6 deletions src/DiffAtomComp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,32 @@
# Ignore PDBConstructionWarning for unrecognized 'END' record
warnings.filterwarnings("ignore", message="Ignoring unrecognized record 'END'", category=PDBConstructionWarning)

from chimerax.geometry.bins import Binned_Transforms

class DiffFit_Binned_Transforms(Binned_Transforms):
def __init__(self, angle, translation, center=(0, 0, 0), bfactor=2):
super().__init__(angle, translation, center, bfactor)

def one_close_transform(self, tf):

a, x, y, z = c = self.bin_point(tf)
clist = self.bins.close_objects(c, self.spacing)
if len(clist) == 0:
return None

itf = tf.inverse()
d2max = self.translation * self.translation
for ctf in clist:
cx, cy, cz = ctf * self.center
dx, dy, dz = x - cx, y - cy, z - cz
d2 = dx * dx + dy * dy + dz * dz
if d2 <= d2max:
dtf = ctf * itf
a = dtf.rotation_angle()
if a < self.angle:
return ctf

return None

def interpolate_coords(coords, inter_folds, inter_kind='quadratic'):
"""Interpolate backbone coordinates."""
Expand Down Expand Up @@ -126,7 +152,6 @@ def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3
"""
mol_center = np.array([0.0, 0.0, 0.0])

from chimerax.geometry import bins
from chimerax.geometry import Place

N_mol, N_record, N_iter, N_metric = e_sqd_log.shape
Expand Down Expand Up @@ -210,21 +235,27 @@ def cluster_and_sort_sqd_fast(e_sqd_log, mol_centers, shift_tolerance: float = 3
log_file.write(f"Convert to matrix time: {datetime.now() - timer_start}\n")
timer_start = datetime.now()

b = bins.Binned_Transforms(angle_tolerance * pi / 180, shift_tolerance, mol_center, bfactor=2)
b = DiffFit_Binned_Transforms(angle_tolerance * pi / 180, shift_tolerance, mol_center, bfactor=1)
mol_transform_label = []
unique_id = 0
T_ID_dict = {}
for i in range(len(mol_shift)):
ptf = T[i]
close = b.close_transforms(ptf)
if len(close) == 0:
coord = [c / s for c, s in zip(b.bin_point(ptf), b.bins.bin_size)]
cbin = b.bins.close_bins(coord, (0, 0, 0, 0))[0]
close = None
if cbin in b.bins.bins:
close = b.bins.bins[cbin][0][1]
else:
close = b.one_close_transform(ptf)
if close is None:
b.add_transform(ptf)
mol_transform_label.append(unique_id)
T_ID_dict[id(ptf)] = unique_id
unique_id = unique_id + 1
else:
mol_transform_label.append(T_ID_dict[id(close[0])])
T_ID_dict[id(ptf)] = T_ID_dict[id(close[0])]
mol_transform_label.append(T_ID_dict[id(close)])
T_ID_dict[id(ptf)] = T_ID_dict[id(close)]

if save_log and (i + 1) % 10000 == 0:
with open(log_path, "a") as log_file:
Expand Down

0 comments on commit 17f459a

Please sign in to comment.