Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sep measure function reworked #338

Merged
merged 9 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 128 additions & 12 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def measure_function(batch, idx, **kwargs):
import astropy.table
import numpy as np
import sep
from astropy import units
from astropy.coordinates import SkyCoord
from skimage.feature import peak_local_max

from btk.multiprocess import multiprocess
Expand Down Expand Up @@ -135,18 +137,120 @@ def basic_measure(
return {"catalog": catalog}


def sep_measure(
def sep_multiband_measure(
batch,
idx,
channels_last=False,
surveys=None,
matching_threshold=1.0,
sigma_noise=1.5,
is_multiresolution=False,
**kwargs,
):
"""Return detection, segmentation and deblending information with SEP.
"""Returns centers detected with source extractor by combining predictions in different bands.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

For each band in the input image we run sep for detection and append new detections to a running
list of detected coordinates. In order to avoid repeating detections, we run a KD-Tree algorithm
to calculate the angular distance between each new coordinate and its closest neighbour. Then we
discard those new coordinates that were closer than matching_threshold to any one of already
detected coordinates.

NOTE: If this function is used with the multiresolution feature,
measurements will be carried on the first survey.
ismael-mendoza marked this conversation as resolved.
Show resolved Hide resolved

Args:
batch (dict): Output of DrawBlendsGenerator object's `__next__` method.
idx (int): Index number of blend scene in the batch to preform
measurement on.
sigma_noise (float): Sigma threshold for detection against noise.
matching_threshold (float): Match centers of objects that are closer than
this threshold to a single prediction (in arseconds).

Returns:
dict containing catalog with entries corresponding to measured peaks.
"""
channel_indx = 0 if not channels_last else -1
# multiresolution
if is_multiresolution:
if surveys is None:
raise ValueError("surveys are required in order to use the MR feature.")
survey_name = surveys[0].name
image = batch["blend_images"][survey_name][idx]
wcs = batch["wcs"][survey_name]

# single-survey
else:
image = batch["blend_images"][idx]
wcs = batch["wcs"]

# run source extractor on the first band
band_image = image[0] if channel_indx == 0 else image[:, :, 0]
bkg = sep.Background(band_image)
catalog = sep.extract(band_image, sigma_noise, err=bkg.globalrms, segmentation_map=False)

# convert predictions to arcseconds
ra_coordinates, dec_coordinates = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
ra_coordinates *= 3600
dec_coordinates *= 3600

# iterate over remaining bands and match predictions using KdTree
for band in range(1, image.shape[channel_indx]):
# run source extractor
band_image = image[band] if channel_indx == 0 else image[:, :, band]
bkg = sep.Background(band_image)
catalog = sep.extract(band_image, sigma_noise, err=bkg.globalrms, segmentation_map=False)

# convert predictions to arcseconds
ra_detections, dec_detections = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
ra_detections *= 3600
dec_detections *= 3600

# convert to sky coordinates
c1 = SkyCoord(ra=ra_detections * units.arcsec, dec=dec_detections * units.arcsec)
c2 = SkyCoord(ra=ra_coordinates * units.arcsec, dec=dec_coordinates * units.arcsec)

# merge new detections with the running list of coordinates
if len(c1) > 0 and len(c2) > 0:
# runs KD-tree to get distances to the closest neighbours
idx, distance2d, _ = c1.match_to_catalog_sky(c2)
distance2d = distance2d.arcsec

# add new predictions, masking those that are closer than threshold
ra_coordinates = np.concatenate(
[ra_coordinates, ra_detections[distance2d > matching_threshold]]
)
dec_coordinates = np.concatenate(
[dec_coordinates, dec_detections[distance2d > matching_threshold]]
)
else:
ra_coordinates = np.concatenate([ra_coordinates, ra_detections])
dec_coordinates = np.concatenate([dec_coordinates, dec_detections])

# Wrap in the astropy table
t = astropy.table.Table()
t["ra"] = ra_coordinates
t["dec"] = dec_coordinates

return {"catalog": t}


def sep_singleband_measure(
batch,
idx,
meas_band_num=3,
use_mean=False,
channels_last=False,
surveys=None,
sigma_noise=1.5,
is_multiresolution=False,
**kwargs,
):
"""Return detection, segmentation and deblending information running SEP on a single band.

The function performs detection and deblending of the sources based on the provided
band index. If use_mean feature is used, then the measurement function is using
the average of all the bands.

For each potentially multi-band image, an average over the bands is taken before measurement.
NOTE: If this function is used with the multiresolution feature,
measurements will be carried on the first survey, and deblended images
or segmentations will not be returned.
Expand All @@ -155,6 +259,8 @@ def sep_measure(
batch (dict): Output of DrawBlendsGenerator object's `__next__` method.
idx (int): Index number of blend scene in the batch to preform
measurement on.
meas_band_num (int): Indicates the index of band to use fo the measurement
use_mean (bool): If True, then algorithm uses the average of all the bands
sigma_noise (float): Sigma threshold for detection against noise.

Returns:
Expand All @@ -168,21 +274,26 @@ def sep_measure(
raise ValueError("surveys are required in order to use the MR feature.")
survey_name = surveys[0].name
image = batch["blend_images"][survey_name][idx]
avg_image = np.mean(image, axis=channel_indx)
wcs = batch["wcs"][survey_name]

# single-survey
else:
image = batch["blend_images"][idx]
avg_image = np.mean(image, axis=channel_indx)
wcs = batch["wcs"]

stamp_size = avg_image.shape[0]
bkg = sep.Background(avg_image)
# get a 1-channel input for sep
if use_mean:
band_image = np.mean(image, axis=channel_indx)
else:
band_image = image[meas_band_num] if channel_indx == 0 else image[:, :, meas_band_num]

# run source extractor
stamp_size = band_image.shape[0]
bkg = sep.Background(band_image)
catalog, segmentation = sep.extract(
avg_image, sigma_noise, err=bkg.globalrms, segmentation_map=True
band_image, sigma_noise, err=bkg.globalrms, segmentation_map=True
)

# reshape segmentation map
n_objects = len(catalog)
segmentation_exp = np.zeros((n_objects, stamp_size, stamp_size), dtype=bool)
deblended_images = np.zeros((n_objects, *image.shape), dtype=image.dtype)
Expand All @@ -195,6 +306,7 @@ def sep_measure(
seg_i_reshaped = np.moveaxis(seg_i_reshaped, 0, np.argmin(image.shape))
deblended_images[i] = image * seg_i_reshaped

# wrap results in astropy table
t = astropy.table.Table()
t["ra"], t["dec"] = wcs.pixel_to_world_values(catalog["x"], catalog["y"])
t["ra"] *= 3600
Expand Down Expand Up @@ -409,12 +521,12 @@ def __next__(self):
if segmentation[key_name] is not None:
np.save(
os.path.join(self.save_path, key_name, "segmentation"),
segmentation[key_name],
np.array(segmentation[key_name], dtype=object),
)
if deblended_images[key_name] is not None:
np.save(
os.path.join(self.save_path, key_name, "deblended_images"),
deblended_images[key_name],
np.array(deblended_images[key_name], dtype=object),
)
for j, cat in enumerate(catalog[key_name]):
cat.write(
Expand All @@ -430,4 +542,8 @@ def __next__(self):
return blend_output, measure_results


available_measure_functions = {"basic": basic_measure, "sep": sep_measure}
available_measure_functions = {
"basic": basic_measure,
"sep_singleband_measure": sep_singleband_measure,
"sep_multiband_measure": sep_multiband_measure,
}
40 changes: 20 additions & 20 deletions notebooks/00-intro.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/01a-cosmos_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.12"
},
"toc": {
"base_numbering": 1,
Expand Down
4 changes: 2 additions & 2 deletions notebooks/01b-scarlet-measure.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
"source": [
"measure_kwargs=[{\"sigma_noise\": 2.0}]\n",
"meas_generator = btk.measure.MeasureGenerator(\n",
" [btk.measure.sep_measure,scarlet_measure], draw_blend_generator, measure_kwargs=measure_kwargs\n",
" [btk.measure.sep_singleband_measure,scarlet_measure], draw_blend_generator, measure_kwargs=measure_kwargs\n",
")\n",
"metrics_generator = btk.metrics.MetricsGenerator(\n",
" meas_generator,\n",
Expand Down Expand Up @@ -411,7 +411,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.8.12"
},
"toc": {
"base_numbering": 1,
Expand Down
82 changes: 41 additions & 41 deletions notebooks/02a-multi-tutorial.ipynb

Large diffs are not rendered by default.

120 changes: 62 additions & 58 deletions notebooks/02b-custom-tutorial.ipynb

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions notebooks/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -801,13 +801,13 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f61a8392",
"id": "a5b32bf1-1204-41de-a70f-8977831695bf",
"metadata": {},
"outputs": [],
"source": [
"import sep\n",
"def sep_measure(batch, idx, channels_last=False, sigma_noise=1.5, **kwargs):\n",
" \"\"\"Return detection, segmentation and deblending information with SEP.\n",
"def sep_singleband_measure(batch, idx, meas_band_num=2, channels_last=False, sigma_noise=1.5, **kwargs):\n",
" \"\"\"Return detection, segmentation and deblending information running SEP on a single band.\n",
"\n",
" NOTE: This is a simplified version of the actual function that does not \n",
" support multi-resolution. \n",
Expand All @@ -816,27 +816,27 @@
" batch (dict): Output of DrawBlendsGenerator object's `__next__` method.\n",
" idx (int): Index number of blend scene in the batch to preform\n",
" measurement on.\n",
" meas_band_num (int) – Indicates the index of band to use fo the measurement\n",
" sigma_noise (float): Sigma threshold for detection against noise.\n",
"\n",
" Returns:\n",
" dict with the centers of sources detected by SEP detection algorithm.\n",
" \"\"\"\n",
" channel_indx = 0 if not channels_last else -1\n",
"\n",
"\n",
" image = batch[\"blend_images\"][idx]\n",
" coadd = np.mean(image, axis=channel_indx)\n",
" # select band of an image\n",
" band_image = image[meas_band_num] if channel_indx == 0 else image[:, :, meas_band_num]\n",
" wcs = batch[\"wcs\"]\n",
"\n",
" stamp_size = coadd.shape[0]\n",
" bkg = sep.Background(coadd)\n",
" # run source extractor\n",
" stamp_size = band_image.shape[0]\n",
" bkg = sep.Background(band_image)\n",
" catalog, segmentation = sep.extract(\n",
" coadd, sigma_noise, err=bkg.globalrms, segmentation_map=True\n",
" band_image, sigma_noise, err=bkg.globalrms, segmentation_map=True\n",
" )\n",
"\n",
" # reshape segmentation map\n",
" n_objects = len(catalog)\n",
"\n",
" # organizing returned images into numpy arrays\n",
" segmentation_exp = np.zeros((n_objects, stamp_size, stamp_size), dtype=bool)\n",
" deblended_images = np.zeros((n_objects, *image.shape), dtype=image.dtype)\n",
" for i in range(n_objects):\n",
Expand All @@ -848,9 +848,11 @@
" seg_i_reshaped = np.moveaxis(seg_i_reshaped, 0, np.argmin(image.shape))\n",
" deblended_images[i] = image * seg_i_reshaped\n",
"\n",
" # translate from pixel to sky coordinates. \n",
" # wrap results in astropy table\n",
" t = astropy.table.Table()\n",
" t[\"ra\"], t[\"dec\"] = wcs.pixel_to_world_values(catalog[\"x\"], catalog[\"y\"])\n",
" t[\"ra\"] *= 3600\n",
" t[\"dec\"] *= 3600\n",
"\n",
" return {\n",
" \"catalog\": t,\n",
Expand Down Expand Up @@ -1017,7 +1019,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.12"
},
"toc": {
"base_numbering": 1,
Expand Down
Loading