Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a cap on distance for matching and changed error handling on KSB #158

Merged
merged 11 commits into from
May 26, 2021
39 changes: 28 additions & 11 deletions btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@ def meas_ksb_ellipticity(image, additional_params):
gal_image = galsim.Image(image[meas_band_num, :, :])
gal_image.scale = pixel_scale
shear_est = "KSB"
try:
res = galsim.hsm.EstimateShear(gal_image, psf_image, shear_est=shear_est, strict=True)
result = [res.corrected_g1, res.corrected_g2, res.observed_shape.e]
except RuntimeError as e:
print(e)
result = [-10.0, -10.0, -10.0]

res = galsim.hsm.EstimateShear(gal_image, psf_image, shear_est=shear_est, strict=False)
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
result = [res.corrected_g1, res.corrected_g2, res.observed_shape.e]
if res.error_message != "":
print(
f"Shear measurement error : '{res.error_message }'. \
This error may happen for faint galaxies or inaccurate detections."
)
return result


Expand All @@ -121,7 +123,9 @@ def distance_center(true_gal, detected_gal):
)


def get_detection_match(true_table, detected_table, f_distance=distance_center):
def get_detection_match(
true_table, detected_table, f_distance=distance_center, distance_threshold_match=5
):
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
r"""Uses the Hungarian algorithm to find optimal matching between detections and true objects.

The optimal matching is computed based on the following optimization problem:
Expand All @@ -145,6 +149,8 @@ def get_detection_match(true_table, detected_table, f_distance=distance_center):
f_distance (func): Function used to compute the distance between true and detected
galaxies. Takes as arguments the entries corresponding to the two galaxies.
By default the distance is the euclidean distance from center to center.
distance_threshold_match (float): Maximum distance for matching a detected and a
true galaxy.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

Returns:
match_table (astropy.table.Table): Table where each row corresponds to each true
Expand Down Expand Up @@ -181,8 +187,9 @@ def get_detection_match(true_table, detected_table, f_distance=distance_center):
match_indx = [-1] * len(true_table)
dist_m = [0.0] * len(true_table)
for i, indx in enumerate(true_indx):
match_indx[indx] = detected_indx[i]
dist_m[indx] = dist[indx][detected_indx[i]]
if dist[indx][detected_indx[i]] <= distance_threshold_match:
match_indx[indx] = detected_indx[i]
dist_m[indx] = dist[indx][detected_indx[i]]

match_table["match_detected_id"] = match_indx
match_table["dist"] = dist_m
Expand Down Expand Up @@ -534,6 +541,7 @@ def compute_metrics( # noqa: C901
channels_last=False,
save_path=None,
f_distance=distance_center,
distance_threshold_match=5,
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
):
"""Computes all requested metrics given information in a single batch from measure_generator.

Expand Down Expand Up @@ -572,6 +580,8 @@ def compute_metrics( # noqa: C901
f_distance (func): Function used to compute the distance between true and detected
galaxies. Takes as arguments the entries corresponding to the two galaxies.
By default the distance is the euclidean distance from center to center.
distance_threshold_match (float): Maximum distance for matching a detected and a
true galaxy.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

Returns:
results (dict) : Contains all the computed metrics. Entries are :
Expand All @@ -590,7 +600,9 @@ def compute_metrics( # noqa: C901
deblended_images = [np.moveaxis(im, -1, 1) for im in deblended_images]
results = {}
matches = [
get_detection_match(blend_list[i], detection_catalogs[i], f_distance)
get_detection_match(
blend_list[i], detection_catalogs[i], f_distance, distance_threshold_match
)
for i in range(len(blend_list))
]
results["matches"] = matches
Expand Down Expand Up @@ -678,6 +690,7 @@ def __init__(
noise_threshold_factor=3,
save_path=None,
f_distance=distance_center,
distance_threshold_match=5,
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
):
"""Initialize metrics generator.

Expand All @@ -697,10 +710,12 @@ def __init__(
correspond to a threshold of 3 sigmas (with sigma the standard deviation of
the noise)
save_path (str): Path to directory where results will be saved. If left
as None, results will not be saved.
as None, results will not be saved.
f_distance (func): Function used to compute the distance between true and detected
galaxies. Takes as arguments the entries corresponding to the two galaxies.
By default the distance is the euclidean distance from center to center.
distance_threshold_match (float): Maximum distance for matching a detected and a
true galaxy.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
"""
self.measure_generator: MeasureGenerator = measure_generator
self.use_metrics = use_metrics
Expand All @@ -709,6 +724,7 @@ def __init__(
self.noise_threshold_factor = noise_threshold_factor
self.save_path = save_path
self.f_distance = f_distance
self.distance_threshold_match = distance_threshold_match

def __next__(self):
"""Returns metric results calculated on one batch."""
Expand Down Expand Up @@ -744,6 +760,7 @@ def __next__(self):
if self.save_path is not None
else None,
f_distance=self.f_distance,
distance_threshold_match=self.distance_threshold_match,
)
metrics_results[meas_func] = metrics_results_f

Expand Down