Skip to content

Commit

Permalink
Merge branch 'dev' into add-01-tutorial-notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
ismael-mendoza authored Apr 30, 2024
2 parents 70c2759 + 8939751 commit 8ac8484
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 168 deletions.
7 changes: 7 additions & 0 deletions btk/deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,13 @@ def __init__(
of the argument also see the Scarlet API at:
https://pmelchior.github.io/scarlet/api/scarlet.initialization.html.
Note that as of commit 45187fd, Scarlet raises a `LinAlg` error if two sources are on
the same pixel, which is allowed by the majority of currently implemented sampling
functions in BTK. To get around this, our Deblender implementation automatically
catches this exception and outputs an array of zeroes for the deblended images of the
particular blend that caused this exception. See this issue for details:
https://github.com/pmelchior/scarlet/issues/282#issuecomment-2074886534
Args:
max_n_sources: See parent class.
thresh: Multiple of the backround RMS used as a flux cutoff for morphology
Expand Down
4 changes: 2 additions & 2 deletions btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_blendedness(iso_image: np.ndarray) -> np.ndarray:
Args:
iso_image: Array of shape = (..., N, H, W) corresponding to images of isolated
galaxiesi you are calculating blendedness for.
galaxies you are calculating blendedness for.
Returns:
Array of size (..., N) corresponding to blendedness values for each individual galaxy.
Expand All @@ -70,7 +70,7 @@ def get_blendedness(iso_image: np.ndarray) -> np.ndarray:
num = np.sum(iso_image * iso_image, axis=(-1, -2))
blend = np.sum(iso_image, axis=-3)[..., None, :, :]
denom = np.sum(blend * iso_image, axis=(-1, -2))
return 1 - np.divide(num, denom, out=np.zeros_like(num), where=(num != 0))
return 1 - np.divide(num, denom, out=np.ones_like(num), where=(num != 0))


def get_snr(iso_image: np.ndarray, sky_level: float) -> np.ndarray:
Expand Down
55 changes: 55 additions & 0 deletions tests/test_cosmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import btk
from btk.survey import Survey

SEED = 0


def test_cosmos_generator(data_dir):
"""Test the pipeline as a whole for a single deblender."""
cosmos_catalog_paths = [
data_dir / "cosmos" / "real_galaxy_catalog_23.5_example.fits",
data_dir / "cosmos" / "real_galaxy_catalog_23.5_example_fits.fits",
]
cosmos_catalog_files = [p.as_posix() for p in cosmos_catalog_paths]
catalog = btk.catalog.CosmosCatalog.from_file(cosmos_catalog_files)

_ = catalog.get_raw_catalog()

survey: Survey = btk.survey.get_surveys("LSST")
fltr = survey.get_filter("r")
assert hasattr(fltr, "psf")

stamp_size = 24.0
max_shift = 1.0
max_n_sources = 2
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=max_n_sources,
min_number=1,
stamp_size=stamp_size,
max_shift=max_shift,
min_mag=20,
max_mag=21,
seed=SEED,
mag_name="MAG",
)

batch_size = 10

draw_generator = btk.draw_blends.CosmosGenerator(
catalog,
sampling_function,
survey,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=SEED,
gal_type="real",
)

# generate batch 100 blend catalogs and images.
blend_batch = next(draw_generator)
assert len(blend_batch.catalog_list) == batch_size
assert blend_batch.blend_images.shape == (batch_size, 6, stamp_size / 0.2, stamp_size / 0.2)
iso_shape = (batch_size, max_n_sources, 6, stamp_size / 0.2, stamp_size / 0.2)
assert blend_batch.isolated_images.shape == iso_shape
97 changes: 97 additions & 0 deletions tests/test_deblenders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

import btk
from btk.survey import Survey

SEED = 0


def test_sep(data_dir):
"""Check we always detect single bright objects."""

catalog_file = data_dir / "input_catalog.fits"
catalog = btk.catalog.CatsimCatalog.from_file(catalog_file)
survey: Survey = btk.survey.get_surveys("LSST")

# single bright galaxy=
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=1,
min_number=1,
stamp_size=24.0,
max_shift=1.0,
min_mag=0,
max_mag=21,
seed=SEED,
)

assert np.sum((catalog.table["i_ab"] > 0) & (catalog.table["i_ab"] < 21)) > 100

batch_size = 100

draw_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
survey,
batch_size=batch_size,
stamp_size=24.0,
njobs=1,
add_noise="all",
seed=SEED,
)

blend_batch = next(draw_generator)
deblender = btk.deblend.SepSingleBand(max_n_sources=1, thresh=3, use_band=2)
deblend_batch = deblender(blend_batch, njobs=1)

matcher = btk.match.PixelHungarianMatcher(pixel_max_sep=5.0)

true_catalog_list = blend_batch.catalog_list
pred_catalog_list = deblend_batch.catalog_list
matching = matcher(true_catalog_list, pred_catalog_list) # matching object
tp, t, p = matching.tp, matching.t, matching.p

recall = btk.metrics.detection.Recall(batch_size)
precision = btk.metrics.detection.Precision(batch_size)

