Skip to content

Commit

Permalink
avoid special cases in returning measure_results and metrics_results
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza committed Apr 28, 2021
1 parent c40d03c commit 08a3b28
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 28 deletions.
2 changes: 0 additions & 2 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,5 @@ def __next__(self):
measure_output[j][i][key] for j in range(len(measure_output))
]
measure_results[f.__name__] = measure_dic
if len(self.measure_functions) == 1:
measure_results = measure_results[self.measure_functions[0].__name__]

return blend_output, measure_results
31 changes: 7 additions & 24 deletions btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,38 +681,21 @@ def __next__(self):
noise_threshold = self.noise_threshold_factor * np.sqrt(
get_mean_sky_level(survey, survey.filters[self.meas_band_num])
)
if "catalog" not in measure_results.keys():
meas_func = measure_results.keys()
metrics_results = {}
for f in meas_func:
metrics_results_f = compute_metrics(
blend_results["blend_images"],
blend_results["isolated_images"],
blend_results["blend_list"],
measure_results[f]["catalog"],
measure_results[f]["segmentation"],
measure_results[f]["deblended_images"],
self.use_metrics,
noise_threshold,
self.meas_band_num,
target_meas,
channels_last=self.measure_generator.channels_last,
)
metrics_results[f] = metrics_results_f

else:
metrics_results = compute_metrics(
metrics_results = {}
for meas_func in measure_results:
metrics_results_f = compute_metrics(
blend_results["blend_images"],
blend_results["isolated_images"],
blend_results["blend_list"],
measure_results["catalog"],
measure_results["segmentation"],
measure_results["deblended_images"],
measure_results[meas_func]["catalog"],
measure_results[meas_func]["segmentation"],
measure_results[meas_func]["deblended_images"],
self.use_metrics,
noise_threshold,
self.meas_band_num,
target_meas,
channels_last=self.measure_generator.channels_last,
)
metrics_results[meas_func] = metrics_results_f

return blend_results, measure_results, metrics_results
2 changes: 2 additions & 0 deletions tests/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def compare_sep():
"""Test detection with sep"""
meas_generator = get_meas_generator(btk.measure.sep_measure)
_, results = next(meas_generator)
results = list(results.values())[0] # extract single element from dict.
x_peak, y_peak = (
results["catalog"][0]["x_peak"].item(),
results["catalog"][0]["y_peak"].item(),
Expand All @@ -64,6 +65,7 @@ def compare_sep_multiprocessing():
"""Test detection with sep"""
meas_generator = get_meas_generator(btk.measure.sep_measure, cpus=4)
_, results = next(meas_generator)
results = list(results.values())[0] # extract single element from dict.
x_peak, y_peak = (
results["catalog"][0]["x_peak"].item(),
results["catalog"][0]["y_peak"].item(),
Expand Down
5 changes: 3 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def get_metrics_generator(meas_function, cpus=1):

@patch("btk.plot_utils.plt.show")
def test_sep_metrics(mock_show):
meas_generator = get_metrics_generator(btk.measure.sep_measure)
_, _, results = next(meas_generator)
metrics_generator = get_metrics_generator(btk.measure.sep_measure)
_, _, results = next(metrics_generator)
results = list(results.values())[0]
gal_summary = results["galaxy_summary"][
results["galaxy_summary"]["detected"] == True # noqa: E712
]
Expand Down

0 comments on commit 08a3b28

Please sign in to comment.