From 55ec673ff54dd37c968962e9761fb42431cf7689 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Tue, 2 Apr 2024 14:42:41 -0400 Subject: [PATCH] catch errors early if deblended catalog is too large, to avoid unhelpful errors for users --- btk/deblend.py | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/btk/deblend.py b/btk/deblend.py index 7c876ea5..2d714bb3 100644 --- a/btk/deblend.py +++ b/btk/deblend.py @@ -185,7 +185,7 @@ class PeakLocalMax(Deblender): """This class detects centroids with `skimage.feature.peak_local_max`. The function performs detection and deblending of the sources based on the provided - band index. If use_mean feature is used, then the measurement function is using + band index. If use_mean feature is used, then the Deblender will use the average of all the bands. """ @@ -197,14 +197,14 @@ def __init__( use_mean: bool = False, use_band: Optional[int] = None, ) -> None: - """Initializes measurement class. Exactly one of 'use_mean' or 'use_band' must be specified. + """Initializes Deblender class. Exactly one of 'use_mean' or 'use_band' must be specified. Args: max_n_sources: See parent class. threshold_scale: Minimum intensity of peaks. min_distance: Minimum distance in pixels between two peaks. - use_mean: Flag to use the band average for the measurement. - use_band: Integer index of the band to use for the measurement. + use_mean: Flag to use the band average for deblending. + use_band: Integer index of the band to use for deblending """ super().__init__(max_n_sources) self.min_distance = min_distance @@ -218,7 +218,7 @@ def __init__( self.use_band = use_band def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample: - """Performs measurement on the ii-th example from the batch.""" + """Performs deblending on the ii-th example from the batch.""" blend_image = blend_batch.blend_images[ii] image = np.mean(blend_image, axis=0) if self.use_mean else blend_image[self.use_band] @@ -240,6 +240,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample: catalog["ra"], catalog["dec"] = ra, dec catalog["x_peak"], catalog["y_peak"] = x, y + if len(catalog) > self.max_n_sources: + raise ValueError( + "`PeakLocalMax` detected more sources than `max_n_sources`. Consider increasing" + "`threshold_scale` or `max_n_sources`." + ) + return DeblendExample(self.max_n_sources, catalog) @@ -261,7 +267,7 @@ def __init__( use_mean: bool = False, use_band: Optional[int] = None, ) -> None: - """Initializes measurement class. Exactly one of 'use_mean' or 'use_band' must be specified. + """Initializes Deblender class. Exactly one of 'use_mean' or 'use_band' must be specified. Args: max_n_sources: See parent class. @@ -270,8 +276,8 @@ def __init__( will be `thresh * err[j, i]` where `err` is set to the global rms of the background measured by SEP. min_area: Minimum number of pixels required for an object. Default is 5. - use_mean: Flag to use the band average for the measurement - use_band: Integer index of the band to use for the measurement + use_mean: Flag to use the band average for deblending. + use_band: Integer index of the band to use for deblending. """ super().__init__(max_n_sources) if use_band is None and not use_mean: @@ -284,7 +290,7 @@ def __init__( self.min_area = min_area def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample: - """Performs measurement on the i-th example from the batch.""" + """Performs deblending on the i-th example from the batch.""" # get a 1-channel input for sep blend_image = blend_batch.blend_images[ii] image = np.mean(blend_image, axis=0) if self.use_mean else blend_image[self.use_band] @@ -299,6 +305,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample: minarea=self.min_area, ) + if len(catalog) > self.max_n_sources: + raise ValueError( + "SEP predicted more sources than `max_n_sources`. Consider increasing `thresh`" + " or `max_n_sources`." + ) + segmentation_exp = np.zeros((self.max_n_sources, *image.shape), dtype=bool) deblended_images = np.zeros((self.max_n_sources, *image.shape), dtype=image.dtype) n_objects = len(catalog) @@ -339,7 +351,7 @@ class SepMultiband(Deblender): """ def __init__(self, max_n_sources: int, matching_threshold: float = 1.0, thresh: float = 1.5): - """Initialize the SepMultiband measurement function. + """Initialize the SepMultiband Deblender. Args: max_n_sources: See parent class. @@ -351,7 +363,7 @@ def __init__(self, max_n_sources: int, matching_threshold: float = 1.0, thresh: self.thresh = thresh def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample: - """Performs measurement on the ii-th example from the batch.""" + """Performs deblending on the ii-th example from the batch.""" # run source extractor on the first band wcs = blend_batch.wcs image = blend_batch.blend_images[ii] @@ -361,6 +373,12 @@ def deblend(self, ii: int, blend_batch: BlendBatch) -> DeblendExample: ra_coordinates *= 3600 dec_coordinates *= 3600 + if len(catalog) > self.max_n_sources: + raise ValueError( + "SEP predicted more sources than `max_n_sources`. Consider increasing `thresh`" + " or `max_n_sources`." + ) + # iterate over remaining bands and match predictions using KdTree for band in range(1, image.shape[0]): # run source extractor @@ -449,7 +467,7 @@ def __init__( def deblend( self, ii: int, blend_batch: BlendBatch, reference_catalogs: Table = None ) -> DeblendExample: - """Performs measurement on the ii-th example from the batch. + """Performs deblending on the ii-th example from the batch. Args: ii: The index of the example in the batch. @@ -556,14 +574,14 @@ def __init__( njobs: int = 1, verbose: bool = False, ): - """Initialize measurement generator. + """Initialize deblender generator. Args: deblenders: Deblender or a list of Deblender that will be used on the outputs of the draw_blend_generator. draw_blend_generator: Instance of subclasses of `DrawBlendsGenerator`. njobs: The number of parallel processes to run [Default: 1]. - verbose: Whether to print information about measurement. + verbose: Whether to print information about deblending. """ self.deblenders = self._validate_deblenders(deblenders) self.deblender_names = self._get_unique_deblender_names() @@ -615,7 +633,7 @@ def _get_unique_deblender_names(self) -> List[str]: return deblender_names def __next__(self) -> Tuple[BlendBatch, Dict[str, DeblendBatch]]: - """Return measurement results on a single batch from the draw_blend_generator. + """Return deblending results on a single batch from the draw_blend_generator. Returns: blend_batch: draw_blend_generator output from its `__next__` method.