assert recall(tp, t, p) > 0.95
assert precision(tp, t, p) > 0.95


def test_scarlet(data_dir):
"""Check scarlet deblender implementation runs without too many failures."""

max_n_sources = 3
stamp_size = 24.0
seed = 0
max_shift = 2.0 # shift is only 2 arcsecs -> 10 pixels, so blends are likely.

catalog = btk.catalog.CatsimCatalog.from_file(data_dir / "input_catalog.fits")
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=max_n_sources,
min_number=max_n_sources, # always 3 sources in every blend.
stamp_size=stamp_size,
max_shift=max_shift,
min_mag=24,
max_mag=25,
seed=seed,
)
LSST = btk.survey.get_surveys("LSST")

batch_size = 10

draw_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
LSST,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=seed, # use same seed here
)

blend_batch = next(draw_generator)
deblender = btk.deblend.Scarlet(max_n_sources)
deblend_batch = deblender(blend_batch, reference_catalogs=blend_batch.catalog_list)
n_failures = np.sum([len(cat) == 0 for cat in deblend_batch.catalog_list], axis=0)
assert n_failures <= 3
33 changes: 33 additions & 0 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from astropy.table import Table

from btk.match import PixelHungarianMatcher


def test_matching():
x1 = [12.0, 31.0]
y1 = [10.0, 30.0]
x2 = [34.0, 12.1, 20.1]
y2 = [33.0, 10.1, 22.0]

t1 = Table()
t1["x_peak"] = x1
t1["y_peak"] = y1

t2 = Table()
t2["x_peak"] = x2
t2["y_peak"] = y2

catalog_list1 = [t1]
catalog_list2 = [t2]

matcher1 = PixelHungarianMatcher(pixel_max_sep=1)

match = matcher1(catalog_list1, catalog_list2)

assert match.n_true == 2
assert match.n_pred == 3
assert match.tp == 1
assert match.fp == 2

assert match.true_matches == [[0]]
assert match.pred_matches == [[1]]
87 changes: 87 additions & 0 deletions tests/test_measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Test measure functions run on simple outputs from generator and deblenders."""

import numpy as np
from galcheat.utilities import mean_sky_level

import btk
from btk.measure import get_aperture_fluxes, get_blendedness, get_ksb_ellipticity, get_snr
from btk.survey import Survey

SEED = 0


def test_measure(data_dir):
catalog_file = data_dir / "input_catalog.fits"
catalog = btk.catalog.CatsimCatalog.from_file(catalog_file)

_ = catalog.get_raw_catalog()

survey: Survey = btk.survey.get_surveys("LSST")
fltr = survey.get_filter("r")
assert hasattr(fltr, "psf")

stamp_size = 24.0
max_shift = 2.0
max_n_sources = 4
sampling_function = btk.sampling_functions.DefaultSampling(
max_number=max_n_sources,
min_number=1,
stamp_size=stamp_size,
max_shift=max_shift,
min_mag=20,
max_mag=21,
seed=SEED,
)

batch_size = 10

draw_generator = btk.draw_blends.CatsimGenerator(
catalog,
sampling_function,
survey,
batch_size=batch_size,
stamp_size=stamp_size,
njobs=1,
add_noise="all",
seed=SEED,
)

batch = next(draw_generator)
sky_level = mean_sky_level(survey, survey.get_filter("r")).to_value("electron")

# combine all centroids
xs_peak = np.zeros((batch_size, max_n_sources))
ys_peak = np.zeros((batch_size, max_n_sources))
for ii, t in enumerate(batch.catalog_list):
n_sources = len(t["x_peak"])
xs_peak[ii, :n_sources] = t["x_peak"].value
ys_peak[ii, :n_sources] = t["y_peak"].value

# aperture photometry
fluxes, fluxerr = get_aperture_fluxes(batch.blend_images[:, 2], xs_peak, ys_peak, 5, sky_level)
assert fluxes.shape == (batch_size, max_n_sources)
assert fluxerr.shape == (batch_size, max_n_sources)

# blendedness
blendedness = get_blendedness(batch.isolated_images[:, :, 2])
assert blendedness.shape == (batch_size, max_n_sources)
assert np.all(np.less_equal(blendedness, 1)) and np.all(np.greater_equal(blendedness, 0.0))

# snr
snr = get_snr(batch.isolated_images[:, :, 2], sky_level)
snr.shape == (batch_size, max_n_sources)
assert np.all(np.greater_equal(snr, 0))

# ellipticity
ellips = get_ksb_ellipticity(batch.isolated_images[:, :, 2], batch.psf[2], 0.2)
assert ellips.shape == (batch_size, max_n_sources, 2)

# zeroes if no galaxies
for ii in range(batch_size):
n_sources = len(batch.catalog_list[ii])
for jj in range(max_n_sources):
if jj >= n_sources:
print(blendedness)
assert snr[ii, jj] == 0
assert np.all(np.isnan(ellips[ii, jj]))
assert blendedness[ii, jj] == 0
Loading

0 comments on commit 8ac8484

Please sign in to comment.