Skip to content

Commit

Permalink
Merge pull request #223 from gbrammer/nircam-hotpix
Browse files Browse the repository at this point in the history
NIRCam hot pixel flagging
  • Loading branch information
gbrammer authored May 22, 2024
2 parents 1ed492f + 0455e6c commit a1d229d
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 5 deletions.
4 changes: 3 additions & 1 deletion grizli/aws/visit_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,7 +2309,9 @@ def cutout_mosaic(rootname='gds', product='{rootname}-{f}', ra=53.1615666, dec=-
snowblind_kwargs=snowblind_kwargs,
weight_type=weight_type,
rnoise_percentile=rnoise_percentile,
calc_wcsmap=calc_wcsmap)
calc_wcsmap=calc_wcsmap,
**kwargs,
)

outsci, outwht, header, flist, wcs_tab = _

Expand Down
217 changes: 216 additions & 1 deletion grizli/jwst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
# PP file WCS
DO_PURE_PARALLEL_WCS = True

JWST_DQ_FLAGS = [
"DO_NOT_USE",
"OTHER_BAD_PIXEL",
"UNRELIABLE_SLOPE",
"UNRELIABLE_BIAS",
"NO_SAT_CHECK",
"NO_GAIN_VALUE",
"HOT",
"WARM",
"DEAD",
"RC",
"LOW_QE",
]


def set_crds_context(fits_file=None, override_environ=False, verbose=True):
"""
Set CRDS_CONTEXT
Expand Down Expand Up @@ -196,6 +211,41 @@ def set_quiet_logging(level=QUIET_LEVEL):
pass


def get_jwst_dq_bit(dq_flags=JWST_DQ_FLAGS, verbose=False):
"""
Get a combined bit from JWST DQ flags
Parameters
----------
dq_flags : list
List of flag names
verbose : bool
Messaging
Returns
-------
dq_flag : int
Combined bit flag
"""
try:
import jwst.datamodels
except:
msg = f"get_jwst_dq_bits: import jwst.datamodels failed"
utils.log_comment(utils.LOGFILE, msg, verbose=verbose)
return 1

dq_flag = 1
for _bp in dq_flags:
dq_flag |= jwst.datamodels.dqflags.pixel[_bp]

msg = f"get_jwst_dq_bits: {'+'.join(dq_flags)} = {dq_flag}"
utils.log_comment(utils.LOGFILE, msg, verbose=verbose)

return dq_flag


def hdu_to_imagemodel(in_hdu):
"""
Workaround for initializing a `jwst.datamodels.ImageModel` from a
Expand Down Expand Up @@ -3492,4 +3542,169 @@ def time_corr_photom_copy(param, t):
amplitude, tau, t0 = param["amplitude"], param["tau"], param["t0"]
corr = amplitude * np.exp(-(t - t0)/tau)

return corr
return corr


def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits', hot_threshold=7, max_filter_size=3, hot_filter_sn_max=5, plus_sn_min=4, corner_sn_max=3, jwst_dq_flags=JWST_DQ_FLAGS, verbose=True):
"""
Flag isolated hot pixels and "plusses" around known bad pixels
Parameters
----------
data : str, `~astropy.io.fits.HDUList`
NIRCam image filename or open HDU
hot_threshold : float
S/N threshold for central hot pixel
max_filter_size : int
Size of the local maximum filter where the central pixel is zeroed out
hot_filter_sn_max : float
Maximum allowed S/N of the local maximum excluding the central pixel
plus_sn_min : float
Minimum S/N of the pixels in a "plus" around known bad pixels
corner_sn_max : float
Maximum S/N of the corners around known bad pixels
jwst_dq_flags : list
List of JWST flag names
verbose : bool
Messaging
Returns
-------
sn : array-like
S/N array derived from ``file``
dq : array-like, int
Flagged pixels where ``hot = HOT`` and ``plus = WARM``
count : int
Number of flagged pixels
Examples
--------
.. plot::
:include-source:
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits
from grizli.jwst_utils import flag_nircam_hot_pixels
signal = np.zeros((48,48), dtype=np.float32)
# hot
signal[16,16] = 10
# plus
for off in [-1,1]:
signal[32+off, 32] = 10
signal[32, 32+off] = 7
err = np.ones_like(signal)
np.random.seed(1)
noise = np.random.normal(size=signal.shape)*err
dq = np.zeros(signal.shape, dtype=int)
dq[32,32] = 2048 # HOT
header = pyfits.Header()
header['MDRIZSKY'] = 0.
hdul = pyfits.HDUList([
pyfits.ImageHDU(data=signal+noise, name='SCI', header=header),
pyfits.ImageHDU(data=err, name='ERR'),
pyfits.ImageHDU(data=dq, name='DQ'),
])
sn, dq_flag, count = flag_nircam_hot_pixels(hdul)
fig, axes = plt.subplots(1,2,figsize=(8,4), sharex=True, sharey=True)
axes[0].imshow(signal + noise, vmin=-2, vmax=9, cmap='gray')
axes[0].set_xlabel('Simulated data')
axes[1].imshow(dq_flag, cmap='magma')
axes[1].set_xlabel('Flagged pixels')
for ax in axes:
ax.set_xticklabels([])
ax.set_yticklabels([])
fig.tight_layout(pad=1)
plt.show()
"""
import scipy.ndimage as nd
from jwst.datamodels.mask import pixel as pixel_codes

