Skip to content

Commit

Permalink
Merge pull request #224 from gbrammer/set-constants
Browse files Browse the repository at this point in the history
Move some global variables from utils to a new `constants.py` file
  • Loading branch information
gbrammer authored May 22, 2024
2 parents a1d229d + 0525d88 commit c69829d
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 33 deletions.
26 changes: 26 additions & 0 deletions grizli/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import astropy.units as u

KMS = u.km/u.s
FLAMBDA_CGS = u.erg/u.s/u.cm**2/u.angstrom
FNU_CGS = u.erg/u.s/u.cm**2/u.Hz

# Filter footprints
PLUS_FOOTPRINT = np.array([[0,1,0], [1,0,1], [0,1,0]]) > 0
CORNER_FOOTPRINT = (~PLUS_FOOTPRINT)
CORNER_FOOTPRINT[1,1] = False

JWST_DQ_FLAGS = [
"DO_NOT_USE",
"OTHER_BAD_PIXEL",
"UNRELIABLE_SLOPE",
"UNRELIABLE_BIAS",
"NO_SAT_CHECK",
"NO_GAIN_VALUE",
"HOT",
"WARM",
"DEAD",
"RC",
"LOW_QE",
]

126 changes: 98 additions & 28 deletions grizli/jwst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import logging
import traceback

import numpy as np

import astropy.io.fits as pyfits
import astropy.wcs as pywcs

import numpy as np
from . import utils
from . import GRIZLI_PATH

Expand All @@ -26,20 +27,7 @@
# PP file WCS
DO_PURE_PARALLEL_WCS = True

JWST_DQ_FLAGS = [
"DO_NOT_USE",
"OTHER_BAD_PIXEL",
"UNRELIABLE_SLOPE",
"UNRELIABLE_BIAS",
"NO_SAT_CHECK",
"NO_GAIN_VALUE",
"HOT",
"WARM",
"DEAD",
"RC",
"LOW_QE",
]

from .constants import JWST_DQ_FLAGS, PLUS_FOOTPRINT, CORNER_FOOTPRINT

def set_crds_context(fits_file=None, override_environ=False, verbose=True):
"""
Expand Down Expand Up @@ -3545,7 +3533,80 @@ def time_corr_photom_copy(param, t):
return corr


def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits', hot_threshold=7, max_filter_size=3, hot_filter_sn_max=5, plus_sn_min=4, corner_sn_max=3, jwst_dq_flags=JWST_DQ_FLAGS, verbose=True):
def flag_nirspec_hot_pixels(data='jw02073008001_03101_00002_nrs2_rate.fits', rnoise_percentile=90, rnoise_threshold=16, max_filter_size=3, hot_filter_sn_max=-3, corner_sn_max = -2, jwst_dq_flags=JWST_DQ_FLAGS, dilate_footprint=np.ones((3,3))):
"""
Flag NIRSpec MOS hot pixels
Parameters
----------
rnoise_percentile : float
Percentile of rnoise array for the absolute threshold
rnoise_threshold : float
The absolute ``hot_threshold`` is
``percentile(ERR_RNOISE, rnoise_percentile) * rnoise_threshold``
max_filter_size, hot_filter_sn_max, corner_sn_max, jwst_dq_flags : int, float, float
See `~grizli.jwst_utils.flag_nircam_hot_pixels`
dilate_footprint : array-like
Footprint for binary dilation on the dq mask
Returns
-------
sn : array-like
S/N array derived from ``file``
dq : array-like, int
Flagged pixels where ``hot = HOT`` and ``plus = WARM``
count : int
Number of flagged pixels
"""
import scipy.ndimage as nd
from jwst.datamodels.mask import pixel as pixel_codes

if isinstance(data, str):
is_open = True
rate = pyfits.open(data)
else:
rate = data
is_open = False

bits = get_jwst_dq_bit(jwst_dq_flags)

mask = (rate['DQ'].data & bits > 0) | (rate['ERR'].data <= 0)
mask |= (rate['SCI'].data < -3*rate['ERR'].data) | (~np.isfinite(rate['SCI'].data))

pval = np.nanpercentile(np.sqrt(rate['VAR_RNOISE'].data[mask]), rnoise_percentile)
hot_threshold = pval * rnoise_threshold

sn, dq_flag, count = flag_nircam_hot_pixels(
data=rate,
err_extension='DATA',
hot_threshold=hot_threshold,
max_filter_size=max_filter_size,
hot_filter_sn_max=hot_filter_sn_max,
plus_sn_min=hot_threshold,
corner_sn_max=corner_sn_max,
jwst_dq_flags=jwst_dq_flags,
)

if dilate_footprint is not None:
for flag in ['HOT', 'WARM']:
dq_flag |= nd.binary_dilation(
dq_flag & pixel_codes[flag] > 0,
structure=dilate_footprint
)*pixel_codes[flag]

if is_open:
rate.close()

return sn, dq_flag, count


def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits', err_extension='ERR', hot_threshold=7, max_filter_size=3, hot_filter_sn_max=5, plus_sn_min=4, corner_sn_max=3, jwst_dq_flags=JWST_DQ_FLAGS, verbose=True):
"""
Flag isolated hot pixels and "plusses" around known bad pixels
Expand Down Expand Up @@ -3668,7 +3729,13 @@ def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits',
indat -= rate['BKG'].data

