Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #490: make rescale_imagehdu more robust against dimension mismatches #503

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 15 additions & 25 deletions scopesim/optics/image_plane_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,19 +487,22 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity,
primary_wcs = WCS(imagehdu.header, key=wcs_suffix[0])

# make sure that units are correct and zoom factor is positive
# The length of the zoom factor will be determined by imagehdu.data,
# which might differ from the dimension of primary_wcs. Here, pick
# the spatial dimensions only.
pixel_scale = pixel_scale << u.Unit(primary_wcs.wcs.cunit[0])
zoom = np.abs(primary_wcs.wcs.cdelt / pixel_scale.value)
zoom = np.abs(primary_wcs.wcs.cdelt[:2] / pixel_scale.value)

if len(imagehdu.data.shape) == 3:
zoom = np.append(zoom, [1.]) # wavelength dimension unscaled if present

logger.debug("zoom factor: %s", zoom)

if primary_wcs.naxis == 3:
# zoom = np.append(zoom, [1])
zoom[2] = 1.
if primary_wcs.naxis != imagehdu.data.ndim:
# FIXME: this happens often - shouldn't WCSs be trimmed down before? (OC)
logger.warning("imagehdu.data.ndim is %d, but primary_wcs.naxis with "
"key %s is %d, both should be equal.",
imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis)
zoom = zoom[:2]

logger.debug("zoom %s", zoom)
"key %s is %d, both should be equal.",
imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis)

if all(zoom == 1.):
# Nothing to do
Expand All @@ -525,28 +528,15 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity,
logger.warning("imagehdu.data.ndim is %d, but wcs.naxis with key "
"%s is %d, both should be equal.",
imagehdu.data.ndim, ww.wcs.alt, ww.naxis)
# TODO: could this be ww = ww.sub(2) instead? or .celestial?
# ww = WCS(imagehdu.header, key=key, naxis=imagehdu.data.ndim)

if any(ctype != "LINEAR" for ctype in ww.wcs.ctype):
logger.warning("Non-linear WCS rescaled using linear procedure.")

new_crpix = (zoom + 1) / 2 + (ww.wcs.crpix - 1) * zoom
#ew_crpix = np.round(new_crpix * 2) / 2 # round to nearest half-pixel
logger.debug("new crpix %s", new_crpix)
ww.wcs.crpix = new_crpix
ww.wcs.crpix[:2] = (zoom[:2] + 1) / 2 + (ww.wcs.crpix[:2] - 1) * zoom[:2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some trouble understanding where this equation came from. Maybe we can add the derivation as a comment to the code?

Suggested change
ww.wcs.crpix[:2] = (zoom[:2] + 1) / 2 + (ww.wcs.crpix[:2] - 1) * zoom[:2]
# In the old coordinate system the following holds (assuming linearity):
# VAL = CRVAL + (PIX - CRPIX) * CDELT
# Denoting the zoomed coordinate system with primes we get:
# VAL' = CRVAL' + (PIX' - CRPIX') * CDELT'
# Where by definition CDELT' = CDELT / ZOOM, and CRVAL' = CRVAL.
#
# There is always a fixed point in such a transformation.
# All values refer to the center of the pixels, so it is the
# edge of the first pixel whose VAL is conserved.
# That is, for PIX = 1/2, PIX' = 1/2, and VAL' = VAL.
#
# Filling the above two equations with the values for the
# edge of the first pixel yields:
# VAL = CRVAL + (1/2 - CRPIX) * CDELT
# VAL = CRVAL + (1/2 - CRPIX') * CDELT / ZOOM
# Equating these equations allows us to solve for CRPIX':
# (1/2 - CRPIX) * CDELT = (1/2 - CRPIX') * CDELT / ZOOM
# (1/2 - CRPIX) * ZOOM = 1/2 - CRPIX'
# CRPIX' = 1/2 - (1/2 - CRPIX) * ZOOM
# CRPIX' = 1/2 + (CRPIX - 1/2) * ZOOM
# CRPIX' = 1/2 + CRPIX * ZOOM - ZOOM / 2
# Which is (effectively) the same as the equation below.
ww.wcs.crpix[:2] = (zoom[:2] + 1) / 2 + (ww.wcs.crpix[:2] - 1) * zoom[:2]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the derivation in a somewhat shorter form and a more rational form of the equation (your line 554).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I found my derivation a bit too verbose too.

I was wondering whether there is a way to just let astropy.wcs do this for us. But that isn't really relevant now we already have this nicely working.

logger.debug("new crpix %s", ww.wcs.crpix)

# Keep CDELT3 if cube...
new_cdelt = ww.wcs.cdelt[:]
new_cdelt /= zoom
ww.wcs.cdelt = new_cdelt

# TODO: is forcing deg here really the best way?
# FIXME: NO THIS WILL MESS UP IF new_cdelt IS IN ARCSEC!!!!!
# new_cunit = [str(cunit) for cunit in ww.wcs.cunit]
# new_cunit[0] = "mm" if key == "D" else "deg"
# new_cunit[1] = "mm" if key == "D" else "deg"
# ww.wcs.cunit = new_cunit
ww.wcs.cdelt[:2] /= zoom[:2]

imagehdu.header.update(ww.to_header())

Expand Down
27 changes: 27 additions & 0 deletions scopesim/tests/mocks/py_objects/imagehdu_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,30 @@ def _image_hdu_three_wcs():
hdu.header.update(wcs_g.to_header())

return hdu

def _image_hdu_3d_data():
nx, ny = 100, 100
nz = 3

# a 3D WCS
the_wcs0 = wcs.WCS(naxis=3, key="")
the_wcs0.wcs.ctype = ["LINEAR", "LINEAR", "WAVE"]
the_wcs0.wcs.cunit = ["arcsec", "arcsec", "um"]
the_wcs0.wcs.cdelt = [1, 1, 0.1]
the_wcs0.wcs.crval = [0, 0, 2.2]
the_wcs0.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2, 1]

