Skip to content

Commit

Permalink
switch band order to scarlet band order (#113)
Browse files Browse the repository at this point in the history
* switch band order to scarlet band order

* added tag to control band order

* took into account thomas suggestion to user friendly tag

* intro notebook does not need sys.path
  • Loading branch information
ismael-mendoza authored Mar 12, 2021
1 parent 20d1631 commit e34eb99
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 44 deletions.
26 changes: 20 additions & 6 deletions btk/draw_blends.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
add_noise=True,
shifts=None,
indexes=None,
dim_order="NCHW",
):
"""Initializes the DrawBlendsGenerator class.
Expand All @@ -167,7 +168,11 @@ def __init__(
random shifts. Must be of length batch_size. Must be used
with indexes.
indexes (list): Contains the ids of the galaxies to use in the stamp.
Must be of length batch_size. Must be used with shifts."""
Must be of length batch_size. Must be used with shifts.
dim_order (str): Whether to return images as numpy arrays with the channel
(band) dimension before the pixel dimensions 'NCHW' (default) or
after 'NHWC'.
"""

self.blend_generator = BlendGenerator(
catalog, sampling_function, batch_size, shifts, indexes, verbose
Expand All @@ -191,6 +196,10 @@ def __init__(
self.add_noise = add_noise
self.verbose = verbose

if dim_order not in ("NCHW", "NHWC"):
raise ValueError("dim_order must be either 'NCHW' or 'NHWC'.")
self.dim_order = (0, 1, 2) if dim_order == "NCHW" else (1, 2, 0)

def __iter__(self):
return self

Expand Down Expand Up @@ -238,9 +247,9 @@ def __next__(self):
batch_blend_cat[s.name] = []
batch_obs_cond[s.name] = []
image_shape = (
len(s.filters),
pix_stamp_size,
pix_stamp_size,
len(s.filters),
)
blend_images[s.name] = np.zeros((self.batch_size, *image_shape))
isolated_images[s.name] = np.zeros((self.batch_size, self.max_number, *image_shape))
Expand Down Expand Up @@ -329,16 +338,21 @@ def render_mini_batch(self, blend_list, psf, wcs, survey):
iso_image_multi = np.zeros(
(
self.max_number,
len(survey.filters),
pix_stamp_size,
pix_stamp_size,
len(survey.filters),
)
)
blend_image_multi = np.zeros((pix_stamp_size, pix_stamp_size, len(survey.filters)))
blend_image_multi = np.zeros((len(survey.filters), pix_stamp_size, pix_stamp_size))
for b, filt in enumerate(survey.filters):
single_band_output = self.render_blend(blend, psf[b], filt, survey)
blend_image_multi[:, :, b] = single_band_output[0]
iso_image_multi[:, :, :, b] = single_band_output[1]
blend_image_multi[b, :, :] = single_band_output[0]
iso_image_multi[:, b, :, :] = single_band_output[1]

# transpose if requested.
dim_order = np.array(self.dim_order)
blend_image_multi = blend_image_multi.transpose(dim_order)
iso_image_multi = iso_image_multi.transpose(0, *(dim_order + 1))

outputs.append([blend_image_multi, iso_image_multi, blend])
return outputs
Expand Down
2 changes: 1 addition & 1 deletion btk/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_deblended_images(self, data, index):
dict with the centers of sources detected by SEP detection
algorithm.
"""
image = np.mean(data["blend_images"][index], axis=2)
image = np.mean(data["blend_images"][index], axis=0)
peaks = self.get_centers(image)
return {"deblend_image": None, "peaks": peaks}

