Skip to content

Commit

Permalink
Implement improved wavelet detection
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Sep 20, 2024
1 parent ae00866 commit 2648153
Show file tree
Hide file tree
Showing 17 changed files with 539 additions and 189 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [
"lsst-versions >= 1.3.0",
"wheel",
"pybind11 >= 2.5.0",
"numpy >= 1.18",
"numpy >= 1.18, <2.0.0",
"peigen >= 0.0.9",
]
build-backend = "setuptools.build_meta"
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/scarlet/lite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
pass

from . import initialization, io, measure, models, operators, utils
from . import initialization, io, measure, models, operators, utils, wavelet
from .fft import *
from .image import *
from .observation import *
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/scarlet/lite/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def from_data(x: np.ndarray, threshold: float = 0) -> Box:
nonzero = np.where(sel)
bounds = []
for dim in range(len(x.shape)):
bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1))
bounds.append((int(nonzero[dim].min()), int(nonzero[dim].max() + 1)))
else:
bounds = [(0, 0)] * len(x.shape)
return Box.from_bounds(*bounds)
Expand Down
6 changes: 3 additions & 3 deletions python/lsst/scarlet/lite/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
return self.observation.convolve(model)
return model

def _grad_log_likelihood(self) -> Image:
def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
"""Gradient of the likelihood wrt the unconvolved model"""
model = self.get_model(convolve=True)
# Update the loss
self.loss.append(self.observation.log_likelihood(model))
# Calculate the gradient wrt the model d(logL)/d(model)
result = self.observation.weights * (model - self.observation.images)
result = self.observation.convolve(result, grad=True)
return result
return result, model.data

@property
def log_likelihood(self) -> float:
Expand Down Expand Up @@ -244,7 +244,7 @@ def fit(
# Update each component given the current gradient
for component in self.components:
overlap = component.bbox & self.bbox
component.update(self.it, grad_log_likelihood[overlap].data)
component.update(self.it, grad_log_likelihood[0][overlap].data)
# Check to see if any components need to be resized
if do_resize:
component.resize(self.bbox)
Expand Down
135 changes: 119 additions & 16 deletions python/lsst/scarlet/lite/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@
from __future__ import annotations

import logging
from typing import Sequence, cast
from typing import Sequence

import numpy as np
from lsst.scarlet.lite.detect_pybind11 import Footprint # type: ignore

from .bbox import Box
from .bbox import Box, overlapped_slices
from .image import Image
from .utils import continue_class
from .wavelet import get_multiresolution_support, starlet_transform
from .wavelet import (
get_multiresolution_support,
get_starlet_scales,
multiband_starlet_reconstruction,
starlet_transform,
)

logger = logging.getLogger("scarlet.detect")

Expand Down Expand Up @@ -111,30 +116,34 @@ def union(self, other: Footprint) -> Image | None:
return footprint1 | footprint2


def footprints_to_image(footprints: Sequence[Footprint], shape: tuple[int, int]) -> Image:
def footprints_to_image(footprints: Sequence[Footprint], bbox: Box) -> Image:
"""Convert a set of scarlet footprints to a pixelized image.
Parameters
----------
footprints:
The footprints to convert into an image.
shape:
The shape of the image that is created from the footprints.
box:
The full box of the image that will contain the footprints.
Returns
-------
result:
The image created from the footprints.
"""
result = Image.from_box(Box(shape), dtype=int)
result = Image.from_box(bbox, dtype=int)
for k, footprint in enumerate(footprints):
bbox = bounds_to_bbox(footprint.bounds)
fp_image = Image(footprint.data, yx0=cast(tuple[int, int], bbox.origin))
result = result + fp_image * (k + 1)
slices = overlapped_slices(result.bbox, footprint.bbox)
result.data[slices[0]] += footprint.data[slices[1]] * (k + 1)
return result


def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None = None) -> np.ndarray:
def get_wavelets(
images: np.ndarray,
variance: np.ndarray,
scales: int | None = None,
generation: int = 2,
) -> np.ndarray:
"""Calculate wavelet coefficents given a set of images and their variances
Parameters
Expand All @@ -157,9 +166,10 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
"""
sigma = np.median(np.sqrt(variance), axis=(1, 2))
# Create the wavelet coefficients for the significant pixels
coeffs = []
scales = get_starlet_scales(images[0].shape, scales)
coeffs = np.empty((scales + 1,) + images.shape, dtype=images.dtype)
for b, image in enumerate(images):
_coeffs = starlet_transform(image, scales=scales)
_coeffs = starlet_transform(image, scales=scales, generation=generation)
support = get_multiresolution_support(
image=image,
starlets=_coeffs,
Expand All @@ -168,8 +178,8 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
epsilon=1e-1,
max_iter=20,
)
coeffs.append((support * _coeffs).astype(images.dtype))
return np.array(coeffs)
coeffs[:, b] = (support.support * _coeffs).astype(images.dtype)
return coeffs


