Skip to content

Commit

Permalink
# max fits/mol filter in the View tab
Browse files Browse the repository at this point in the history
  • Loading branch information
RodenLuo committed Dec 17, 2024
1 parent 323a08d commit 8317641
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
10 changes: 8 additions & 2 deletions src/DiffAtomComp.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def cluster_and_sort_sqd_fast(e_sqd_log, shift_tolerance: float = 3.0, angle_tol
in_contour_threshold: float = 0.5,
save_log=False,
log_path="",
max_fits=10000,
max_clusters=100):
"""
Cluster the fitting results in sqd table by thresholding on shift and quaternion
Expand Down Expand Up @@ -191,8 +192,13 @@ def cluster_and_sort_sqd_fast(e_sqd_log, shift_tolerance: float = 3.0, angle_tol
filtered_indices = np.where(in_contour_mask) # Get the indices of the filtered rows
filtered_array = sqd_highest_corr_np_mol[filtered_indices]

fit_res_filtered.append(filtered_array)
fit_res_filtered_indices.append(filtered_indices[0])
sorted_indices = np.argsort(filtered_array[:, sort_column_idx])[::-1]
top_indices = sorted_indices[:min(max_fits, len(filtered_array))]
filtered_array_top = filtered_array[top_indices]
filtered_indices_top = filtered_indices[0][top_indices]

fit_res_filtered.append(filtered_array_top)
fit_res_filtered_indices.append(filtered_indices_top)

if save_log:
with open(log_path, "a") as log_file:
Expand Down
34 changes: 24 additions & 10 deletions src/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def __init__(self):
self.clustering_shift_tolerance : float = 3.0
self.clustering_angle_tolerance : float = 6.0

self.max_fits: float = 10000
self.max_clusters: float = 100
self.clustering_in_contour_threshold: float = 0.2

Expand Down Expand Up @@ -350,6 +351,7 @@ def load_settings(self):
self.dataset_folder.setText(self.settings.view_output_directory)

# clustering
self.max_fits.setValue(self.settings.max_fits)
self.max_clusters.setValue(self.settings.max_clusters)
self.clustering_in_contour_threshold.setValue(self.settings.clustering_in_contour_threshold)
self.clustering_angle_tolerance.setValue(self.settings.clustering_angle_tolerance)
Expand Down Expand Up @@ -386,6 +388,7 @@ def store_settings(self):
self.settings.view_output_directory = self.dataset_folder.text()

# clustering
self.settings.max_fits = self.max_fits.value()
self.settings.max_clusters = self.max_clusters.value()
self.settings.clustering_in_contour_threshold = self.clustering_in_contour_threshold.value()
self.settings.clustering_angle_tolerance = self.clustering_angle_tolerance.value()
Expand Down Expand Up @@ -1013,16 +1016,6 @@ def build_view_ui(self, layout):
layout.addWidget(self.dataset_folder_select, row, 2)
row = row + 1

max_clusters_label = QLabel()
max_clusters_label.setText("# max clusters/mol:")
self.max_clusters = QSpinBox()
self.max_clusters.setMinimum(1)
self.max_clusters.setMaximum(100000)
self.max_clusters.valueChanged.connect(lambda: self.store_settings())
layout.addWidget(max_clusters_label, row, 0)
layout.addWidget(self.max_clusters, row, 1, 1, 2)
row = row + 1

clustering_in_contour_threshold_label = QLabel()
clustering_in_contour_threshold_label.setText("In contour threshold:")
self.clustering_in_contour_threshold = QDoubleSpinBox()
Expand All @@ -1034,6 +1027,26 @@ def build_view_ui(self, layout):
layout.addWidget(self.clustering_in_contour_threshold, row, 1, 1, 2)
row = row + 1

max_fits_label = QLabel()
max_fits_label.setText("# max fits/mol:")
self.max_fits = QSpinBox()
self.max_fits.setMinimum(1)
self.max_fits.setMaximum(2**31 - 1)
self.max_fits.valueChanged.connect(lambda: self.store_settings())
layout.addWidget(max_fits_label, row, 0)
layout.addWidget(self.max_fits, row, 1, 1, 2)
row = row + 1

max_clusters_label = QLabel()
max_clusters_label.setText("# max clusters/mol:")
self.max_clusters = QSpinBox()
self.max_clusters.setMinimum(1)
self.max_clusters.setMaximum(100000)
self.max_clusters.valueChanged.connect(lambda: self.store_settings())
layout.addWidget(max_clusters_label, row, 0)
layout.addWidget(self.max_clusters, row, 1, 1, 2)
row = row + 1

clustering_shift_tolerance_label = QLabel()
clustering_shift_tolerance_label.setText("Clustering - Shift Tolerance:")
self.clustering_shift_tolerance = QDoubleSpinBox()
Expand Down Expand Up @@ -1456,6 +1469,7 @@ def show_results(self, e_sqd_log, mol_num_atoms, mol_paths,
in_contour_threshold=self.settings.clustering_in_contour_threshold,
save_log=save_log,
log_path=log_path,
max_fits=self.settings.max_fits,
max_clusters=self.settings.max_clusters)
if save_log:
with open(log_path, "a") as log_file:
Expand Down

0 comments on commit 8317641

Please sign in to comment.