Expand Down
10 changes: 5 additions & 5 deletions btk/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def plot_blends(blend_images, blend_list, detected_centers=None, limits=None, ba
Args:
blend_images (array_like): Array of blend scene images to plot
[batch, height, width, bands].
[batch, bands, height, width].
blend_list (list) : List of `astropy.table.Table` with entries of true
objects. Length of list must be the batch size.
detected_centers (list, default=`None`): List of `numpy.ndarray` or
Expand Down Expand Up @@ -112,7 +112,7 @@ def plot_blends(blend_images, blend_list, detected_centers=None, limits=None, ba
)
for i in range(batch_size):
num = len(blend_list[i])
images = np.transpose(blend_images[i], axes=(2, 0, 1))
images = blend_images[i]
blend_img_rgb = get_rgb_image(images[band_indices])
_, ax = plt.subplots(1, 3, figsize=(8, 3))
ax[0].imshow(blend_img_rgb)
Expand All @@ -121,7 +121,7 @@ def plot_blends(blend_images, blend_list, detected_centers=None, limits=None, ba
ax[0].set_ylim(limits)
ax[0].set_title("gri bands")
ax[0].axis("off")
ax[1].imshow(np.sum(blend_images[i, :, :, :], axis=2))
ax[1].imshow(np.sum(blend_images[i, :, :, :], axis=0))
ax[1].set_title("Sum")
if limits:
ax[1].set_xlim(limits)
Expand Down Expand Up @@ -194,7 +194,7 @@ def plot_with_isolated(
{len(detected_centers), len(blend_list), len(blend_images)}"
)
for i in range(len(blend_list)):
images = np.transpose(blend_images[i], axes=(2, 0, 1))
images = blend_images[i]
blend_img_rgb = get_rgb_image(
images[band_indices], normalize_with_image=images[band_indices]
)
Expand All @@ -212,7 +212,7 @@ def plot_with_isolated(
num = iso_blend.shape[0]
plt.figure(figsize=(2 * num, 2))
for j in range(num):
iso_images = np.transpose(iso_blend[j], axes=(2, 0, 1))
iso_images = iso_blend[j]
iso_img_rgb = get_rgb_image(
iso_images[band_indices], normalize_with_image=images[band_indices]
)
Expand Down
33 changes: 12 additions & 21 deletions notebooks/intro.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ sep = "^1.1.1"
[tool.poetry.dev-dependencies]
Cython = "^0.29.21"
black = "^20.8b1"
flake8 = "^3.8.4"
flake8-absolute-import = "^1.0"
jupyter-sphinx = "^0.3"
mock = "^3.0.5"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def match_isolated_images_default(isolated_images):
test_batch_max = np.array([4772.817, 8506.056, 10329.56, 7636.189, 1245.693, 90.721])
test_batch_mean = 3.1101762559117585
test_batch_std = 90.74182140645624
batch_max = isolated_images.max(axis=0).max(axis=0).max(axis=0).max(axis=0)
batch_max = isolated_images.max(axis=(0, 1, 3, 4))
batch_mean = isolated_images.mean()
batch_std = isolated_images.std()
np.testing.assert_array_almost_equal(
Expand Down Expand Up @@ -109,7 +109,7 @@ def match_blend_images_default(blend_images):
test_batch_max = np.array([5428.147, 8947.227, 11190.504, 8011.935, 1536.116, 191.629])
test_batch_mean = 5.912076135028083
test_batch_std = 403.5577217178115
batch_max = blend_images.max(axis=0).max(axis=0).max(axis=0)
batch_max = blend_images.max(axis=(0, 2, 3))
batch_mean = blend_images.mean()
batch_std = blend_images.std()
np.testing.assert_array_almost_equal(
Expand Down Expand Up @@ -138,7 +138,7 @@ def match_background_noise(blend_images):
default input settings.
"""
test_batch_noise = 129660.6576538086
batch_noise = np.var(blend_images[1, 0:32, 0:32, 3])
batch_noise = np.var(blend_images[1, 3, 0:32, 0:32])
np.testing.assert_almost_equal(
batch_noise,
test_batch_noise,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_group_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_group_sampling():
draw_blend_generator = get_group_sampling_draw_generator()
output = next(draw_blend_generator)
blend_images = output["blend_images"]
batch_max = blend_images.max(axis=0).max(axis=0).max(axis=0)
batch_max = blend_images.max(axis=(0, 2, 3))
batch_mean = blend_images.mean()
batch_std = blend_images.std()
test_batch_max = np.array([17e3, 30e3, 45e3, 43e3, 13e3, 13e2])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def test_multiresolution():

assert "LSST" in draw_output["blend_list"].keys(), "Both surveys get well defined outputs"
assert "HSC" in draw_output["blend_list"].keys(), "Both surveys get well defined outputs"
assert draw_output["blend_images"]["LSST"][0].shape[0] == int(
assert draw_output["blend_images"]["LSST"][0].shape[-1] == int(
24.0 / 0.2
), "LSST survey should have a pixel scale of 0.2"
assert draw_output["blend_images"]["HSC"][0].shape[0] == int(
assert draw_output["blend_images"]["HSC"][0].shape[-1] == int(
24.0 / 0.167
), "HSC survey should have a pixel scale of 0.167"

0 comments on commit e34eb99

Please sign in to comment.