From 33ee0b4d8bb40508842044ef1e26508a722b6af1 Mon Sep 17 00:00:00 2001 From: fred3m Date: Fri, 20 Sep 2024 11:09:50 -0700 Subject: [PATCH 1/7] Implement improved wavelet detection --- pyproject.toml | 2 +- python/lsst/scarlet/lite/__init__.py | 2 +- python/lsst/scarlet/lite/bbox.py | 2 +- python/lsst/scarlet/lite/blend.py | 6 +- python/lsst/scarlet/lite/detect.py | 135 ++++++++++++-- python/lsst/scarlet/lite/detect_pybind11.cc | 185 ++++++++++++------- python/lsst/scarlet/lite/initialization.py | 57 +++--- python/lsst/scarlet/lite/models/fit_psf.py | 77 +++++--- python/lsst/scarlet/lite/models/free_form.py | 131 ++++++++++++- python/lsst/scarlet/lite/observation.py | 1 - python/lsst/scarlet/lite/parameters.py | 5 +- python/lsst/scarlet/lite/utils.py | 10 +- python/lsst/scarlet/lite/wavelet.py | 18 +- tests/test_detect.py | 65 ++++++- tests/test_io.py | 4 +- tests/test_models.py | 28 +-- tests/test_wavelet.py | 2 +- 17 files changed, 541 insertions(+), 189 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99dd6532..0950cbad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "lsst-versions >= 1.3.0", "wheel", "pybind11 >= 2.5.0", - "numpy >= 1.18", + "numpy >= 1.18, <2.0.0", "peigen >= 0.0.9", ] build-backend = "setuptools.build_meta" diff --git a/python/lsst/scarlet/lite/__init__.py b/python/lsst/scarlet/lite/__init__.py index eda725ee..135d279d 100644 --- a/python/lsst/scarlet/lite/__init__.py +++ b/python/lsst/scarlet/lite/__init__.py @@ -10,7 +10,7 @@ except ImportError: pass -from . import initialization, io, measure, models, operators, utils +from . import initialization, io, measure, models, operators, utils, wavelet from .fft import * from .image import * from .observation import * diff --git a/python/lsst/scarlet/lite/bbox.py b/python/lsst/scarlet/lite/bbox.py index 63a35746..f8616731 100644 --- a/python/lsst/scarlet/lite/bbox.py +++ b/python/lsst/scarlet/lite/bbox.py @@ -228,7 +228,7 @@ def from_data(x: np.ndarray, threshold: float = 0) -> Box: nonzero = np.where(sel) bounds = [] for dim in range(len(x.shape)): - bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1)) + bounds.append((int(nonzero[dim].min()), int(nonzero[dim].max() + 1))) else: bounds = [(0, 0)] * len(x.shape) return Box.from_bounds(*bounds) diff --git a/python/lsst/scarlet/lite/blend.py b/python/lsst/scarlet/lite/blend.py index 9e9e4537..23355108 100644 --- a/python/lsst/scarlet/lite/blend.py +++ b/python/lsst/scarlet/lite/blend.py @@ -113,7 +113,7 @@ def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image: return self.observation.convolve(model) return model - def _grad_log_likelihood(self) -> Image: + def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]: """Gradient of the likelihood wrt the unconvolved model""" model = self.get_model(convolve=True) # Update the loss @@ -121,7 +121,7 @@ def _grad_log_likelihood(self) -> Image: # Calculate the gradient wrt the model d(logL)/d(model) result = self.observation.weights * (model - self.observation.images) result = self.observation.convolve(result, grad=True) - return result + return result, model.data @property def log_likelihood(self) -> float: @@ -244,7 +244,7 @@ def fit( # Update each component given the current gradient for component in self.components: overlap = component.bbox & self.bbox - component.update(self.it, grad_log_likelihood[overlap].data) + component.update(self.it, grad_log_likelihood[0][overlap].data) # Check to see if any components need to be resized if do_resize: component.resize(self.bbox) diff --git a/python/lsst/scarlet/lite/detect.py b/python/lsst/scarlet/lite/detect.py index a5bac623..6d4f0e6a 100644 --- a/python/lsst/scarlet/lite/detect.py +++ b/python/lsst/scarlet/lite/detect.py @@ -22,15 +22,20 @@ from __future__ import annotations import logging -from typing import Sequence, cast +from typing import Sequence import numpy as np from lsst.scarlet.lite.detect_pybind11 import Footprint # type: ignore -from .bbox import Box +from .bbox import Box, overlapped_slices from .image import Image from .utils import continue_class -from .wavelet import get_multiresolution_support, starlet_transform +from .wavelet import ( + get_multiresolution_support, + get_starlet_scales, + multiband_starlet_reconstruction, + starlet_transform, +) logger = logging.getLogger("scarlet.detect") @@ -111,30 +116,34 @@ def union(self, other: Footprint) -> Image | None: return footprint1 | footprint2 -def footprints_to_image(footprints: Sequence[Footprint], shape: tuple[int, int]) -> Image: +def footprints_to_image(footprints: Sequence[Footprint], bbox: Box) -> Image: """Convert a set of scarlet footprints to a pixelized image. Parameters ---------- footprints: The footprints to convert into an image. - shape: - The shape of the image that is created from the footprints. + box: + The full box of the image that will contain the footprints. Returns ------- result: The image created from the footprints. """ - result = Image.from_box(Box(shape), dtype=int) + result = Image.from_box(bbox, dtype=int) for k, footprint in enumerate(footprints): - bbox = bounds_to_bbox(footprint.bounds) - fp_image = Image(footprint.data, yx0=cast(tuple[int, int], bbox.origin)) - result = result + fp_image * (k + 1) + slices = overlapped_slices(result.bbox, footprint.bbox) + result.data[slices[0]] += footprint.data[slices[1]] * (k + 1) return result -def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None = None) -> np.ndarray: +def get_wavelets( + images: np.ndarray, + variance: np.ndarray, + scales: int | None = None, + generation: int = 2, +) -> np.ndarray: """Calculate wavelet coefficents given a set of images and their variances Parameters @@ -157,9 +166,10 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None = """ sigma = np.median(np.sqrt(variance), axis=(1, 2)) # Create the wavelet coefficients for the significant pixels - coeffs = [] + scales = get_starlet_scales(images[0].shape, scales) + coeffs = np.empty((scales + 1,) + images.shape, dtype=images.dtype) for b, image in enumerate(images): - _coeffs = starlet_transform(image, scales=scales) + _coeffs = starlet_transform(image, scales=scales, generation=generation) support = get_multiresolution_support( image=image, starlets=_coeffs, @@ -168,8 +178,8 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None = epsilon=1e-1, max_iter=20, ) - coeffs.append((support * _coeffs).astype(images.dtype)) - return np.array(coeffs) + coeffs[:, b] = (support.support * _coeffs).astype(images.dtype) + return coeffs def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = 3) -> np.ndarray: @@ -206,4 +216,97 @@ def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = epsilon=1e-1, max_iter=20, ) - return (support * _coeffs).astype(images.dtype) + return (support.support * _coeffs).astype(images.dtype) + + +def detect_footprints( + images: np.ndarray, + variance: np.ndarray, + scales: int = 2, + generation: int = 2, + origin: tuple[int, int] | None = None, + min_separation: float = 4, + min_area: int = 4, + peak_thresh: float = 5, + footprint_thresh: float = 5, + find_peaks: bool = True, + remove_high_freq: bool = True, + min_pixel_detect: int = 1, +) -> list[Footprint]: + """Detect footprints in an image + + Parameters + ---------- + images: + The array of images with shape `(bands, Ny, Nx)` for which to + calculate wavelet coefficients. + variance: + An array of variances with the same shape as `images`. + scales: + The maximum number of wavelet scales to use. + If `remove_high_freq` is `False`, then this argument is ignored. + generation: + The generation of the starlet transform to use. + If `remove_high_freq` is `False`, then this argument is ignored. + origin: + The location (y, x) of the lower corner of the image. + min_separation: + The minimum separation between peaks in pixels. + min_area: + The minimum area of a footprint in pixels. + peak_thresh: + The threshold for peak detection. + footprint_thresh: + The threshold for footprint detection. + find_peaks: + If `True`, then detect peaks in the detection image, + otherwise only the footprints are returned. + remove_high_freq: + If `True`, then remove high frequency wavelet coefficients + before detecting peaks. + min_pixel_detect: + The minimum number of bands that must be above the + detection threshold for a pixel to be included in a footprint. + """ + from lsst.scarlet.lite.detect_pybind11 import get_footprints + + if origin is None: + origin = (0, 0) + if remove_high_freq: + # Build the wavelet coefficients + wavelets = get_wavelets( + images, + variance, + scales=scales, + generation=generation, + ) + # Remove the high frequency wavelets. + # This has the effect of preventing high frequency noise + # from interfering with the detection of peak positions. + wavelets[0] = 0 + # Reconstruct the image from the remaining wavelet coefficients + _images = multiband_starlet_reconstruction( + wavelets, + generation=generation, + ) + else: + _images = images + # Build a SNR weighted detection image + sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2 + detection = np.sum(_images / sigma[:, None, None], axis=0) + if min_pixel_detect > 1: + mask = np.sum(images > 0, axis=0) >= min_pixel_detect + detection[~mask] = 0 + # Detect peaks on the detection image + footprints = get_footprints( + detection, + min_separation, + min_area, + peak_thresh, + footprint_thresh, + find_peaks, + origin[0], + origin[1], + ) + + return footprints diff --git a/python/lsst/scarlet/lite/detect_pybind11.cc b/python/lsst/scarlet/lite/detect_pybind11.cc index b05d578e..e64f58d4 100644 --- a/python/lsst/scarlet/lite/detect_pybind11.cc +++ b/python/lsst/scarlet/lite/detect_pybind11.cc @@ -4,6 +4,12 @@ #include #include #include +#include +#include +#include +#include // For std::pair +#include +#include namespace py = pybind11; using namespace pybind11::literals; @@ -16,44 +22,53 @@ typedef Eigen::Matrix Mat // located at `i,j` and create the bounding box for the `footprint` in `image`. template void get_connected_pixels( - const int i, - const int j, - Eigen::Ref image, - Eigen::Ref> unchecked, - Eigen::Ref> footprint, - Eigen::Ref> bounds, + const int start_i, + const int start_j, + py::EigenDRef image, + py::EigenDRef unchecked, + py::EigenDRef footprint, + Eigen::Ref bounds, const double thresh=0 ){ - if(not unchecked(i,j)){ - return; - } - unchecked(i,j) = false; + std::stack> stack; + stack.push(std::make_pair(start_i, start_j)); - if(image(i,j) > thresh){ - footprint(i,j) = true; + while (!stack.empty()) { + int i, j; + std::tie(i, j) = stack.top(); + stack.pop(); - if(i < bounds[0]){ - bounds[0] = i; - } else if(i > bounds[1]){ - bounds[1] = i; - } - if(j < bounds[2]){ - bounds[2] = j; - } else if(j > bounds[3]){ - bounds[3] = j; + if (!unchecked(i, j)) { + continue; } + unchecked(i, j) = false; - if(i > 0){ - get_connected_pixels(i-1, j, image, unchecked, footprint, bounds, thresh); - } - if(i < image.rows()-1){ - get_connected_pixels(i+1, j, image, unchecked, footprint, bounds, thresh); - } - if(j > 0){ - get_connected_pixels(i, j-1, image, unchecked, footprint, bounds, thresh); - } - if(j < image.cols()-1){ - get_connected_pixels(i, j+1, image, unchecked, footprint, bounds, thresh); + if (image(i, j) > thresh) { + footprint(i, j) = true; + + if (i < bounds[0]) { + bounds[0] = i; + } else if (i > bounds[1]) { + bounds[1] = i; + } + if (j < bounds[2]) { + bounds[2] = j; + } else if (j > bounds[3]) { + bounds[3] = j; + } + + if (i > 0 && unchecked(i-1, j)) { + stack.push(std::make_pair(i-1, j)); + } + if (i < image.rows() - 1 && unchecked(i+1, j)) { + stack.push(std::make_pair(i+1, j)); + } + if (j > 0 && unchecked(i, j-1)) { + stack.push(std::make_pair(i, j-1)); + } + if (j < image.cols() - 1 && unchecked(i, j+1)) { + stack.push(std::make_pair(i, j+1)); + } } } } @@ -62,20 +77,43 @@ void get_connected_pixels( /// Proximal operator to trim pixels not connected to one of the source centers. template MatrixB get_connected_multipeak( - Eigen::Ref image, - const std::vector> centers, + py::EigenDRef image, + const std::vector>& centers, const double thresh=0 ){ const int height = image.rows(); const int width = image.cols(); - MatrixB unchecked = MatrixB::Ones(height, width); MatrixB footprint = MatrixB::Zero(height, width); + std::queue> pixel_queue; + + // Seed the queue with peaks + for(const auto& center : centers){ + const int y = center[0]; + const int x = center[1]; + if (!footprint(y, x) && image(y, x) > thresh) { + footprint(y, x) = true; + pixel_queue.emplace(y, x); + } + } - for(auto center=begin(centers); center!=end(centers); ++center){ - const int y = (*center)[0]; - const int x = (*center)[1]; - Bounds bounds; bounds << y, y, x, x; - get_connected_pixels(y, x, image, unchecked, footprint, bounds, thresh); + // 4-connectivity offsets + const std::vector> offsets = {{-1, 0}, {1, 0}, {0, -1}, {0, 1}}; + + // Flood fill + while (!pixel_queue.empty()) { + auto [i, j] = pixel_queue.front(); + pixel_queue.pop(); + + for (const auto& [di, dj] : offsets) { + int ni = i + di; + int nj = j + dj; + if (ni >= 0 && ni < height && nj >= 0 && nj < width) { + if (!footprint(ni, nj) && image(ni, nj) > thresh) { + footprint(ni, nj) = true; + pixel_queue.emplace(ni, nj); + } + } + } } return footprint; @@ -128,6 +166,7 @@ template std::vector get_peaks( M& image, const double min_separation, + const double peak_thresh, const int y0, const int x0 ){ @@ -138,6 +177,9 @@ std::vector get_peaks( for(int i=0; i 0 && image(i, j) <= image(i-1, j)){ continue; } @@ -168,33 +210,29 @@ std::vector get_peaks( } } - assert(peaks.size() > 0); + if(peaks.empty()){ + return peaks; + } /// Sort the peaks in the footprint so that the brightest are first std::sort (peaks.begin(), peaks.end(), sortBrightness); // Remove peaks within min_separation double min_separation2 = min_separation * min_separation; - int i = 0; - while (i < peaks.size()-1){ - int j = i+1; - Peak *p1 = &peaks[i]; - while (j < peaks.size()){ + for (size_t i = 0; i < peaks.size() - 1; ++i) { + for (size_t j = i + 1; j < peaks.size();) { + Peak *p1 = &peaks[i]; Peak *p2 = &peaks[j]; - double dy = p1->getY()-p2->getY(); - double dx = p1->getX()-p2->getX(); + double dy = p1->getY() - p2->getY(); + double dx = p1->getX() - p2->getX(); double separation2 = dy*dy + dx*dx; - if(separation2 < min_separation2){ - peaks.erase(peaks.begin()+j); - i--; + if (separation2 < min_separation2) { + peaks.erase(peaks.begin() + j); + } else { + ++j; } - j++; } - i++; } - - assert(peaks.size() > 0); - return peaks; } @@ -226,8 +264,8 @@ class Footprint { template void maskImage( - Eigen::Ref> image, - Eigen::Ref> footprint + py::EigenDRef image, + py::EigenDRef footprint ){ const int height = image.rows(); const int width = image.cols(); @@ -244,11 +282,14 @@ void maskImage( template std::vector get_footprints( - Eigen::Ref image, + py::EigenDRef image, const double min_separation, const int min_area, - const double thresh, - const bool find_peaks=true + const double peak_thresh, + const double footprint_thresh, + const bool find_peaks=true, + const int y0=0, + const int x0=0 ){ const int height = image.rows(); const int width = image.cols(); @@ -260,7 +301,7 @@ std::vector get_footprints( for(int i=0; i min_area){ @@ -274,12 +315,18 @@ std::vector get_footprints( _peaks = get_peaks( patch, min_separation, - bounds[0], - bounds[2] + peak_thresh, + bounds[0] + y0, + bounds[2] + x0 ); } - - footprints.push_back(Footprint(subFootprint, _peaks, bounds)); + // Only add footprints that have at least one peak above the + // minimum peak_thresh. + if(!_peaks.empty() || !find_peaks){ + Bounds trueBounds; trueBounds << bounds[0] + y0, + bounds[1] + y0, bounds[2] + x0, bounds[3] + x0; + footprints.push_back(Footprint(subFootprint, _peaks, trueBounds)); + } } } footprint.block(bounds[0], bounds[2], subHeight, subWidth) = MatrixB::Zero(subHeight, subWidth); @@ -317,10 +364,12 @@ PYBIND11_MODULE(detect_pybind11, mod) { mod.def("get_footprints", &get_footprints, "Create a list of all of the footprints in an image, with their peaks" - "image"_a, "min_separation"_a, "min_area"_a, "thresh"_a, "find_peaks"_a); + "image"_a, "min_separation"_a, "min_area"_a, "peak_thresh"_a, "footprint_thresh"_a, + "find_peaks"_a=true, "y0"_a=0, "x0"_a=0); mod.def("get_footprints", &get_footprints, "Create a list of all of the footprints in an image, with their peaks" - "image"_a, "min_separation"_a, "min_area"_a, "thresh"_a, "find_peaks"_a); + "image"_a, "min_separation"_a, "min_area"_a, "peak_thresh"_a, "footprint_thresh"_a, + "find_peaks"_a=true, "y0"_a=0, "x0"_a=0); py::class_(mod, "Footprint") .def(py::init, Bounds>(), diff --git a/python/lsst/scarlet/lite/initialization.py b/python/lsst/scarlet/lite/initialization.py index 55e9127d..452f3a7f 100644 --- a/python/lsst/scarlet/lite/initialization.py +++ b/python/lsst/scarlet/lite/initialization.py @@ -526,34 +526,37 @@ def init_source(self, center: tuple[int, int]) -> Source | None: # Fit the spectra assuming that all of the flux in the image # is due to both components. This is not true, but for the # vast majority of sources this is a good approximation. - bulge_spectrum, disk_spectrum = multifit_spectra( - self.observation, - [ - Image(bulge_morph, yx0=cast(tuple[int, int], component.bbox.origin)), - Image(disk_morph, yx0=cast(tuple[int, int], component.bbox.origin)), - ], - ) + try: + bulge_spectrum, disk_spectrum = multifit_spectra( + self.observation, + [ + Image(bulge_morph, yx0=cast(tuple[int, int], component.bbox.origin)), + Image(disk_morph, yx0=cast(tuple[int, int], component.bbox.origin)), + ], + ) - components = [ - FactorizedComponent( - self.observation.bands, - bulge_spectrum, - bulge_morph, - component.bbox.copy(), - center, - self.observation.noise_rms, - monotonicity=self.monotonicity, - ), - FactorizedComponent( - self.observation.bands, - disk_spectrum, - disk_morph, - component.bbox.copy(), - center, - self.observation.noise_rms, - monotonicity=self.monotonicity, - ), - ] + components = [ + FactorizedComponent( + self.observation.bands, + bulge_spectrum, + bulge_morph, + component.bbox.copy(), + center, + self.observation.noise_rms, + monotonicity=self.monotonicity, + ), + FactorizedComponent( + self.observation.bands, + disk_spectrum, + disk_morph, + component.bbox.copy(), + center, + self.observation.noise_rms, + monotonicity=self.monotonicity, + ), + ] + except np.linalg.LinAlgError: + components = [component] return Source(components) # type: ignore diff --git a/python/lsst/scarlet/lite/models/fit_psf.py b/python/lsst/scarlet/lite/models/fit_psf.py index a2db1099..9f7b5e24 100644 --- a/python/lsst/scarlet/lite/models/fit_psf.py +++ b/python/lsst/scarlet/lite/models/fit_psf.py @@ -27,7 +27,8 @@ from ..bbox import Box from ..blend import Blend -from ..fft import Fourier, get_fft_shape +from ..fft import Fourier, centered +from ..fft import convolve as fft_convolve from ..image import Image from ..observation import Observation from ..parameters import parameter @@ -48,6 +49,7 @@ def __init__( bands: tuple | None = None, padding: int = 3, convolution_mode: str = "fft", + shape: tuple[int, int] | None = None, ): """Initialize a `FitPsfObservation` @@ -68,18 +70,24 @@ def __init__( self.axes = (-2, -1) - self.fft_shape = get_fft_shape(self.images.data[0], psfs[0], padding, self.axes) + if shape is None: + shape = (41, 41) # Make the DFT of the psf a fittable parameter - self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).fft(self.fft_shape, self.axes)) + self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).image) + + def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray: + grad = cast( + np.ndarray, + fft_convolve( + Fourier(model), + Fourier(input_grad[:, ::-1, ::-1]), + axes=(1, 2), + return_fourier=False, + ), + ) - def grad_fit_kernel( - self, input_grad: np.ndarray, kernel: np.ndarray, model_fft: np.ndarray - ) -> np.ndarray: - # Transform the upstream gradient into k-space - grad_fft = Fourier(input_grad) - _grad_fft = grad_fft.fft(self.fft_shape, self.axes) - return _grad_fft * model_fft + return centered(grad, psf.shape) def prox_kernel(self, kernel: np.ndarray) -> np.ndarray: # No prox for now @@ -91,7 +99,7 @@ def fitted_kernel(self) -> np.ndarray: @property def cached_kernel(self): - return self.fitted_kernel.real - self.fitted_kernel.imag * 1j + return self.fitted_kernel[:, ::-1, ::-1] def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> Image: """Convolve the model into the observed seeing in each band. @@ -117,16 +125,16 @@ def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> if mode != "fft" and mode is not None: return super().convolve(image, mode, grad) - fft_image = Fourier(image.data) - fft = fft_image.fft(self.fft_shape, self.axes) - - result = Fourier.from_fft(fft * kernel, self.fft_shape, image.shape, self.axes) - return Image(result.image, bands=image.bands, yx0=image.yx0) + result = fft_convolve( + Fourier(image.data), + Fourier(kernel), + axes=(1, 2), + return_fourier=False, + ) + return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0) - def update(self, it: int, input_grad: np.ndarray, model: Image): - _model = Fourier(model.data[:, ::-1, ::-1]) - model_fft = _model.fft(self.fft_shape, self.axes) - self._fitted_kernel.update(it, input_grad, model_fft) + def update(self, it: int, input_grad: np.ndarray, model: np.ndarray): + self._fitted_kernel.update(it, input_grad, model) def parameterize(self, parameterization: Callable) -> None: """Convert the component parameter arrays into Parameter instances @@ -138,7 +146,7 @@ def parameterize(self, parameterization: Callable) -> None: a `Parameter` in place. It should take a single argument that is the `Component` or `Source` that is to be parameterized. """ - # Update the spectrum and morph in place + # Update the fitted kernel in place parameterization(self) # update the parameters self._fitted_kernel.grad = self.grad_fit_kernel @@ -148,14 +156,15 @@ def parameterize(self, parameterization: Callable) -> None: class FittedPsfBlend(Blend): """A blend that attempts to fit the PSF along with the source models.""" - def _grad_log_likelihood(self) -> Image: + def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]: """Gradient of the likelihood wrt the unconvolved model""" model = self.get_model(convolve=True) # Update the loss self.loss.append(self.observation.log_likelihood(model)) # Calculate the gradient wrt the model d(logL)/d(model) - result = self.observation.weights * (model - self.observation.images) - return result + residual = self.observation.weights * (model - self.observation.images) + + return residual, model.data def fit( self, @@ -182,7 +191,7 @@ def fit( it = self.it while it < max_iter: # Calculate the gradient wrt the on-convolved model - grad_log_likelihood = self._grad_log_likelihood() + grad_log_likelihood, model = self._grad_log_likelihood() _grad_log_likelihood = self.observation.convolve(grad_log_likelihood, grad=True) # Check if resizing needs to be performed in this iteration if resize is not None and self.it > 0 and self.it % resize == 0: @@ -199,7 +208,9 @@ def fit( # Update the PSF cast(FittedPsfObservation, self.observation).update( - it, grad_log_likelihood.data, self.get_model() + self.it, + grad_log_likelihood.data, + model, ) # Stopping criteria it += 1 @@ -207,3 +218,17 @@ def fit( break self.it = it return it, self.loss[-1] + + def parameterize(self, parameterization: Callable): + """Convert the component parameter arrays into Parameter instances + + Parameters + ---------- + parameterization: + A function to use to convert parameters of a given type into + a `Parameter` in place. It should take a single argument that + is the `Component` or `Source` that is to be parameterized. + """ + for source in self.sources: + source.parameterize(parameterization) + cast(FittedPsfObservation, self.observation).parameterize(parameterization) diff --git a/python/lsst/scarlet/lite/models/free_form.py b/python/lsst/scarlet/lite/models/free_form.py index 2c419087..d9e31c89 100644 --- a/python/lsst/scarlet/lite/models/free_form.py +++ b/python/lsst/scarlet/lite/models/free_form.py @@ -19,19 +19,20 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -__all__ = ["FreeFormComponent"] +__all__ = ["FactorizedFreeFormComponent"] -from typing import cast +from typing import Callable, cast import numpy as np from ..bbox import Box -from ..component import FactorizedComponent +from ..component import Component, FactorizedComponent from ..detect import footprints_to_image -from ..parameters import Parameter +from ..image import Image +from ..parameters import Parameter, parameter -class FreeFormComponent(FactorizedComponent): +class FactorizedFreeFormComponent(FactorizedComponent): """Implements a free-form component With no constraints this component is typically either a garbage collector, @@ -109,11 +110,14 @@ def prox_morph(self, morph: np.ndarray) -> np.ndarray: morph[morph < 0] = 0 if self.peaks is not None: - morph = morph * get_connected_multipeak(morph > 0, self.peaks, 0) + footprint = get_connected_multipeak(morph, self.peaks, 0) + morph = morph * footprint if self.min_area > 0: - footprints = get_footprints(morph > 0, 4.0, self.min_area, 0, False) - footprint_image = footprints_to_image(footprints, cast(tuple[int, int], morph.shape)) + footprints = get_footprints(morph, 4.0, self.min_area, 0, 0, False) + bbox = self.bbox.copy() + bbox.origin = (0, 0) + footprint_image = footprints_to_image(footprints, bbox) morph = morph * (footprint_image > 0).data if np.all(morph == 0): @@ -126,10 +130,119 @@ def resize(self, model_box: Box) -> bool: def __str__(self): return ( - f"FreeFormComponent(\n bands={self.bands}\n " + f"FactorizedFreeFormComponent(\n bands={self.bands}\n " f"spectrum={self.spectrum})\n center={self.peak}\n " f"morph_shape={self.morph.shape}" ) def __repr__(self): return self.__str__() + + +class FreeFormComponent(Component): + """Implements a free-form component + + This is a FreeFormComponent that is not factorized into a + spectrum and morphology with no monotonicity constraint. + """ + + def __init__( + self, + bands: tuple, + model: np.ndarray | Parameter, + model_bbox: Box, + bg_thresh: float | None = None, + bg_rms: np.ndarray | None = None, + floor: float = 1e-20, + peaks: list[tuple[int, int]] | None = None, + min_area: float = 0, + ): + if len(bands) != 1: + raise ValueError("MonochromaticDeconvolvedComponent only supports one band") + super().__init__(bands=bands, bbox=model_bbox) + self._model = parameter(model) + self.bg_rms = bg_rms + self.bg_thresh = bg_thresh + self.floor = floor + self.peaks = peaks + self.min_area = min_area + + @property + def model(self) -> np.ndarray: + """The morphological model of the component""" + return self._model.x + + def get_model(self) -> Image: + """Convert the model into an image""" + return Image(self.model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) + + @property + def shape(self) -> tuple: + """Shape of the resulting model image""" + return self.model.shape + + def grad_model(self, input_grad: np.ndarray, model: np.ndarray) -> np.ndarray: + """Gradient of the morph wrt. the component model""" + return input_grad + + def prox_model(self, model: np.ndarray) -> np.ndarray: + """Apply a prox-like update to the model""" + from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore + + if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray): + bg_thresh = self.bg_rms * self.bg_thresh + # Enforce background thresholding + model[model < bg_thresh[:, None, None]] = 0 + else: + # enforce positivity + model[model < 0] = 0 + + if self.peaks is not None: + # Remove pixels not connected to one of the peaks + model2d = np.sum(model, axis=0) + footprint = get_connected_multipeak(model2d, self.peaks, 0) + model = model * footprint[None, :, :] + + if self.min_area > 0: + # Remove regions with fewer than min_area connected pixels + model2d = np.sum(model, axis=0) + footprints = get_footprints(model2d, 4.0, self.min_area, 0, 0, False) + bbox = self.bbox.copy() + bbox.origin = (0, 0) + footprint_image = footprints_to_image(footprints, bbox) + model = model * (footprint_image > 0).data[None, :, :] + + if np.all(model == 0): + # If the model is all zeros, set a single pixel to the floor + model[0, 0] = self.floor + + return model + + def resize(self, model_box: Box) -> bool: + return False + + def update(self, it: int, grad_log_likelihood: np.ndarray): + self._model.update(it, grad_log_likelihood) + + def parameterize(self, parameterization: Callable) -> None: + """Convert the component parameter arrays into Parameter instances + + Parameters + ---------- + parameterization: Callable + A function to use to convert parameters of a given type into + a `Parameter` in place. It should take a single argument that + is the `Component` or `Source` that is to be parameterized. + """ + # Update the spectrum and morph in place + parameterization(self) + # update the parameters + self._model.grad = self.grad_model + self._model.prox = self.prox_model + + def __str__(self): + result = f"FreeFormComponent" + return result + + def __repr__(self): + return self.__str__() diff --git a/python/lsst/scarlet/lite/observation.py b/python/lsst/scarlet/lite/observation.py index 1b7ae4a8..02b03dbc 100644 --- a/python/lsst/scarlet/lite/observation.py +++ b/python/lsst/scarlet/lite/observation.py @@ -63,7 +63,6 @@ def get_filter_coords(filter_values: np.ndarray, center: tuple[int, int] | None calculate `coords` on your own.""" raise ValueError(msg) center = tuple([filter_values.shape[0] // 2, filter_values.shape[1] // 2]) # type: ignore - center = cast(tuple[int, int], center) x = np.arange(filter_values.shape[1]) y = np.arange(filter_values.shape[0]) x, y = np.meshgrid(x, y) diff --git a/python/lsst/scarlet/lite/parameters.py b/python/lsst/scarlet/lite/parameters.py index a64fb2c0..5790c4fb 100644 --- a/python/lsst/scarlet/lite/parameters.py +++ b/python/lsst/scarlet/lite/parameters.py @@ -213,7 +213,10 @@ def update(self, it: int, input_grad: np.ndarray, *args): See `Parameter` for the full description. """ - step = self.step / np.sum(args[0] * args[0]) + if len(args) == 0: + step = self.step + else: + step = self.step / np.sum(args[0] * args[0]) _x = self.x _z = self.helpers["z"] diff --git a/python/lsst/scarlet/lite/utils.py b/python/lsst/scarlet/lite/utils.py index 7aeca378..30f4a41c 100644 --- a/python/lsst/scarlet/lite/utils.py +++ b/python/lsst/scarlet/lite/utils.py @@ -48,9 +48,9 @@ def integrated_gaussian_value(x: np.ndarray, sigma: float) -> np.ndarray: gaussian: A Gaussian function integrated over `x` """ - lhs = erfc((0.5 - x) / (sqrt2 * sigma)) - rhs = erfc((2 * x + 1) / (2 * sqrt2 * sigma)) - return np.sqrt(np.pi / 2) * sigma * (1 - lhs + 1 - rhs) + lhs = erfc((x - 0.5) / (sqrt2 * sigma)) + rhs = erfc((x + 0.5) / (sqrt2 * sigma)) + return np.sqrt(np.pi) * 0.5 * sigma * (lhs - rhs) def integrated_circular_gaussian( @@ -87,7 +87,9 @@ def integrated_circular_gaussian( elif y is None: raise ValueError(f"Either X and Y must be specified, or neither must be specified, got {x=} and {y=}") - result = integrated_gaussian_value(x, sigma)[None, :] * integrated_gaussian_value(y, sigma)[:, None] + _x = integrated_gaussian_value(np.abs(x), sigma)[None, :] + _y = integrated_gaussian_value(np.abs(y), sigma)[:, None] + result = _x * _y return result / np.sum(result) diff --git a/python/lsst/scarlet/lite/wavelet.py b/python/lsst/scarlet/lite/wavelet.py index c604a048..daaa7d78 100644 --- a/python/lsst/scarlet/lite/wavelet.py +++ b/python/lsst/scarlet/lite/wavelet.py @@ -27,6 +27,7 @@ "get_multiresolution_support", ] +from dataclasses import dataclass from typing import Callable, Sequence import numpy as np @@ -53,7 +54,7 @@ def bspline_convolve(image: np.ndarray, scale: int) -> np.ndarray: The result of convolving the `image` with the spline. """ # Filter for the scarlet transform. Here bspline - h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]) + h1d = np.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]).astype(image.dtype) j = scale slice0 = slice(None, -(2 ** (j + 1))) @@ -226,13 +227,19 @@ def multiband_starlet_reconstruction( See `starlet_reconstruction` for a description of the remainder of the parameters. """ - scales, bands, width, height = starlets.shape + _, bands, width, height = starlets.shape result = np.zeros((bands, width, height), dtype=starlets.dtype) for band in range(bands): result[band] = starlet_reconstruction(starlets[:, band], generation=generation, convolve2d=convolve2d) return result +@dataclass +class MultiResolutionSupport: + support: np.ndarray + sigma: np.ndarray + + def get_multiresolution_support( image: np.ndarray, starlets: np.ndarray, @@ -241,7 +248,7 @@ def get_multiresolution_support( epsilon: float = 1e-1, max_iter: int = 20, image_type: str = "ground", -) -> np.ndarray: +) -> MultiResolutionSupport: """Calculate the multi-resolution support for a dictionary of starlet coefficients. @@ -302,6 +309,7 @@ def get_multiresolution_support( if np.abs(sigma_i - last_sigma_i) / sigma_i < epsilon: break last_sigma_i = sigma_i + sigma_j = sigma_je else: # Sigma to use for significance at each scale # Initially we use the input `sigma` @@ -322,7 +330,7 @@ def get_multiresolution_support( last_sigma_j = sigma_j # noinspection PyUnboundLocalVariable - return m.astype(int) + return MultiResolutionSupport(support=m.astype(int), sigma=sigma_j) def apply_wavelet_denoising( @@ -380,7 +388,7 @@ def apply_wavelet_denoising( for n in range(max_iter): coeffs = starlet_transform(x) - x = x + starlet_reconstruction(support * (image_coeffs - coeffs)) + x = x + starlet_reconstruction(support.support * (image_coeffs - coeffs)) if positive: x[x < 0] = 0 return x diff --git a/tests/test_detect.py b/tests/test_detect.py index 022ce631..46e5eb67 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -23,7 +23,13 @@ import numpy as np from lsst.scarlet.lite import Box, Image -from lsst.scarlet.lite.detect import bounds_to_bbox, footprints_to_image, get_detect_wavelets, get_wavelets +from lsst.scarlet.lite.detect import ( + bounds_to_bbox, + detect_footprints, + footprints_to_image, + get_detect_wavelets, + get_wavelets, +) from lsst.scarlet.lite.detect_pybind11 import ( Footprint, Peak, @@ -117,8 +123,7 @@ def test_connected(self): truth[30:32, 40] = False assert_array_equal(footprint, truth) - def test_get_footprints(self): - footprints = get_footprints(self.image.data, 1, 4, 1e-15, True) + def _footprint_check(self, footprints): self.assertEqual(len(footprints), 3) # The first footprint has a single peak @@ -147,9 +152,49 @@ def test_get_footprints(self): truth = 1 * self.sources[3] + 2 * (self.sources[0] + self.sources[1]) + 3 * self.sources[2] truth.data[truth.data < 1e-15] = 0 - fp_image = footprints_to_image(footprints, truth.shape) + fp_image = footprints_to_image(footprints, truth.bbox) assert_array_equal(fp_image, truth.data) + def test_get_footprints(self): + footprints = get_footprints(self.image.data, 1, 4, 1e-15, 1e-15, True) + self._footprint_check(footprints) + + def test_detect_footprints(self): + # this method doesn't test for accurracy, just that all of the + # different configuration options run without error. + + # There is no variance, so we set it to ones + variance = np.ones(self.image.shape, dtype=self.image.dtype) + + detect_footprints( + self.image.data[None, :, :], + variance[None, :, :], + scales=1, + generation=2, + origin=(0, 0), + min_separation=1, + min_area=4, + peak_thresh=1e-15, + footprint_thresh=1e-15, + find_peaks=True, + remove_high_freq=False, + min_pixel_detect=1, + ) + + detect_footprints( + self.image.data[None, :, :], + variance[None, :, :], + scales=1, + generation=1, + min_separation=1, + min_area=4, + peak_thresh=1e-15, + footprint_thresh=1e-15, + find_peaks=True, + remove_high_freq=True, + min_pixel_detect=2, + ) + def test_bounds_to_bbox(self): bounds = (3, 27, 11, 52) truth = Box((25, 42), (3, 11)) @@ -161,26 +206,28 @@ def test_footprint(self): footprint[footprint < 1e-15] = 0 bounds = [ self.sources[0].bbox.start[0], - self.sources[0].bbox.stop[0], + self.sources[0].bbox.stop[0] - 1, self.sources[0].bbox.start[1], - self.sources[0].bbox.stop[1], + self.sources[0].bbox.stop[1] - 1, ] + print(bounds) peaks = [Peak(self.centers[0][0], self.centers[0][1], self.image.data[self.centers[0]])] footprint1 = Footprint(footprint, peaks, bounds) footprint = self.sources[1].data footprint[footprint < 1e-15] = 0 bounds = [ self.sources[1].bbox.start[0], - self.sources[1].bbox.stop[0], + self.sources[1].bbox.stop[0] - 1, self.sources[1].bbox.start[1], - self.sources[1].bbox.stop[1], + self.sources[1].bbox.stop[1] - 1, ] + print(bounds) peaks = [Peak(self.centers[1][0], self.centers[1][1], self.image.data[self.centers[1]])] footprint2 = Footprint(footprint, peaks, bounds) truth = self.sources[0] + self.sources[1] truth.data[truth.data < 1e-15] = 0 - image = footprints_to_image([footprint1, footprint2], truth.shape) + image = footprints_to_image([footprint1, footprint2], truth.bbox) assert_array_equal(image, truth.data) # Test intersection diff --git a/tests/test_io.py b/tests/test_io.py index 6207de8a..a654b7e2 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -25,7 +25,7 @@ import numpy as np from lsst.scarlet.lite import Blend, Image, Observation, io from lsst.scarlet.lite.initialization import FactorizedChi2Initialization -from lsst.scarlet.lite.models.free_form import FreeFormComponent +from lsst.scarlet.lite.models.free_form import FactorizedFreeFormComponent from lsst.scarlet.lite.operators import Monotonicity from lsst.scarlet.lite.utils import integrated_circular_gaussian from numpy.testing import assert_almost_equal @@ -101,7 +101,7 @@ def test_cube_component(self): blend.sources[i].peak_id = i component = blend.sources[-1].components[-1] # Replace one of the components with a Free-Form component. - blend.sources[-1].components[-1] = FreeFormComponent( + blend.sources[-1].components[-1] = FactorizedFreeFormComponent( bands=self.observation.bands, spectrum=component.spectrum, morph=component.morph, diff --git a/tests/test_models.py b/tests/test_models.py index 328c39b0..19353db5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,19 +26,15 @@ import lsst.scarlet.lite.models as models import numpy as np from lsst.scarlet.lite import Blend, Box, FistaParameter, Image, Observation, Source -from lsst.scarlet.lite.component import ( - Component, - default_adaprox_parameterization, - default_fista_parameterization, -) +from lsst.scarlet.lite.component import Component, FactorizedComponent, default_adaprox_parameterization from lsst.scarlet.lite.initialization import FactorizedChi2Initialization from lsst.scarlet.lite.models import ( CartesianFrame, EllipseFrame, EllipticalParametricComponent, + FactorizedFreeFormComponent, FittedPsfBlend, FittedPsfObservation, - FreeFormComponent, ParametricComponent, ) from lsst.scarlet.lite.operators import Monotonicity @@ -89,7 +85,7 @@ def test_free_form_component(self): # Test with no thresholding (sparsity) sources = [] for i in range(5): - component = FreeFormComponent( + component = FactorizedFreeFormComponent( self.observation.bands, np.ones(5), images[i].copy(), @@ -104,7 +100,7 @@ def test_free_form_component(self): # Test with thresholding (sparsity) sources = [] for i in range(5): - component = FreeFormComponent( + component = FactorizedFreeFormComponent( self.observation.bands, np.ones(5), images[i].copy(), @@ -123,7 +119,7 @@ def test_free_form_component(self): sources = [] peaks = list(np.array([self.data["catalog"]["y"], self.data["catalog"]["x"]]).T.astype(int)) for i in range(5): - component = FreeFormComponent( + component = FactorizedFreeFormComponent( self.observation.bands, np.ones(5), images[i].copy(), @@ -403,12 +399,16 @@ def test_psf_fitting(self): bands=self.data["filters"], ) - def obs_params(cls): - if isinstance(cls, FittedPsfObservation): - cls._fitted_kernel = FistaParameter(cls._fitted_kernel.x, step=1e-2) + def fista_parameterization(component: Component): + if isinstance(component, FactorizedComponent): + component._spectrum = FistaParameter(component.spectrum, step=0.5) + component._morph = FistaParameter(component.morph, step=0.5) + else: + if isinstance(component, FittedPsfObservation): + component._fitted_kernel = FistaParameter(component._fitted_kernel.x, step=1e-2) init = FactorizedChi2Initialization(observation, self.centers, monotonicity=monotonicity) blend = FittedPsfBlend(init.sources, observation).fit_spectra() - blend.parameterize(default_fista_parameterization) - cast(FittedPsfObservation, blend.observation).parameterize(obs_params) + blend.parameterize(fista_parameterization) + assert isinstance(cast(FittedPsfObservation, blend.observation)._fitted_kernel, FistaParameter) blend.fit(12, e_rel=1e-4) diff --git a/tests/test_wavelet.py b/tests/test_wavelet.py index e3113978..4e47565f 100644 --- a/tests/test_wavelet.py +++ b/tests/test_wavelet.py @@ -54,7 +54,7 @@ def test_transform_inverse(self): # Test inverse inverse = starlet_reconstruction(starlets) assert_almost_equal(inverse, image, decimal=5) - self.assertEqual(inverse.dtype, np.float32) + self.assertEqual(inverse.dtype, starlets.dtype) # Test using gen1 starlets starlets = starlet_transform(image, scales=3, generation=1) From b607e236f44d9f2df98e3a4587c597d9b541c823 Mon Sep 17 00:00:00 2001 From: fred3m Date: Mon, 30 Sep 2024 05:37:40 -0700 Subject: [PATCH 2/7] Implement checks in tests --- tests/test_detect.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/test_detect.py b/tests/test_detect.py index 46e5eb67..8a492385 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -159,14 +159,22 @@ def test_get_footprints(self): footprints = get_footprints(self.image.data, 1, 4, 1e-15, 1e-15, True) self._footprint_check(footprints) + def _peak_check(self, peaks): + matched_peaks = [] + for center in self.centers: + for peak in peaks: + if peak.y == center[0] and peak.x == center[1]: + matched_peaks.append(peak) + break + self.assertEqual(len(matched_peaks), len(self.centers)) + def test_detect_footprints(self): - # this method doesn't test for accurracy, just that all of the - # different configuration options run without error. + # this method doesn't test for accurracy, since # There is no variance, so we set it to ones variance = np.ones(self.image.shape, dtype=self.image.dtype) - detect_footprints( + footprints = detect_footprints( self.image.data[None, :, :], variance[None, :, :], scales=1, @@ -181,7 +189,11 @@ def test_detect_footprints(self): min_pixel_detect=1, ) - detect_footprints( + self.assertEqual(len(footprints), 3) + peaks = [peak for footprint in footprints for peak in footprint.peaks] + self._peak_check(peaks) + + footprints = detect_footprints( self.image.data[None, :, :], variance[None, :, :], scales=1, @@ -192,9 +204,13 @@ def test_detect_footprints(self): footprint_thresh=1e-15, find_peaks=True, remove_high_freq=True, - min_pixel_detect=2, + min_pixel_detect=1, ) + self.assertEqual(len(footprints), 2) + peaks = [peak for footprint in footprints for peak in footprint.peaks] + self._peak_check(peaks) + def test_bounds_to_bbox(self): bounds = (3, 27, 11, 52) truth = Box((25, 42), (3, 11)) From ad2a53e3d3f8f8b2b8a3687a44b7736aed8e08ff Mon Sep 17 00:00:00 2001 From: fred3m Date: Mon, 28 Oct 2024 09:43:21 -0700 Subject: [PATCH 3/7] Revert change to require numpy < 2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0950cbad..99dd6532 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "lsst-versions >= 1.3.0", "wheel", "pybind11 >= 2.5.0", - "numpy >= 1.18, <2.0.0", + "numpy >= 1.18", "peigen >= 0.0.9", ] build-backend = "setuptools.build_meta" From 8fd276dd9e5bf34ad0b6f6e981e61d022899cde1 Mon Sep 17 00:00:00 2001 From: fred3m Date: Mon, 28 Oct 2024 11:13:47 -0700 Subject: [PATCH 4/7] Fix incompatibility with numpy >= 2.0 and setuptools<65 --- pyproject.toml | 2 +- setup.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99dd6532..4f04d529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools<65", + "setuptools", "lsst-versions >= 1.3.0", "wheel", "pybind11 >= 2.5.0", diff --git a/setup.py b/setup.py index e9218c34..f0aef598 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,6 @@ import os # Importing this automatically enables parallelized builds -import numpy.distutils.ccompiler # noqa: F401 from pybind11.setup_helpers import Pybind11Extension, build_ext from setuptools import setup From 63d0086280727d8e06ab1afff084306a611705f3 Mon Sep 17 00:00:00 2001 From: fred3m Date: Tue, 29 Oct 2024 14:09:50 -0700 Subject: [PATCH 5/7] Respond to reviewer comments --- python/lsst/scarlet/lite/blend.py | 12 +++++++++- python/lsst/scarlet/lite/detect.py | 3 +-- python/lsst/scarlet/lite/detect_pybind11.cc | 15 +++++++++++- python/lsst/scarlet/lite/models/fit_psf.py | 24 ++++++++++++++++++++ python/lsst/scarlet/lite/models/free_form.py | 13 ++++------- python/lsst/scarlet/lite/utils.py | 4 ++-- tests/test_detect.py | 15 ++++++------ 7 files changed, 63 insertions(+), 23 deletions(-) diff --git a/python/lsst/scarlet/lite/blend.py b/python/lsst/scarlet/lite/blend.py index 23355108..adf1ecb2 100644 --- a/python/lsst/scarlet/lite/blend.py +++ b/python/lsst/scarlet/lite/blend.py @@ -114,7 +114,17 @@ def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image: return model def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]: - """Gradient of the likelihood wrt the unconvolved model""" + """Gradient of the likelihood wrt the unconvolved model + + Returns + ------- + result: + The gradient of the likelihood wrt the model + model_data: + The convol model data used to calculate the gradient. + This can be useful for debugging but is not used in + production. + """ model = self.get_model(convolve=True) # Update the loss self.loss.append(self.observation.log_likelihood(model)) diff --git a/python/lsst/scarlet/lite/detect.py b/python/lsst/scarlet/lite/detect.py index 6d4f0e6a..9b4e28f4 100644 --- a/python/lsst/scarlet/lite/detect.py +++ b/python/lsst/scarlet/lite/detect.py @@ -25,7 +25,7 @@ from typing import Sequence import numpy as np -from lsst.scarlet.lite.detect_pybind11 import Footprint # type: ignore +from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints # type: ignore from .bbox import Box, overlapped_slices from .image import Image @@ -268,7 +268,6 @@ def detect_footprints( The minimum number of bands that must be above the detection threshold for a pixel to be included in a footprint. """ - from lsst.scarlet.lite.detect_pybind11 import get_footprints if origin is None: origin = (0, 0) diff --git a/python/lsst/scarlet/lite/detect_pybind11.cc b/python/lsst/scarlet/lite/detect_pybind11.cc index e64f58d4..5d6b0f79 100644 --- a/python/lsst/scarlet/lite/detect_pybind11.cc +++ b/python/lsst/scarlet/lite/detect_pybind11.cc @@ -279,7 +279,20 @@ void maskImage( } } - +/** + * Get all footprints in an image + * + * @param image: The image to search for footprints + * @param min_separation: The minimum separation (in pixels) between peaks in a footprint + * @param min_area: The minimum area of a footprint in pixels + * @param peak_thresh: The minimum flux of a peak to be detected. + * @param footprint_thresh: The minimum flux of a pixel to be included in a footprint + * @param find_peaks: If True, find peaks in each footprint + * @param y0: The y-coordinate of the top-left corner of the image + * @param x0: The x-coordinate of the top-left corner of the image + * + * @return: A list of Footprints + */ template std::vector get_footprints( py::EigenDRef image, diff --git a/python/lsst/scarlet/lite/models/fit_psf.py b/python/lsst/scarlet/lite/models/fit_psf.py index 9f7b5e24..1a92fe0e 100644 --- a/python/lsst/scarlet/lite/models/fit_psf.py +++ b/python/lsst/scarlet/lite/models/fit_psf.py @@ -77,6 +77,19 @@ def __init__( self._fitted_kernel = parameter(cast(Fourier, self.diff_kernel).image) def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray: + """Gradient of the loss wrt the PSF + + This is just the cross correlation of the input gradient with the model. + + Parameters + ---------- + input_grad: + The gradient of the loss wrt the model + psf: + The PSF of the model. + model: + The deconvolved model. + """ grad = cast( np.ndarray, fft_convolve( @@ -134,6 +147,17 @@ def convolve(self, image: Image, mode: str | None = None, grad: bool = False) -> return Image(cast(np.ndarray, result), bands=image.bands, yx0=image.yx0) def update(self, it: int, input_grad: np.ndarray, model: np.ndarray): + """Update the PSF given the gradient of the loss + + Parameters + ---------- + it: int + The current iteration + input_grad: np.ndarray + The gradient of the loss wrt the model + model: np.ndarray + The deconvolved model. + """ self._fitted_kernel.update(it, input_grad, model) def parameterize(self, parameterization: Callable) -> None: diff --git a/python/lsst/scarlet/lite/models/free_form.py b/python/lsst/scarlet/lite/models/free_form.py index d9e31c89..24b0e76a 100644 --- a/python/lsst/scarlet/lite/models/free_form.py +++ b/python/lsst/scarlet/lite/models/free_form.py @@ -25,6 +25,8 @@ import numpy as np +from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints + from ..bbox import Box from ..component import Component, FactorizedComponent from ..detect import footprints_to_image @@ -140,7 +142,7 @@ def __repr__(self): class FreeFormComponent(Component): - """Implements a free-form component + """Implements a component with no spectral or monotonicty constraints This is a FreeFormComponent that is not factorized into a spectrum and morphology with no monotonicity constraint. @@ -158,7 +160,7 @@ def __init__( min_area: float = 0, ): if len(bands) != 1: - raise ValueError("MonochromaticDeconvolvedComponent only supports one band") + raise ValueError(f"{type(self)} only supports one band") super().__init__(bands=bands, bbox=model_bbox) self._model = parameter(model) self.bg_rms = bg_rms @@ -169,26 +171,19 @@ def __init__( @property def model(self) -> np.ndarray: - """The morphological model of the component""" return self._model.x def get_model(self) -> Image: - """Convert the model into an image""" return Image(self.model, bands=self.bands, yx0=cast(tuple[int, int], self.bbox.origin)) @property def shape(self) -> tuple: - """Shape of the resulting model image""" return self.model.shape def grad_model(self, input_grad: np.ndarray, model: np.ndarray) -> np.ndarray: - """Gradient of the morph wrt. the component model""" return input_grad def prox_model(self, model: np.ndarray) -> np.ndarray: - """Apply a prox-like update to the model""" - from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore - if self.bg_thresh is not None and isinstance(self.bg_rms, np.ndarray): bg_thresh = self.bg_rms * self.bg_thresh # Enforce background thresholding diff --git a/python/lsst/scarlet/lite/utils.py b/python/lsst/scarlet/lite/utils.py index 30f4a41c..dfd8f3a4 100644 --- a/python/lsst/scarlet/lite/utils.py +++ b/python/lsst/scarlet/lite/utils.py @@ -30,7 +30,7 @@ sqrt2 = np.sqrt(2) - +sqrt_pi = np.sqrt(np.pi) def integrated_gaussian_value(x: np.ndarray, sigma: float) -> np.ndarray: """A Gaussian function evaluated at `x` @@ -50,7 +50,7 @@ def integrated_gaussian_value(x: np.ndarray, sigma: float) -> np.ndarray: """ lhs = erfc((x - 0.5) / (sqrt2 * sigma)) rhs = erfc((x + 0.5) / (sqrt2 * sigma)) - return np.sqrt(np.pi) * 0.5 * sigma * (lhs - rhs) + return sqrt_pi * 0.5 * sigma * (lhs - rhs) def integrated_circular_gaussian( diff --git a/tests/test_detect.py b/tests/test_detect.py index 8a492385..fbdc4756 100644 --- a/tests/test_detect.py +++ b/tests/test_detect.py @@ -123,7 +123,7 @@ def test_connected(self): truth[30:32, 40] = False assert_array_equal(footprint, truth) - def _footprint_check(self, footprints): + def _check_footprints(self, footprints): self.assertEqual(len(footprints), 3) # The first footprint has a single peak @@ -157,9 +157,9 @@ def _footprint_check(self, footprints): def test_get_footprints(self): footprints = get_footprints(self.image.data, 1, 4, 1e-15, 1e-15, True) - self._footprint_check(footprints) + self._check_footprints(footprints) - def _peak_check(self, peaks): + def _check_peaks(self, peaks): matched_peaks = [] for center in self.centers: for peak in peaks: @@ -169,9 +169,8 @@ def _peak_check(self, peaks): self.assertEqual(len(matched_peaks), len(self.centers)) def test_detect_footprints(self): - # this method doesn't test for accurracy, since - - # There is no variance, so we set it to ones + # This method doesn't test for accurracy, since + # there is no variance, so we set it to ones. variance = np.ones(self.image.shape, dtype=self.image.dtype) footprints = detect_footprints( @@ -191,7 +190,7 @@ def test_detect_footprints(self): self.assertEqual(len(footprints), 3) peaks = [peak for footprint in footprints for peak in footprint.peaks] - self._peak_check(peaks) + self._check_peaks(peaks) footprints = detect_footprints( self.image.data[None, :, :], @@ -209,7 +208,7 @@ def test_detect_footprints(self): self.assertEqual(len(footprints), 2) peaks = [peak for footprint in footprints for peak in footprint.peaks] - self._peak_check(peaks) + self._check_peaks(peaks) def test_bounds_to_bbox(self): bounds = (3, 27, 11, 52) From 57008adcb5c0e05ed40b6b8654bcfb47359488ca Mon Sep 17 00:00:00 2001 From: fred3m Date: Wed, 30 Oct 2024 05:36:41 -0700 Subject: [PATCH 6/7] Fixes for flake8, black, isort --- python/lsst/scarlet/lite/models/fit_psf.py | 3 ++- python/lsst/scarlet/lite/models/free_form.py | 1 - python/lsst/scarlet/lite/utils.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/lsst/scarlet/lite/models/fit_psf.py b/python/lsst/scarlet/lite/models/fit_psf.py index 1a92fe0e..733f6f45 100644 --- a/python/lsst/scarlet/lite/models/fit_psf.py +++ b/python/lsst/scarlet/lite/models/fit_psf.py @@ -79,7 +79,8 @@ def __init__( def grad_fit_kernel(self, input_grad: np.ndarray, psf: np.ndarray, model: np.ndarray) -> np.ndarray: """Gradient of the loss wrt the PSF - This is just the cross correlation of the input gradient with the model. + This is just the cross correlation of the input gradient + with the model. Parameters ---------- diff --git a/python/lsst/scarlet/lite/models/free_form.py b/python/lsst/scarlet/lite/models/free_form.py index 24b0e76a..9c40ae19 100644 --- a/python/lsst/scarlet/lite/models/free_form.py +++ b/python/lsst/scarlet/lite/models/free_form.py @@ -24,7 +24,6 @@ from typing import Callable, cast import numpy as np - from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints from ..bbox import Box diff --git a/python/lsst/scarlet/lite/utils.py b/python/lsst/scarlet/lite/utils.py index dfd8f3a4..da5ca818 100644 --- a/python/lsst/scarlet/lite/utils.py +++ b/python/lsst/scarlet/lite/utils.py @@ -32,6 +32,7 @@ sqrt2 = np.sqrt(2) sqrt_pi = np.sqrt(np.pi) + def integrated_gaussian_value(x: np.ndarray, sigma: float) -> np.ndarray: """A Gaussian function evaluated at `x` From a843edcc1082596f0897ac79e2a917cee8225ff6 Mon Sep 17 00:00:00 2001 From: fred3m Date: Wed, 30 Oct 2024 05:41:15 -0700 Subject: [PATCH 7/7] Fix mypy error --- python/lsst/scarlet/lite/models/free_form.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/lsst/scarlet/lite/models/free_form.py b/python/lsst/scarlet/lite/models/free_form.py index 9c40ae19..dd37c38e 100644 --- a/python/lsst/scarlet/lite/models/free_form.py +++ b/python/lsst/scarlet/lite/models/free_form.py @@ -24,7 +24,7 @@ from typing import Callable, cast import numpy as np -from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints +from lsst.scarlet.lite.detect_pybind11 import get_connected_multipeak, get_footprints # type: ignore from ..bbox import Box from ..component import Component, FactorizedComponent