Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sep measure function reworked #338

Merged
merged 9 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 107 additions & 9 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def measure_function(batch, idx, **kwargs):
import astropy.table
import numpy as np
import sep
from astropy import units
from astropy.coordinates import SkyCoord
from skimage.feature import peak_local_max

from btk.multiprocess import multiprocess
Expand Down Expand Up @@ -135,18 +137,107 @@ def basic_measure(
return {"catalog": catalog}


def sep_measure(
def sep_multiband_measure(
batch,
idx,
channels_last=False,
surveys=None,
matching_threshold=1.0,
sigma_noise=1.5,
is_multiresolution=False,
**kwargs,
):
"""Return detection, segmentation and deblending information with SEP.
"""Returns centers detected with source extractor by combining predictions in different bands.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

NOTE: If this function is used with the multiresolution feature,
measurements will be carried on the first survey.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

Args:
batch (dict): Output of DrawBlendsGenerator object's `__next__` method.
idx (int): Index number of blend scene in the batch to preform
measurement on.
sigma_noise (float): Sigma threshold for detection against noise.
matching_threshold (float): Match centers of objects that are closer than
this threshold to a single prediction.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

Returns:
dict containing catalog with entries corresponding to measured peaks.
"""
channel_indx = 0 if not channels_last else -1
# multiresolution
if is_multiresolution:
if surveys is None:
raise ValueError("surveys are required in order to use the MR feature.")
survey_name = surveys[0].name
image = batch["blend_images"][survey_name][idx]
wcs = batch["wcs"][survey_name]

# single-survey
else:
image = batch["blend_images"][idx]
wcs = batch["wcs"]

# run source extractor on the first band
band_image = image[0] if channel_indx == 0 else image[:, :, 0]
bkg = sep.Background(band_image)
catalog = sep.extract(band_image, sigma_noise, err=bkg.globalrms, segmentation_map=False)

# convert predictions to arcseconds
ra_coordinates, dec_coordinates = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
ra_coordinates *= 3600
dec_coordinates *= 3600

# iterate over remaining bands and match predictions using KdTree
for band in range(1, image.shape[channel_indx]):
# run source extractor
band_image = image[band] if channel_indx == 0 else image[:, :, band]
bkg = sep.Background(band_image)
catalog = sep.extract(band_image, sigma_noise, err=bkg.globalrms, segmentation_map=False)

# convert predictions to arcseconds
ra_detections, dec_detections = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
ra_detections *= 3600
dec_detections *= 3600

# convert to sky coordinates
c1 = SkyCoord(ra=ra_detections * units.arcsec, dec=dec_detections * units.arcsec)
c2 = SkyCoord(ra=ra_coordinates * units.arcsec, dec=dec_coordinates * units.arcsec)

# add new coordinates
if len(c1) > 0 and len(c2) > 0:
idx, d2d, d3d = c1.match_to_catalog_sky(c2)
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
d2d = d2d.arcsec

ra_coordinates = np.concatenate(
[ra_coordinates, ra_detections[d2d > matching_threshold]]
)
dec_coordinates = np.concatenate(
[dec_coordinates, dec_detections[d2d > matching_threshold]]
)
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
else:
ra_coordinates = np.concatenate([ra_coordinates, ra_detections])
dec_coordinates = np.concatenate([dec_coordinates, dec_detections])

# Wrap in the astropy table
t = astropy.table.Table()
t["ra"] = ra_coordinates
t["dec"] = dec_coordinates

return {"catalog": t}


def sep_singleband_measure(
batch,
idx,
meas_band_num=3,
channels_last=False,
surveys=None,
sigma_noise=1.5,
is_multiresolution=False,
**kwargs,
):
"""Return detection, segmentation and deblending information running SEP on a single band.

For each potentially multi-band image, an average over the bands is taken before measurement.
NOTE: If this function is used with the multiresolution feature,
measurements will be carried on the first survey, and deblended images
or segmentations will not be returned.
Expand All @@ -155,6 +246,7 @@ def sep_measure(
batch (dict): Output of DrawBlendsGenerator object's `__next__` method.
idx (int): Index number of blend scene in the batch to preform
measurement on.
meas_band_num (int) – Indicates the index of band to use fo the measurement
sigma_noise (float): Sigma threshold for detection against noise.

Returns:
Expand All @@ -168,21 +260,22 @@ def sep_measure(
raise ValueError("surveys are required in order to use the MR feature.")
survey_name = surveys[0].name
image = batch["blend_images"][survey_name][idx]
avg_image = np.mean(image, axis=channel_indx)
wcs = batch["wcs"][survey_name]

# single-survey
else:
image = batch["blend_images"][idx]
avg_image = np.mean(image, axis=channel_indx)
wcs = batch["wcs"]

stamp_size = avg_image.shape[0]
bkg = sep.Background(avg_image)
# run source extractor
band_image = image[meas_band_num] if channel_indx == 0 else image[:, :, meas_band_num]
stamp_size = band_image.shape[0]
bkg = sep.Background(band_image)
catalog, segmentation = sep.extract(
avg_image, sigma_noise, err=bkg.globalrms, segmentation_map=True
band_image, sigma_noise, err=bkg.globalrms, segmentation_map=True
)

# reshape segmentation map
n_objects = len(catalog)
segmentation_exp = np.zeros((n_objects, stamp_size, stamp_size), dtype=bool)
deblended_images = np.zeros((n_objects, *image.shape), dtype=image.dtype)
Expand All @@ -195,6 +288,7 @@ def sep_measure(
seg_i_reshaped = np.moveaxis(seg_i_reshaped, 0, np.argmin(image.shape))
deblended_images[i] = image * seg_i_reshaped

# wrap results in astropy table
t = astropy.table.Table()
t["ra"], t["dec"] = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
t["ra"] *= 3600
Expand Down Expand Up @@ -430,4 +524,8 @@ def __next__(self):
return blend_output, measure_results


available_measure_functions = {"basic": basic_measure, "sep": sep_measure}
available_measure_functions = {
"basic": basic_measure,
"sep_singleband_measure": sep_singleband_measure,
"sep_multiband_measure": sep_multiband_measure,
}
116 changes: 82 additions & 34 deletions tests/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,42 @@

from btk.catalog import CatsimCatalog
from btk.draw_blends import CatsimGenerator
from btk.measure import MeasureGenerator, basic_measure, sep_measure
from btk.measure import (
MeasureGenerator,
basic_measure,
sep_multiband_measure,
sep_singleband_measure,
)
from btk.sampling_functions import DefaultSampling
from btk.survey import get_surveys

TEST_SEED = 0


def get_meas_results(meas_function, cpus=1, measure_kwargs=None):
"""Returns draw generator with group sampling function"""

"""Runs a measurement function on a set of data -- returns targets and predicitons"""
catalog_name = data_dir / "sample_input_catalog.fits"
stamp_size = 24
survey = get_surveys("LSST")
shifts = [[-0.3, 1.2]]
indexes = [[1]]
shifts = [
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
[-0.3, 1.2],
[-1.6, -1.7],
[-1.1, -2.1],
[1.4, 1.8],
[-1.8, -0.8],
[-0.6, 2.2],
[-2.0, -0.7],
[-2.2, 1.9],
[1.1, -1.5],
[0.1, -2.3],
[-2.3, 1.9],
[0.4, -1.9],
[2.0, -2.0],
[2.0, 0.1],
[0.2, 2.4],
[-1.8, -2.0],
]
indexes = [[4], [5], [9], [1], [9], [2], [0], [2], [3], [8], [0], [7], [10], [2], [0], [10]]
catalog = CatsimCatalog.from_file(catalog_name)
draw_blend_generator = CatsimGenerator(
catalog,
Expand All @@ -26,41 +47,54 @@ def get_meas_results(meas_function, cpus=1, measure_kwargs=None):
shifts=shifts,
indexes=indexes,
stamp_size=stamp_size,
batch_size=1,
seed=TEST_SEED,
)
meas_generator = MeasureGenerator(
meas_function, draw_blend_generator, cpus=cpus, measure_kwargs=measure_kwargs
)
blend_results, results = next(meas_generator)
wcs = blend_results["wcs"]
x, y = wcs.world_to_pixel_values(shifts[0][0] / 3600, shifts[0][1] / 3600)
target = np.array([x.item(), y.item()])
target = np.array(
[[blend["x_peak"].item(), blend["y_peak"].item()] for blend in blend_results["blend_list"]]
)
return target, results


def compare_sep():
"""Test detection with sep"""
target, results = get_meas_results(sep_measure, measure_kwargs=[{"sigma_noise": 2.0}])
x_peak, y_peak = (
results["catalog"]["sep_measure"][0]["x_peak"].item(),
results["catalog"]["sep_measure"][0]["y_peak"].item(),
"""Test sep detection using single band and multiband"""
target, results = get_meas_results(
[sep_singleband_measure, sep_multiband_measure], measure_kwargs=[{"sigma_noise": 1.5}]
)
detected_centers = np.array([x_peak, y_peak])
dist = np.max(np.abs(detected_centers - target))
np.testing.assert_array_less(dist, 1.0)
detected_sources = {"sep_singleband_measure": 0, "sep_multiband_measure": 0}
for meas_function in detected_sources.keys():
for i, blend in enumerate(results["catalog"][meas_function]):
if len(blend) > 0:
detected_centers = np.array([blend[0]["x_peak"].item(), blend[0]["y_peak"].item()])
dist = np.max(np.abs(detected_centers - target[i]))
np.testing.assert_array_less(dist, 1.5)
detected_sources[meas_function] += 1

assert detected_sources["sep_multiband_measure"] >= 0.5 * len(target)
assert detected_sources["sep_singleband_measure"] >= 0.5 * len(target)
assert detected_sources["sep_multiband_measure"] >= detected_sources["sep_singleband_measure"]


def compare_sep_multiprocessing():
"""Test detection with sep"""
target, results = get_meas_results(sep_measure, measure_kwargs=[{"sigma_noise": 2.0}], cpus=4)
x_peak, y_peak = (
results["catalog"]["sep_measure"][0]["x_peak"].item(),
results["catalog"]["sep_measure"][0]["y_peak"].item(),
"""Test sep dettection using single band and multiband with multiprocessing"""
target, results = get_meas_results(
[sep_singleband_measure, sep_multiband_measure], measure_kwargs=[{"sigma_noise": 1.5}]
)
detected_centers = np.array([x_peak, y_peak])
dist = np.max(np.abs(detected_centers - target))
np.testing.assert_array_less(dist, 1.0)
detected_sources = {"sep_singleband_measure": 0, "sep_multiband_measure": 0}
for meas_function in detected_sources.keys():
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved
for i, blend in enumerate(results["catalog"][meas_function]):
if len(blend) > 0:
detected_centers = np.array([blend[0]["x_peak"].item(), blend[0]["y_peak"].item()])
dist = np.max(np.abs(detected_centers - target[i]))
np.testing.assert_array_less(dist, 1.5)
detected_sources[meas_function] += 1

assert detected_sources["sep_multiband_measure"] >= 0.5 * len(target)
assert detected_sources["sep_singleband_measure"] >= 0.5 * len(target)
assert detected_sources["sep_multiband_measure"] >= detected_sources["sep_singleband_measure"]


def test_algorithms():
Expand All @@ -71,14 +105,28 @@ def test_algorithms():


def test_measure_kwargs():
"""Test detection with sep"""
"""Test measure kwargs parameters for sep"""
target, results = get_meas_results(
sep_measure, measure_kwargs=[{"sigma_noise": 2.0}, {"sigma_noise": 3.0}]
[sep_singleband_measure, sep_multiband_measure],
measure_kwargs=[{"sigma_noise": 1.5}, {"sigma_noise": 2.0}],
)
x_peak, y_peak = (
results["catalog"]["sep_measure0"][0]["x_peak"].item(),
results["catalog"]["sep_measure0"][0]["y_peak"].item(),
)
detected_centers = np.array([x_peak, y_peak])
dist = np.max(np.abs(detected_centers - target))
np.testing.assert_array_less(dist, 1.0)
detected_sources = {}
for meas_function in [
"sep_singleband_measure0",
"sep_singleband_measure1",
"sep_multiband_measure0",
"sep_multiband_measure1",
]:
assert meas_function in results["catalog"].keys()
for i, blend in enumerate(results["catalog"][meas_function]):
if len(blend) > 0:
detected_centers = np.array([blend[0]["x_peak"].item(), blend[0]["y_peak"].item()])
dist = np.max(np.abs(detected_centers - target[i]))
np.testing.assert_array_less(dist, 1.5)
if meas_function in detected_sources.keys():
detected_sources[meas_function] += 1
else:
detected_sources[meas_function] = 1

assert detected_sources["sep_multiband_measure0"] >= detected_sources["sep_singleband_measure0"]
assert detected_sources["sep_multiband_measure1"] >= detected_sources["sep_singleband_measure1"]
21 changes: 11 additions & 10 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

import matplotlib.pyplot as plt
import numpy as np
from conftest import data_dir

import btk.plot_utils as plot_utils
from btk.catalog import CatsimCatalog
from btk.draw_blends import CatsimGenerator
from btk.measure import MeasureGenerator, sep_measure
from btk.measure import MeasureGenerator, sep_singleband_measure
from btk.metrics import (
MetricsGenerator,
auc,
Expand All @@ -25,7 +26,7 @@ def get_metrics_generator(
measure_kwargs=None,
):
"""Returns draw generator with group sampling function"""
catalog_name = "data/sample_input_catalog.fits"
catalog_name = data_dir / "sample_input_catalog.fits"
stamp_size = 24
survey = get_surveys("LSST")
shifts = [
Expand Down Expand Up @@ -61,9 +62,9 @@ def get_metrics_generator(

@patch("btk.plot_utils.plt.show")
def test_sep_metrics(mock_show):
metrics_generator = get_metrics_generator(sep_measure)
metrics_generator = get_metrics_generator(sep_singleband_measure)
blend_results, meas_results, metrics_results = next(metrics_generator)
gal_summary = metrics_results["galaxy_summary"]["sep_measure"]
gal_summary = metrics_results["galaxy_summary"]["sep_singleband_measure"]
gal_summary = gal_summary[gal_summary["detected"] == True] # noqa: E712
msr = gal_summary["msr"]
dist = gal_summary["distance_closest_galaxy"]
Expand All @@ -88,9 +89,9 @@ def test_sep_metrics(mock_show):
blend_results["blend_images"],
blend_results["isolated_images"],
blend_results["blend_list"],
meas_results["catalog"]["sep_measure"],
meas_results["deblended_images"]["sep_measure"],
metrics_results["matches"]["sep_measure"],
meas_results["catalog"]["sep_singleband_measure"],
meas_results["deblended_images"]["sep_singleband_measure"],
metrics_results["matches"]["sep_singleband_measure"],
indexes=list(range(5)),
band_indices=[1, 2, 3],
)
Expand All @@ -101,11 +102,11 @@ def test_sep_metrics(mock_show):
def test_measure_kwargs(mock_show):
"""Test detection with sep"""
metrics_generator = get_metrics_generator(
sep_measure, measure_kwargs=[{"sigma_noise": 2.0}, {"sigma_noise": 3.0}]
sep_singleband_measure, measure_kwargs=[{"sigma_noise": 2.0}, {"sigma_noise": 3.0}]
)
_, _, results = next(metrics_generator)
average_precision = auc(results, "sep_measure", 2, plot=True)
assert average_precision == 0.3125
average_precision = auc(results, "sep_singleband_measure", 2, plot=True)
assert average_precision == 0.25


def test_detection_eff_matrix():
Expand Down
Loading