if isinstance(data, str):
is_open = True
rate = pyfits.open(data)
else:
rate = data
is_open = False

bits = get_jwst_dq_bit(jwst_dq_flags)

mask = (rate['DQ'].data & bits > 0) | (rate['ERR'].data <= 0)
mask |= (rate['SCI'].data < -3*rate['ERR'].data) | (~np.isfinite(rate['SCI'].data))

if 'MDRIZSKY' in rate['SCI'].header:
bkg = rate['SCI'].header['MDRIZSKY']
else:
bkg = np.nanmedian(rate['SCI'].data[~mask])

indat = rate['SCI'].data - bkg

if 'BKG' in rate:
indat -= rate['BKG'].data

indat[mask] = 0.
sn = indat / rate['ERR'].data
sn[mask] = 0

##########
# Isolated hot pixels
footprint = np.ones((max_filter_size, max_filter_size), dtype=bool)
footprint[(max_filter_size - 1) //2, (max_filter_size - 1) //2] = False

snmax = nd.maximum_filter(sn, footprint=footprint)

hi = sn > hot_threshold
hot = hi & (snmax < hot_filter_sn_max)

###########
# Plus mask
plus = np.array([[0,1,0], [1,0,1], [0,1,0]]) > 0
corner = (~plus)
corner[1,1] = False

sn_up = sn*1
sn_up[mask] = 1000

dplus = nd.minimum_filter(sn, footprint=plus)

dcorner = nd.maximum_filter(sn, footprint=corner)

plusses = (dplus > plus_sn_min) & (dcorner < corner_sn_max)
plusses &= (rate['DQ'].data & bits > 0)

plus_mask = nd.binary_dilation(plusses, structure=plus)

dq = (hot * pixel_codes['HOT']) | (plus_mask * pixel_codes['WARM'])

msg = f'flag_nircam_hot_pixels : hot={hot.sum()} plus={plus_mask.sum()}'
utils.log_comment(utils.LOGFILE, msg, verbose=verbose)

if is_open:
rate.close()

return sn, dq, (dq > 0).sum()
37 changes: 34 additions & 3 deletions grizli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sregion import SRegion, patch_from_polygon

from . import GRIZLI_PATH
from .jwst_utils import JWST_DQ_FLAGS

KMS = u.km/u.s
FLAMBDA_CGS = u.erg/u.s/u.cm**2/u.angstrom
Expand Down Expand Up @@ -5852,7 +5853,10 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',
calc_wcsmap=False,
niriss_ghost_kwargs={},
snowblind_kwargs=None,
get_dbmask=True):
jwst_dq_flags=JWST_DQ_FLAGS,
nircam_hot_pixel_kwargs={},
get_dbmask=True,
**kwargs):
"""
Make drizzle mosaic from exposures in a visit dictionary
Expand Down Expand Up @@ -5921,7 +5925,15 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',
snowblind_kwargs : dict
Arguments to pass to `~grizli.utils.jwst_snowblind_mask` if `snowblind` hasn't
already been run on JWST exposures
jwst_dq_flags : list
List of JWST flag names to include in the bad pixel mask. To ignore, set to
``None``
nircam_hot_pixel_kwargs : dict
Keyword arguments for `grizli.jwst_utils.flag_nircam_hot_pixels`. Set to
``None`` to disable and use the static bad pixel tables.
Returns
-------
outsci : array-like
Expand All @@ -5947,6 +5959,7 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',

from .prep import apply_region_mask_from_db
from .version import __version__ as grizli__version
from .jwst_utils import get_jwst_dq_bit, flag_nircam_hot_pixels

bucket_name = None

Expand Down Expand Up @@ -6077,6 +6090,20 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',

#bpdata = 0
_inst = flt[0].header['INSTRUME']

# Directly flag hot pixels rather than use mask
if (_inst in ['NIRCAM']) & (nircam_hot_pixel_kwargs is not None):
_sn, dq_flag, _count = flag_nircam_hot_pixels(
flt,
**nircam_hot_pixel_kwargs
)
if (_count > 0) & (_count < 8192):
bpdata = (dq_flag > 0)*1024
extra_wfc3ir_badpix = False
else:
msg = f' flag_nircam_hot_pixels: {_count} out of range'
log_comment(LOGFILE, msg, verbose=verbose)

if (extra_wfc3ir_badpix) & (_inst in ['NIRCAM','NIRISS']):
_det = flt[0].header['DETECTOR']
bpfiles = [os.path.join(os.path.dirname(__file__),
Expand Down Expand Up @@ -6276,7 +6303,11 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',

# JWST: just 1,1024,4096 bits
if flt[0].header['TELESCOP'] in ['JWST']:
dq = flt[('DQ', ext)].data & (1+1024+4096)
bad_bits = (1 | 1024 | 4096)
if jwst_dq_flags is not None:
bad_bits |= get_jwst_dq_bit(jwst_dq_flags, verbose=verbose)

dq = flt[('DQ', ext)].data & bad_bits
dq |= bpdata.astype(dq.dtype)

# Clipping threshold for BKG extensions, global at top
Expand Down

0 comments on commit a1d229d

Please sign in to comment.