Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tickets/DM-41840: Implement improved detection algorithms #7

Merged
merged 7 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[build-system]
requires = [
"setuptools<65",
"setuptools",
"lsst-versions >= 1.3.0",
"wheel",
"pybind11 >= 2.5.0",
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/scarlet/lite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
pass

from . import initialization, io, measure, models, operators, utils
from . import initialization, io, measure, models, operators, utils, wavelet
from .fft import *
from .image import *
from .observation import *
Expand Down
2 changes: 1 addition & 1 deletion python/lsst/scarlet/lite/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def from_data(x: np.ndarray, threshold: float = 0) -> Box:
nonzero = np.where(sel)
bounds = []
for dim in range(len(x.shape)):
bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1))
bounds.append((int(nonzero[dim].min()), int(nonzero[dim].max() + 1)))
else:
bounds = [(0, 0)] * len(x.shape)
return Box.from_bounds(*bounds)
Expand Down
18 changes: 14 additions & 4 deletions python/lsst/scarlet/lite/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,25 @@ def get_model(self, convolve: bool = False, use_flux: bool = False) -> Image:
return self.observation.convolve(model)
return model

def _grad_log_likelihood(self) -> Image:
"""Gradient of the likelihood wrt the unconvolved model"""
def _grad_log_likelihood(self) -> tuple[Image, np.ndarray]:
"""Gradient of the likelihood wrt the unconvolved model

Returns
-------
result:
The gradient of the likelihood wrt the model
model_data:
The convol model data used to calculate the gradient.
This can be useful for debugging but is not used in
production.
"""
model = self.get_model(convolve=True)
# Update the loss
self.loss.append(self.observation.log_likelihood(model))
# Calculate the gradient wrt the model d(logL)/d(model)
result = self.observation.weights * (model - self.observation.images)
result = self.observation.convolve(result, grad=True)
return result
return result, model.data

@property
def log_likelihood(self) -> float:
Expand Down Expand Up @@ -244,7 +254,7 @@ def fit(
# Update each component given the current gradient
for component in self.components:
overlap = component.bbox & self.bbox
component.update(self.it, grad_log_likelihood[overlap].data)
component.update(self.it, grad_log_likelihood[0][overlap].data)
# Check to see if any components need to be resized
if do_resize:
component.resize(self.bbox)
Expand Down
136 changes: 119 additions & 17 deletions python/lsst/scarlet/lite/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@
from __future__ import annotations

import logging
from typing import Sequence, cast
from typing import Sequence

import numpy as np
from lsst.scarlet.lite.detect_pybind11 import Footprint # type: ignore
from lsst.scarlet.lite.detect_pybind11 import Footprint, get_footprints # type: ignore

from .bbox import Box
from .bbox import Box, overlapped_slices
from .image import Image
from .utils import continue_class
from .wavelet import get_multiresolution_support, starlet_transform
from .wavelet import (
get_multiresolution_support,
get_starlet_scales,
multiband_starlet_reconstruction,
starlet_transform,
)

logger = logging.getLogger("scarlet.detect")

Expand Down Expand Up @@ -111,30 +116,34 @@ def union(self, other: Footprint) -> Image | None:
return footprint1 | footprint2


def footprints_to_image(footprints: Sequence[Footprint], shape: tuple[int, int]) -> Image:
def footprints_to_image(footprints: Sequence[Footprint], bbox: Box) -> Image:
"""Convert a set of scarlet footprints to a pixelized image.

Parameters
----------
footprints:
The footprints to convert into an image.
shape:
The shape of the image that is created from the footprints.
box:
The full box of the image that will contain the footprints.

Returns
-------
result:
The image created from the footprints.
"""
result = Image.from_box(Box(shape), dtype=int)
result = Image.from_box(bbox, dtype=int)
for k, footprint in enumerate(footprints):
bbox = bounds_to_bbox(footprint.bounds)
fp_image = Image(footprint.data, yx0=cast(tuple[int, int], bbox.origin))
result = result + fp_image * (k + 1)
slices = overlapped_slices(result.bbox, footprint.bbox)
result.data[slices[0]] += footprint.data[slices[1]] * (k + 1)
return result


def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None = None) -> np.ndarray:
def get_wavelets(
images: np.ndarray,
variance: np.ndarray,
scales: int | None = None,
generation: int = 2,
) -> np.ndarray:
"""Calculate wavelet coefficents given a set of images and their variances

Parameters
Expand All @@ -157,9 +166,10 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
"""
sigma = np.median(np.sqrt(variance), axis=(1, 2))
# Create the wavelet coefficients for the significant pixels
coeffs = []
scales = get_starlet_scales(images[0].shape, scales)
coeffs = np.empty((scales + 1,) + images.shape, dtype=images.dtype)
for b, image in enumerate(images):
_coeffs = starlet_transform(image, scales=scales)
_coeffs = starlet_transform(image, scales=scales, generation=generation)
support = get_multiresolution_support(
image=image,
starlets=_coeffs,
Expand All @@ -168,8 +178,8 @@ def get_wavelets(images: np.ndarray, variance: np.ndarray, scales: int | None =
epsilon=1e-1,
max_iter=20,
)
coeffs.append((support * _coeffs).astype(images.dtype))
return np.array(coeffs)
coeffs[:, b] = (support.support * _coeffs).astype(images.dtype)
return coeffs


def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int = 3) -> np.ndarray:
Expand Down Expand Up @@ -206,4 +216,96 @@ def get_detect_wavelets(images: np.ndarray, variance: np.ndarray, scales: int =
epsilon=1e-1,
max_iter=20,
)
return (support * _coeffs).astype(images.dtype)
return (support.support * _coeffs).astype(images.dtype)