indat[mask] = 0.
sn = indat / rate['ERR'].data
if err_extension == 'ERR':
sn = indat / rate['ERR'].data
elif err_extension == 'VAR_RNOISE':
sn = indat / np.sqrt(rate['VAR_RNOISE'].data)
else:
sn = indat * 1.

sn[mask] = 0

##########
Expand All @@ -3679,25 +3746,28 @@ def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits',
snmax = nd.maximum_filter(sn, footprint=footprint)

hi = sn > hot_threshold
hot = hi & (snmax < hot_filter_sn_max)
if hot_filter_sn_max < 0:
hot = hi & (snmax < sn * -1 / hot_filter_sn_max)
else:
hot = hi & (snmax < hot_filter_sn_max)

###########
# Plus mask
plus = np.array([[0,1,0], [1,0,1], [0,1,0]]) > 0
corner = (~plus)
corner[1,1] = False

# Plus mask
sn_up = sn*1
sn_up[mask] = 1000

dplus = nd.minimum_filter(sn, footprint=plus)
dplus = nd.minimum_filter(sn, footprint=PLUS_FOOTPRINT)

dcorner = nd.maximum_filter(sn, footprint=corner)
dcorner = nd.maximum_filter(sn, footprint=CORNER_FOOTPRINT)

plusses = (dplus > plus_sn_min) & (dcorner < corner_sn_max)
plusses &= (rate['DQ'].data & bits > 0)
if corner_sn_max < 0:
plusses = (dplus > plus_sn_min) & (dcorner < dplus * -1 / corner_sn_max)
else:
plusses = (dplus > plus_sn_min) & (dcorner < corner_sn_max)

plusses &= (rate['DQ'].data & bits > 0) | hot

plus_mask = nd.binary_dilation(plusses, structure=plus)
plus_mask = nd.binary_dilation(plusses, structure=PLUS_FOOTPRINT)

dq = (hot * pixel_codes['HOT']) | (plus_mask * pixel_codes['WARM'])

Expand Down
6 changes: 1 addition & 5 deletions grizli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@
from sregion import SRegion, patch_from_polygon

from . import GRIZLI_PATH
from .jwst_utils import JWST_DQ_FLAGS

KMS = u.km/u.s
FLAMBDA_CGS = u.erg/u.s/u.cm**2/u.angstrom
FNU_CGS = u.erg/u.s/u.cm**2/u.Hz
from .constants import JWST_DQ_FLAGS, KMS, FLAMBDA_CGS, FNU_CGS

# character to skip clearing line on STDOUT printing
NO_NEWLINE = '\x1b[1A\x1b[1M'
Expand Down

0 comments on commit c69829d

Please sign in to comment.