Skip to content

Commit

Permalink
Multiresolution flag (#215)
Browse files Browse the repository at this point in the history
* Added target measures to custom tutorial

* Fix nbval issues

* Added multiresolution flag
  • Loading branch information
thuiop authored Jul 30, 2021
1 parent 5d8da24 commit aaff948
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
3 changes: 2 additions & 1 deletion btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
self.surveys = surveys
else:
raise TypeError("surveys must be a Survey object or a list of Survey objects.")
self.is_multiresolution = len(self.surveys) > 1

self.stamp_size = stamp_size
self.add_noise = add_noise
Expand Down Expand Up @@ -268,7 +269,7 @@ def __next__(self):
format="ascii",
overwrite=True,
)
if len(self.surveys) > 1:
if self.is_multiresolution:
output = {
"blend_images": blend_images,
"isolated_images": isolated_images,
Expand Down
24 changes: 18 additions & 6 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def add_pixel_columns(catalog, wcs):
return catalog_t


def basic_measure(batch, idx, channels_last=False, surveys=None, **kwargs):
def basic_measure(
batch, idx, channels_last=False, surveys=None, is_multiresolution=False, **kwargs
):
"""Return centers detected with skimage.feature.peak_local_max.
NOTE: If this function is used with the multiresolution feature,
Expand All @@ -106,7 +108,7 @@ def basic_measure(batch, idx, channels_last=False, surveys=None, **kwargs):
channel_indx = 0 if not channels_last else -1

# multiresolution
if isinstance(batch["blend_images"], dict):
if is_multiresolution:
if surveys is None:
raise ValueError("surveys are required in order to use the MR feature.")
surveys = kwargs.get("surveys", None)
Expand All @@ -129,7 +131,15 @@ def basic_measure(batch, idx, channels_last=False, surveys=None, **kwargs):
return {"catalog": catalog}


def sep_measure(batch, idx, channels_last=False, surveys=None, sigma_noise=1.5, **kwargs):
def sep_measure(
batch,
idx,
channels_last=False,
surveys=None,
sigma_noise=1.5,
is_multiresolution=False,
**kwargs,
):
"""Return detection, segmentation and deblending information with SEP.
NOTE: If this function is used with the multiresolution feature,
Expand All @@ -148,7 +158,7 @@ def sep_measure(batch, idx, channels_last=False, surveys=None, sigma_noise=1.5,
channel_indx = 0 if not channels_last else -1

# multiresolution
if isinstance(batch["blend_images"], dict):
if is_multiresolution:
if surveys is None:
raise ValueError("surveys are required in order to use the MR feature.")
survey_name = surveys[0].name
Expand Down Expand Up @@ -184,7 +194,7 @@ def sep_measure(batch, idx, channels_last=False, surveys=None, sigma_noise=1.5,
t["ra"], t["dec"] = wcs.pixel_to_world_values(catalog["x"], catalog["y"])

# If multiresolution, return only the catalog
if isinstance(batch["blend_images"], dict):
if is_multiresolution:
return {"catalog": t}
else:
return {
Expand Down Expand Up @@ -241,6 +251,7 @@ def __init__(
self.batch_size = self.draw_blend_generator.batch_size
self.channels_last = self.draw_blend_generator.channels_last
self.surveys = self.draw_blend_generator.surveys
self.is_multiresolution = self.draw_blend_generator.is_multiresolution
self.verbose = verbose
self.save_path = save_path

Expand All @@ -249,6 +260,7 @@ def __init__(
for m in self.measure_kwargs:
m["channels_last"] = self.channels_last
m["surveys"] = self.surveys
m["is_multiresolution"] = self.is_multiresolution

def __iter__(self):
"""Return iterator which is the object itself."""
Expand Down Expand Up @@ -356,7 +368,7 @@ def __next__(self):
)
# If multiresolution, we reverse the order between the survey name and
# the index of the blend
if isinstance(blend_output["blend_list"], dict):
if self.is_multiresolution:
survey_keys = list(blend_output["blend_list"].keys())
# We duplicate the catalog for each survey to get the pixel coordinates
catalogs_temp = {}
Expand Down
3 changes: 2 additions & 1 deletion btk/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ def __init__(
self.f_distance = f_distance
self.distance_threshold_match = distance_threshold_match
self.verbose = verbose
self.is_multiresolution = self.measure_generator.is_multiresolution

def __next__(self):
"""Returns metric results calculated on one batch."""
Expand All @@ -732,7 +733,7 @@ def __next__(self):

metrics_results = {}
for meas_func in measure_results["catalog"].keys():
if isinstance(blend_results["isolated_images"], dict):
if self.is_multiresolution:
metrics_results_f = {}
for i, surv in enumerate(blend_results["isolated_images"].keys()):
additional_params = {
Expand Down
10 changes: 4 additions & 6 deletions notebooks/scarlet-measure.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88,30 +88,28 @@
},
"outputs": [],
"source": [
"def scarlet_measure(batch,idx,channels_last=False,**kwargs):\n",
"def scarlet_measure(batch,idx,channels_last=False, is_multiresolution=False,**kwargs):\n",
" \"\"\"Measure function for SCARLET\n",
" \"\"\"\n",
" sigma_noise = kwargs.get(\"sigma_noise\", 1.5) \n",
" surveys = kwargs.get(\"surveys\", None)\n",
" \n",
" #Fist we carry out the detection, using SExtractor (sep being the python implementation)\n",
" # We need to differentiate between the multiresolution and the regular case\n",
" if isinstance(batch[\"blend_images\"], dict):\n",
" if is_multiresolution:\n",
" survey_name = surveys[0].name\n",
" image = batch[\"blend_images\"][survey_name][idx]\n",
" # Put the image in the channels first format if not already the case\n",
" image = np.moveaxis(image,-1,0) if channels_last else image\n",
" coadd = np.mean(image, axis=0)\n",
" wcs_ref = batch[\"wcs\"][survey_name]\n",
" psf = np.array([p.drawImage(galsim.Image(image.shape[1],image.shape[2]),scale=surveys[0].pixel_scale).array for p in batch[\"psf\"][survey_name]])\n",
" multiresolution = True\n",
" else:\n",
" image = batch[\"blend_images\"][idx]\n",
" image = np.moveaxis(image,-1,0) if channels_last else image\n",
" coadd = np.mean(image, axis=0)\n",
" psf = np.array([p.drawImage(galsim.Image(image.shape[1],image.shape[2]),scale=surveys[0].pixel_scale).array for p in batch[\"psf\"]])\n",
" wcs_ref = batch[\"wcs\"]\n",
" multiresolution = False\n",
" stamp_size = coadd.shape[0]\n",
" \n",
" bkg = sep.Background(coadd)\n",
Expand All @@ -123,7 +121,7 @@
" if len(catalog) == 0:\n",
" t = astropy.table.Table()\n",
" t[\"ra\"], t[\"dec\"] = wcs_ref.pixel_to_world_values(catalog[\"x\"], catalog[\"y\"])\n",
" if multiresolution:\n",
" if is_multiresolution:\n",
" return {\"catalog\":t,\"segmentation\":None,\"deblended_images\":{s.name: np.array([np.zeros((len(s.filters),batch[\"blend_images\"][s.name][idx].shape[1],\n",
" batch[\"blend_images\"][s.name][idx].shape[1]))]) for s in surveys}}\n",
" else:\n",
Expand Down Expand Up @@ -200,7 +198,7 @@
" except AssertionError: #If the fitting fails\n",
" t = astropy.table.Table()\n",
" t[\"ra\"], t[\"dec\"] = wcs_ref.pixel_to_world_values(catalog[\"x\"], catalog[\"y\"])\n",
" if multiresolution:\n",
" if is_multiresolution:\n",
" deblended_images={s.name: np.array([np.zeros((len(s.filters),batch[\"blend_images\"][s.name][idx].shape[1],\n",
" batch[\"blend_images\"][s.name][idx].shape[1])) for c in catalog]) for s in surveys}\n",
" else:\n",
Expand Down

0 comments on commit aaff948

Please sign in to comment.