def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = 3) -> np.ndarray:
Expand Down Expand Up @@ -206,4 +216,97 @@ def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int =
epsilon=1e-1,
max_iter=20,
)
return (support * _coeffs).astype(images.dtype)
return (support.support * _coeffs).astype(images.dtype)


def detect_footprints(
images: np.ndarray,
variance: np.ndarray,
scales: int = 2,
generation: int = 2,
origin: tuple[int, int] | None = None,
min_separation: float = 4,
min_area: int = 4,
peak_thresh: float = 5,
footprint_thresh: float = 5,
find_peaks: bool = True,
remove_high_freq: bool = True,
min_pixel_detect: int = 1,
) -> Sequence[Footprint]:
"""Detect footprints in an image
Parameters
----------
images:
The array of images with shape `(bands, Ny, Nx)` for which to
calculate wavelet coefficients.
variance:
An array of variances with the same shape as `images`.
scales:
The maximum number of wavelet scales to use.
If `remove_high_freq` is `False`, then this argument is ignored.
generation:
The generation of the starlet transform to use.
If `remove_high_freq` is `False`, then this argument is ignored.
origin:
The location (y, x) of the lower corner of the image.
min_separation:
The minimum separation between peaks in pixels.
min_area:
The minimum area of a footprint in pixels.
peak_thresh:
The threshold for peak detection.
footprint_thresh:
The threshold for footprint detection.
find_peaks:
If `True`, then detect peaks in the detection image,
otherwise only the footprints are returned.
remove_high_freq:
If `True`, then remove high frequency wavelet coefficients
before detecting peaks.
min_pixel_detect:
The minimum number of bands that must be above the
detection threshold for a pixel to be included in a footprint.
"""
from lsst.scarlet.lite.detect_pybind11 import get_footprints

if origin is None:
origin = (0, 0)
if remove_high_freq:
# Build the wavelet coefficients
wavelets = get_wavelets(
images,
variance,
scales=scales,
generation=generation,
)
# Remove the high frequency wavelets.
# This has the effect of preventing high frequency noise
# from interfering with the detection of peak positions.
wavelets[0] = 0
# Reconstruct the image from the remaining wavelet coefficients
_images = multiband_starlet_reconstruction(
wavelets,
generation=generation,
)
else:
_images = images
# Build a SNR weighted detection image
sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2
detection = np.sum(_images / sigma[:, None, None], axis=0)
if min_pixel_detect > 1:
mask = np.sum(images > 0, axis=0) >= min_pixel_detect
detection[~mask] = 0
# Detect peaks on the detection image
footprints = get_footprints(
detection,
min_separation,
min_area,
peak_thresh,
footprint_thresh,
find_peaks,
origin[0],
origin[1],
)

return footprints
Loading

0 comments on commit 2648153

Please sign in to comment.