def detect_footprints(
images: np.ndarray,
variance: np.ndarray,
scales: int = 2,
generation: int = 2,
origin: tuple[int, int] | None = None,
min_separation: float = 4,
min_area: int = 4,
peak_thresh: float = 5,
footprint_thresh: float = 5,
find_peaks: bool = True,
remove_high_freq: bool = True,
min_pixel_detect: int = 1,
) -> list[Footprint]:
"""Detect footprints in an image

Parameters
----------
images:
The array of images with shape `(bands, Ny, Nx)` for which to
calculate wavelet coefficients.
variance:
An array of variances with the same shape as `images`.
scales:
The maximum number of wavelet scales to use.
If `remove_high_freq` is `False`, then this argument is ignored.
generation:
The generation of the starlet transform to use.
If `remove_high_freq` is `False`, then this argument is ignored.
origin:
The location (y, x) of the lower corner of the image.
min_separation:
The minimum separation between peaks in pixels.
min_area:
The minimum area of a footprint in pixels.
peak_thresh:
The threshold for peak detection.
footprint_thresh:
The threshold for footprint detection.
find_peaks:
If `True`, then detect peaks in the detection image,
otherwise only the footprints are returned.
remove_high_freq:
If `True`, then remove high frequency wavelet coefficients
before detecting peaks.
min_pixel_detect:
The minimum number of bands that must be above the
detection threshold for a pixel to be included in a footprint.
"""

if origin is None:
origin = (0, 0)
if remove_high_freq:
# Build the wavelet coefficients
wavelets = get_wavelets(
images,
variance,
scales=scales,
generation=generation,
)
# Remove the high frequency wavelets.
# This has the effect of preventing high frequency noise
# from interfering with the detection of peak positions.
wavelets[0] = 0
# Reconstruct the image from the remaining wavelet coefficients
_images = multiband_starlet_reconstruction(
wavelets,
generation=generation,
)
else:
_images = images
# Build a SNR weighted detection image
sigma = np.median(np.sqrt(variance), axis=(1, 2)) / 2
detection = np.sum(_images / sigma[:, None, None], axis=0)
if min_pixel_detect > 1:
mask = np.sum(images > 0, axis=0) >= min_pixel_detect
detection[~mask] = 0
# Detect peaks on the detection image
footprints = get_footprints(
detection,
min_separation,
min_area,
peak_thresh,
footprint_thresh,
find_peaks,
origin[0],
origin[1],
)

return footprints
Loading
Loading