Skip to content

Commit

Permalink
Switch order of dictionaries for metrics (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
thuiop authored Jul 13, 2021
1 parent f40e042 commit 4a4e756
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 92 deletions.
12 changes: 7 additions & 5 deletions btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@

from btk.measure import MeasureGenerator
from btk.survey import get_mean_sky_level
from btk.utils import reverse_dictionary_dictionary


def get_blendedness(iso_image, blend_iso_images):
Expand Down Expand Up @@ -757,12 +758,13 @@ def __next__(self):
self.meas_band_num[i],
target_meas,
channels_last=self.measure_generator.channels_last,
save_path=os.path.join(self.save_path, meas_func)
save_path=os.path.join(self.save_path, meas_func, surv)
if self.save_path is not None
else None,
f_distance=self.f_distance,
distance_threshold_match=self.distance_threshold_match,
)
metrics_results_f = reverse_dictionary_dictionary(metrics_results_f)

else:
additional_params = {
Expand All @@ -788,13 +790,14 @@ def __next__(self):
self.meas_band_num,
target_meas,
channels_last=self.measure_generator.channels_last,
save_path=os.path.join(self.save_path, meas_func)
save_path=os.path.join(self.save_path, meas_func, surveys[0].name)
if self.save_path is not None
else None,
f_distance=self.f_distance,
distance_threshold_match=self.distance_threshold_match,
)
metrics_results[meas_func] = metrics_results_f
metrics_results = reverse_dictionary_dictionary(metrics_results)

return blend_results, measure_results, metrics_results

