Skip to content

Commit

Permalink
Added possibility to change how distance is computed (#147)
Browse files Browse the repository at this point in the history
* Added possibility to change how distance is computed

* Added test for new feature

* Normalized the computation of the distances

* remove print statement

Co-authored-by: Ismael Mendoza <[email protected]>
  • Loading branch information
thuiop and ismael-mendoza authored May 18, 2021
1 parent b4903ac commit d198d09
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 9,312 deletions.
46 changes: 38 additions & 8 deletions btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,21 @@ def meas_ksb_ellipticity(image, additional_params):
return result


def get_detection_match(true_table, detected_table):
def distance_center(true_gal, detected_gal):
"""Computes distance between the two galaxies given as arguments.
Args:
true_gal (astropy.table.Table): Contains information related to the true galaxy.
detected_gal (astropy.table.Table): Contains information related to the detected galaxy
Returns:
Distance between the two galaxies
"""
return np.hypot(
true_gal["x_peak"] - detected_gal["x_peak"], true_gal["y_peak"] - detected_gal["y_peak"]
)


def get_detection_match(true_table, detected_table, f_distance=distance_center):
r"""Uses the Hungarian algorithm to find optimal matching between detections and true objects.
The optimal matching is computed based on the following optimization problem:
Expand All @@ -126,6 +140,9 @@ def get_detection_match(true_table, detected_table):
the true object parameter values in one blend.
detected_table(astropy.table.Table): Table with entries corresponding
to output of measurement algorithm in one blend.
f_distance (func): Function used to compute the distance between true and detected
galaxies. Takes as arguments the entries corresponding to the two galaxies.
By default the distance is the euclidean distance from center to center.
Returns:
match_table (astropy.table.Table): Table where each row corresponds to each true
Expand All @@ -145,9 +162,13 @@ def get_detection_match(true_table, detected_table):
if "y_peak" not in true_table.colnames:
raise KeyError("Detection table has no column y_peak")
match_table = astropy.table.Table()
t_x = true_table["x_peak"].reshape(-1, 1) - detected_table["x_peak"].reshape(1, -1)
t_y = true_table["y_peak"].reshape(-1, 1) - detected_table["y_peak"].reshape(1, -1)
dist = np.hypot(t_x, t_y) # dist[i][j] = distance between true object i and detected object j.

print(f_distance)
# dist[i][j] = distance between true object i and detected object j.
dist = np.zeros((len(true_table), len(detected_table)))
for i, true_gal in enumerate(true_table):
for j, detected_gal in enumerate(detected_table):
dist[i][j] = f_distance(true_gal, detected_gal)

# solve optimization problem.
# true_table[true_indx[i]] is matched with detected_table[detected_indx[i]]
Expand Down Expand Up @@ -510,6 +531,7 @@ def compute_metrics( # noqa: C901
meas_band_num=0,
target_meas={},
channels_last=False,
f_distance=distance_center,
):
"""Computes all requested metrics given information in a single batch from measure_generator.
Expand Down Expand Up @@ -543,6 +565,9 @@ def compute_metrics( # noqa: C901
be returned for both isolated and deblended images to compare.
channels_last (bool) : Indicates whether the images should be channels first (NCHW)
or channels last (NHWC).
f_distance (func): Function used to compute the distance between true and detected
galaxies. Takes as arguments the entries corresponding to the two galaxies.
By default the distance is the euclidean distance from center to center.
Returns:
results (dict) : Contains all the computed metrics. Entries are :
Expand All @@ -561,7 +586,8 @@ def compute_metrics( # noqa: C901
deblended_images = [np.moveaxis(im, -1, 1) for im in deblended_images]
results = {}
matches = [
get_detection_match(blend_list[i], detection_catalogs[i]) for i in range(len(blend_list))
get_detection_match(blend_list[i], detection_catalogs[i], f_distance)
for i in range(len(blend_list))
]
results["matches"] = matches

Expand Down Expand Up @@ -612,9 +638,7 @@ def compute_metrics( # noqa: C901
if len(blend) > 1:
dists = []
for g in blend:
dx = gal["x_peak"] - g["x_peak"]
dy = gal["y_peak"] - g["y_peak"]
dists.append(np.hypot(dx, dy))
dists.append(f_distance(gal, g))
row["distance_closest_galaxy"] = np.partition(dists, 1)[1]
else:
row["distance_closest_galaxy"] = -1 # placeholder
Expand All @@ -641,6 +665,7 @@ def __init__(
meas_band_num=0,
target_meas={},
noise_threshold_factor=3,
f_distance=distance_center,
):
"""Initialize metrics generator.
Expand All @@ -659,12 +684,16 @@ def __init__(
applied when getting segmentations from true images. A value of 3 would
correspond to a threshold of 3 sigmas (with sigma the standard deviation of
the noise)
f_distance (func): Function used to compute the distance between true and detected
galaxies. Takes as arguments the entries corresponding to the two galaxies.
By default the distance is the euclidean distance from center to center.
"""
self.measure_generator: MeasureGenerator = measure_generator
self.use_metrics = use_metrics
self.meas_band_num = meas_band_num
self.target_meas = target_meas
self.noise_threshold_factor = noise_threshold_factor
self.f_distance = f_distance

def __next__(self):
"""Returns metric results calculated on one batch."""
Expand Down Expand Up @@ -696,6 +725,7 @@ def __next__(self):
self.meas_band_num,
target_meas,
channels_last=self.measure_generator.channels_last,
f_distance=self.f_distance,
)
metrics_results[meas_func] = metrics_results_f

Expand Down
9,467 changes: 164 additions & 9,303 deletions notebooks/intro.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import btk.metrics


def get_metrics_generator(meas_function, cpus=1, measure_kwargs=None):
def get_metrics_generator(
meas_function, cpus=1, f_distance=btk.metrics.distance_center, measure_kwargs=None
):
"""Returns draw generator with group sampling function"""

np.random.seed(0)
Expand Down Expand Up @@ -40,6 +42,7 @@ def get_metrics_generator(meas_function, cpus=1, measure_kwargs=None):
meas_generator,
use_metrics=("detection", "segmentation", "reconstruction"),
target_meas={"ellipticity": btk.metrics.meas_ksb_ellipticity},
f_distance=f_distance,
)
return metrics_generator

Expand Down

0 comments on commit d198d09

Please sign in to comment.