# a 2D WCS for spatial dimensions
the_wcsd = wcs.WCS(naxis=2, key="D")
the_wcsd.wcs.ctype = ["LINEAR", "LINEAR"]
the_wcsd.wcs.cunit = ["mm", "mm"]
the_wcsd.wcs.cdelt = [1, 1]
the_wcsd.wcs.crval = [0, 0]
the_wcsd.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2]

image = np.ones((nz, ny, nx))
hdr = the_wcs0.to_header()
hdr.extend(the_wcsd.to_header())
hdu = fits.ImageHDU(data=image, header=hdr)

return hdu
72 changes: 49 additions & 23 deletions scopesim/tests/tests_optics/test_ImagePlane.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,64 @@
"""Tests for ImagePlane and some ImagePlaneUtils"""

# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

from copy import deepcopy

import pytest
from pytest import approx
from copy import deepcopy

import numpy as np
from astropy.io import fits
from astropy import units as u
from astropy.table import Table
from astropy import wcs

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import scopesim.optics.image_plane as opt_imp
import scopesim.optics.image_plane_utils as imp_utils

from scopesim.tests.mocks.py_objects.imagehdu_objects import \
_image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

_image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs,\
_image_hdu_3d_data

PLOTS = False


@pytest.fixture(scope="function")
def image_hdu_rect():
@pytest.fixture(scope="function", name="image_hdu_rect")
def fixture_image_hdu_rect():
return _image_hdu_rect()


@pytest.fixture(scope="function")
def image_hdu_rect_mm():
@pytest.fixture(scope="function", name="image_hdu_rect_mm")
def fixture_image_hdu_rect_mm():
return _image_hdu_rect("D")


@pytest.fixture(scope="function")
def image_hdu_square():
@pytest.fixture(scope="function", name="image_hdu_square")
def fixture_image_hdu_square():
return _image_hdu_square()


@pytest.fixture(scope="function")
def image_hdu_square_mm():
@pytest.fixture(scope="function", name="image_hdu_square_mm")
def fixture_image_hdu_square_mm():
return _image_hdu_square("D")

@pytest.fixture(scope="function")
def image_hdu_three_wcs():

@pytest.fixture(scope="function", name="image_hdu_three_wcs")
def fixture_image_hdu_three_wcs():
return _image_hdu_three_wcs()

@pytest.fixture(scope="function")
def input_table():

@pytest.fixture(scope="function", name="image_hdu_3d_data")
def fixture_image_hdu_3d_data():
return _image_hdu_3d_data()


@pytest.fixture(scope="function", name="input_table")
def fixture_input_table():
x = [-10, -10, 0, 10, 10] * u.arcsec
y = [-10, 10, 0, -10, 10] * u.arcsec
f = [1, 3, 1, 1, 5]
Expand All @@ -54,8 +67,8 @@ def input_table():
return tbl


@pytest.fixture(scope="function")
def input_table_mm():
@pytest.fixture(scope="function", name="input_table_mm")
def fixture_input_table_mm():
x = [-10, -10, 0, 10, 10] * u.mm
y = [-10, 10, 0, -10, 10] * u.mm
f = [1, 3, 1, 1, 5]
Expand Down Expand Up @@ -312,7 +325,7 @@ def test_points_are_added_to_small_canvas(self, input_table):
assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"])

if PLOTS:
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand All @@ -328,7 +341,7 @@ def test_mm_points_are_added_to_small_canvas(self, input_table_mm):
assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"])

if PLOTS:
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand Down Expand Up @@ -387,7 +400,7 @@ def test_mm_points_are_added_to_massive_canvas(self, input_table_mm):
if PLOTS:
x, y = imp_utils.val2pix(hdr, 0, 0, "D")
plt.plot(x, y, "ro")
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand Down Expand Up @@ -701,6 +714,19 @@ def test_rescale_works_on_nondefault_wcs(self, image_hdu_three_wcs):
assert new_hdu.header['CDELT1D'] == 20


def test_rescale_works_on_3d_imageplane(self, image_hdu_3d_data):
pixel_scale = 0.274
wcses = wcs.find_all_wcs(image_hdu_3d_data.header)
fact = pixel_scale / wcses[0].wcs.cdelt[0]

new_hdu = imp_utils.rescale_imagehdu(image_hdu_3d_data, pixel_scale)
new_wcses = wcs.find_all_wcs(new_hdu.header)

assert new_wcses[0].wcs.cdelt[0] == pixel_scale
assert new_wcses[0].wcs.cdelt[2] == wcses[0].wcs.cdelt[2]
assert new_wcses[1].wcs.cdelt[1] / fact == approx(wcses[1].wcs.cdelt[1])


###############################################################################
# ..todo: When you have time, reintegrate these tests, There are some good ones

Expand Down