Expand Down Expand Up @@ -828,9 +831,8 @@ def auc(metrics_results, measure_name, n_meas, plot=False, ax=None):
recalls = []
average_precision = 0
for i in range(n_meas):
metrics_results_temp = metrics_results[measure_name + str(i)]
precisions.append(metrics_results_temp["detection"]["precision"])
recalls.append(metrics_results_temp["detection"]["recall"])
precisions.append(metrics_results["detection"][measure_name + str(i)]["precision"])
recalls.append(metrics_results["detection"][measure_name + str(i)]["recall"])
order = np.argsort(recalls)
recalls = np.array(recalls)[order]
precisions = np.array(precisions)[order]
Expand Down
57 changes: 34 additions & 23 deletions btk/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,34 +486,45 @@ def plot_metrics_summary( # noqa: C901
"""
sns.set_context(context)
# Keys corresponding to the measure functions
keys = list(metrics_results.keys())
measure_keys = list(metrics_results["galaxy_summary"].keys())

# We need to handle the multiresolution case
if "galaxy_summary" not in metrics_results[keys[0]].keys():
survey_keys = list(metrics_results[keys[0]].keys())
gal_summary_keys = list(metrics_results[keys[0]][survey_keys[0]]["galaxy_summary"].keys())
if isinstance(metrics_results["galaxy_summary"][measure_keys[0]], dict):
survey_keys = list(metrics_results["galaxy_summary"][measure_keys[0]].keys())
gal_summary_keys = list(
metrics_results["galaxy_summary"][measure_keys[0]][survey_keys[0]].keys()
)
multiresolution = True
# Limits for widgets
min_mag = np.min(metrics_results[keys[0]][survey_keys[0]]["galaxy_summary"]["ref_mag"])
max_mag = np.max(metrics_results[keys[0]][survey_keys[0]]["galaxy_summary"]["ref_mag"])
min_size = np.min(metrics_results[keys[0]][survey_keys[0]]["galaxy_summary"]["btk_size"])
max_size = np.max(metrics_results[keys[0]][survey_keys[0]]["galaxy_summary"]["btk_size"])
min_mag = np.min(
metrics_results["galaxy_summary"][measure_keys[0]][survey_keys[0]]["ref_mag"]
)
max_mag = np.max(
metrics_results["galaxy_summary"][measure_keys[0]][survey_keys[0]]["ref_mag"]
)
min_size = np.min(
metrics_results["galaxy_summary"][measure_keys[0]][survey_keys[0]]["btk_size"]
)
max_size = np.max(
metrics_results["galaxy_summary"][measure_keys[0]][survey_keys[0]]["btk_size"]
)
else:
gal_summary_keys = list(metrics_results[keys[0]]["galaxy_summary"].keys())
gal_summary_keys = list(metrics_results["galaxy_summary"][measure_keys[0]].keys())
multiresolution = False
min_mag = np.min(metrics_results[keys[0]]["galaxy_summary"]["ref_mag"])
max_mag = np.max(metrics_results[keys[0]]["galaxy_summary"]["ref_mag"])
min_size = np.min(metrics_results[keys[0]]["galaxy_summary"]["btk_size"])
max_size = np.max(metrics_results[keys[0]]["galaxy_summary"]["btk_size"])
min_mag = np.min(metrics_results["galaxy_summary"][measure_keys[0]]["ref_mag"])
max_mag = np.max(metrics_results["galaxy_summary"][measure_keys[0]]["ref_mag"])
min_size = np.min(metrics_results["galaxy_summary"][measure_keys[0]]["btk_size"])
max_size = np.max(metrics_results["galaxy_summary"][measure_keys[0]]["btk_size"])
plot_keys = ["reconstruction", "segmentation", "eff_matrix"] + target_meas_keys + ["custom"]

if interactive:
layout = widgets.Layout(width="auto")
# Checkboxes for selecting the measure function
measure_functions_dict = {
key: widgets.Checkbox(description=key, value=False, layout=layout) for key in keys
key: widgets.Checkbox(description=key, value=False, layout=layout)
for key in measure_keys
}
measure_functions = [measure_functions_dict[key] for key in keys]
measure_functions = [measure_functions_dict[key] for key in measure_keys]
measure_functions_widget = widgets.VBox(measure_functions, description="Measure functions")
# Checkboxes for selecting the survey (if multiresolution)
if multiresolution:
Expand Down Expand Up @@ -597,7 +608,7 @@ def draw_plots(value):
custom_y_log = custom_y_widget_log.value
plot_selections = {w.description: w.value for w in plot_selection_widget.children}
else:
meas_func_names = keys
meas_func_names = measure_keys
if multiresolution:
surveys = survey_keys
blendedness_limits = [0, 1]
Expand All @@ -619,14 +630,14 @@ def draw_plots(value):
for f_name in meas_func_names:
for s_name in surveys:
couples.append(f_name + "_" + s_name)
dataframes[f_name + "_" + s_name] = metrics_results[f_name][s_name][
"galaxy_summary"
dataframes[f_name + "_" + s_name] = metrics_results["galaxy_summary"][f_name][
s_name
].to_pandas()
concatenated = pd.concat([dataframes[c].assign(measure_function=c) for c in couples])
else:
dataframes = {}
for f_name in meas_func_names:
dataframes[f_name] = metrics_results[f_name]["galaxy_summary"].to_pandas()
dataframes[f_name] = metrics_results["galaxy_summary"][f_name].to_pandas()
concatenated = pd.concat(
[dataframes[f_name].assign(measure_function=f_name) for f_name in meas_func_names]
)
Expand Down Expand Up @@ -719,7 +730,7 @@ def draw_plots(value):

mag_low = np.min(concatenated["ref_mag"])
mag_high = np.max(concatenated["ref_mag"])
for meas_func in keys:
for meas_func in measure_keys:
bins = np.linspace(mag_low, mag_high, n_bins_target)
labels = np.digitize(concatenated["ref_mag"], bins)
means = []
Expand Down Expand Up @@ -772,10 +783,10 @@ def draw_plots(value):
for i, k in enumerate(meas_func_names):
if multiresolution:
plot_efficiency_matrix(
metrics_results[k][survey_keys[0]]["detection"]["eff_matrix"], ax=ax[i]
metrics_results["detection"][k][survey_keys[0]]["eff_matrix"], ax=ax[i]
)
else:
plot_efficiency_matrix(metrics_results[k]["detection"]["eff_matrix"], ax=ax[i])
plot_efficiency_matrix(metrics_results["detection"][k]["eff_matrix"], ax=ax[i])
ax[i].set_title(k)
if save_path is not None:
plt.savefig(os.path.join(save_path, "efficiency_matrices.png"))
Expand All @@ -786,7 +797,7 @@ def draw_plots(value):
blendedness_widget.observe(draw_plots, "value")
magnitude_widget.observe(draw_plots, "value")
size_widget.observe(draw_plots, "value")
for k in keys:
for k in measure_keys:
measure_functions_dict[k].observe(draw_plots, "value")
if multiresolution:
for k in survey_keys:
Expand Down
44 changes: 39 additions & 5 deletions btk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_measure_results(path, measure_name, n_batch):
return measure_results


def load_metrics_results(path, measure_name):
def load_metrics_results(path, measure_name, survey_name):
"""Load results exported from a MetricsGenerator.
Args:
Expand All @@ -87,13 +87,14 @@ def load_metrics_results(path, measure_name):
for key in ["detection", "segmentation", "reconstruction"]:
try:
metrics_results[key] = np.load(
os.path.join(path, measure_name, f"{key}_metric.npy"), allow_pickle=True
os.path.join(path, measure_name, survey_name, f"{key}_metric.npy"),
allow_pickle=True,
)
except FileNotFoundError:
print(f"No {key} metrics found.")

metrics_results["galaxy_summary"] = Table.read(
os.path.join(path, measure_name, "galaxy_summary"),
os.path.join(path, measure_name, survey_name, "galaxy_summary"),
format="ascii",
)
return metrics_results
Expand All @@ -119,7 +120,12 @@ def load_all_results(path, surveys, measure_names, n_batch, n_meas_kwargs=1):
for key in BLEND_RESULT_KEYS:
blend_results[key] = {}
measure_results = {"catalog": {}, "segmentation": {}, "deblended_images": {}}
metrics_results = {}
metrics_results = {
"detection": {},
"segmentation": {},
"reconstruction": {},
"galaxy_summary": {},
}
for s in surveys:
blend_results_temp = load_blend_results(path, s)
for key in BLEND_RESULT_KEYS:
Expand All @@ -131,7 +137,17 @@ def load_all_results(path, surveys, measure_names, n_batch, n_meas_kwargs=1):
meas_results = load_measure_results(path, dir_name, n_batch)
for k in meas_results.keys():
measure_results[k][dir_name] = meas_results[k]
metrics_results[dir_name] = load_metrics_results(path, dir_name)
for k in metrics_results.keys():
metrics_results[k][dir_name] = {}
if len(surveys) > 1:
for s in surveys:
metr_results = load_metrics_results(path, dir_name, s)
for k in metr_results.keys():
metrics_results[k][dir_name][s] = metr_results[k]
else:
metr_results = load_metrics_results(path, dir_name, surveys[0])
for k in metr_results.keys():
metrics_results[k][dir_name] = metr_results[k]

return blend_results, measure_results, metrics_results

Expand All @@ -155,3 +171,21 @@ def reverse_list_dictionary(to_reverse, keys):
else:
to_reverse = {k: [to_reverse[n][k] for n in range(len(to_reverse))] for k in keys}
return to_reverse


def reverse_dictionary_dictionary(to_reverse):
"""Exchanges two dictionary layers.
For instance, dic[keyA][key1] will become dic[key1][keyA].
Args:
to_reverse (dict): Dictionary of dictionaries.
Returns:
Reversed dictionary.
"""
first_keys = list(to_reverse.keys())
second_keys = list(to_reverse[first_keys[0]].keys())
return {
s_key: {f_key: to_reverse[f_key][s_key] for f_key in first_keys} for s_key in second_keys
}
46 changes: 23 additions & 23 deletions notebooks/intro.ipynb

Large diffs are not rendered by default.

51 changes: 24 additions & 27 deletions notebooks/scarlet-measure.ipynb

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ def get_metrics_generator(meas_function, cpus=1, f_distance=distance_center, mea
def test_sep_metrics(mock_show):
metrics_generator = get_metrics_generator(sep_measure)
blend_results, meas_results, metrics_results = next(metrics_generator)
results = list(metrics_results.values())[0]
gal_summary = results["galaxy_summary"][
results["galaxy_summary"]["detected"] == True # noqa: E712
]
gal_summary = metrics_results["galaxy_summary"]["sep_measure"]
gal_summary = gal_summary[gal_summary["detected"] == True] # noqa: E712
msr = gal_summary["msr"]
dist = gal_summary["distance_closest_galaxy"]
fig, (ax1, ax2) = plt.subplots(1, 2)
Expand All @@ -87,7 +85,7 @@ def test_sep_metrics(mock_show):
blend_results["blend_list"],
meas_results["catalog"]["sep_measure"],
meas_results["deblended_images"]["sep_measure"],
metrics_results["sep_measure"]["matches"],
metrics_results["matches"]["sep_measure"],
indexes=list(range(5)),
band_indices=[1, 2, 3],
)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_mr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def test_multiresolution(mock_show):
assert (
"HSC" in measure_results["catalog"]["sep_measure"].keys()
), "Both surveys get well defined outputs"
assert "Rubin" in metrics_results["sep_measure"].keys(), "Both surveys get well defined outputs"
assert "HSC" in metrics_results["sep_measure"].keys(), "Both surveys get well defined outputs"
assert (
"Rubin" in metrics_results["galaxy_summary"]["sep_measure"].keys()
), "Both surveys get well defined outputs"
assert (
"HSC" in metrics_results["galaxy_summary"]["sep_measure"].keys()
), "Both surveys get well defined outputs"

plot_metrics_summary(
metrics_results,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,6 @@ def test_save():
measure_results2["deblended_images"]["sep_measure"][0],
)
np.testing.assert_array_equal(
metrics_results["sep_measure"]["galaxy_summary"]["distance_closest_galaxy"],
metrics_results2["sep_measure"]["galaxy_summary"]["distance_closest_galaxy"],
metrics_results["galaxy_summary"]["sep_measure"]["distance_closest_galaxy"],
metrics_results2["galaxy_summary"]["sep_measure"]["distance_closest_galaxy"],
)

0 comments on commit 4a4e756

Please sign in to comment.