Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NIRCam hot pixel flagging #223

Merged
merged 4 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion grizli/aws/visit_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2309,7 +2309,9 @@ def cutout_mosaic(rootname='gds', product='{rootname}-{f}', ra=53.1615666, dec=-
snowblind_kwargs=snowblind_kwargs,
weight_type=weight_type,
rnoise_percentile=rnoise_percentile,
calc_wcsmap=calc_wcsmap)
calc_wcsmap=calc_wcsmap,
**kwargs,
)

outsci, outwht, header, flist, wcs_tab = _

Expand Down
217 changes: 216 additions & 1 deletion grizli/jwst_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
# PP file WCS
DO_PURE_PARALLEL_WCS = True

JWST_DQ_FLAGS = [
"DO_NOT_USE",
"OTHER_BAD_PIXEL",
"UNRELIABLE_SLOPE",
"UNRELIABLE_BIAS",
"NO_SAT_CHECK",
"NO_GAIN_VALUE",
"HOT",
"WARM",
"DEAD",
"RC",
"LOW_QE",
]


def set_crds_context(fits_file=None, override_environ=False, verbose=True):
"""
Set CRDS_CONTEXT
Expand Down Expand Up @@ -196,6 +211,41 @@ def set_quiet_logging(level=QUIET_LEVEL):
pass


def get_jwst_dq_bit(dq_flags=JWST_DQ_FLAGS, verbose=False):
"""
Get a combined bit from JWST DQ flags

Parameters
----------
dq_flags : list
List of flag names

verbose : bool
Messaging

Returns
-------
dq_flag : int
Combined bit flag

"""
try:
import jwst.datamodels
except:
msg = f"get_jwst_dq_bits: import jwst.datamodels failed"
utils.log_comment(utils.LOGFILE, msg, verbose=verbose)
return 1

dq_flag = 1
for _bp in dq_flags:
dq_flag |= jwst.datamodels.dqflags.pixel[_bp]

msg = f"get_jwst_dq_bits: {'+'.join(dq_flags)} = {dq_flag}"
utils.log_comment(utils.LOGFILE, msg, verbose=verbose)

return dq_flag


def hdu_to_imagemodel(in_hdu):
"""
Workaround for initializing a `jwst.datamodels.ImageModel` from a
Expand Down Expand Up @@ -3492,4 +3542,169 @@ def time_corr_photom_copy(param, t):
amplitude, tau, t0 = param["amplitude"], param["tau"], param["t0"]
corr = amplitude * np.exp(-(t - t0)/tau)

return corr
return corr


def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits', hot_threshold=7, max_filter_size=3, hot_filter_sn_max=5, plus_sn_min=4, corner_sn_max=3, jwst_dq_flags=JWST_DQ_FLAGS, verbose=True):
"""
Flag isolated hot pixels and "plusses" around known bad pixels

Parameters
----------
data : str, `~astropy.io.fits.HDUList`
NIRCam image filename or open HDU

hot_threshold : float
S/N threshold for central hot pixel

max_filter_size : int
Size of the local maximum filter where the central pixel is zeroed out

hot_filter_sn_max : float
Maximum allowed S/N of the local maximum excluding the central pixel

plus_sn_min : float
Minimum S/N of the pixels in a "plus" around known bad pixels

corner_sn_max : float
Maximum S/N of the corners around known bad pixels

jwst_dq_flags : list
List of JWST flag names

verbose : bool
Messaging

Returns
-------
sn : array-like
S/N array derived from ``file``

dq : array-like, int
Flagged pixels where ``hot = HOT`` and ``plus = WARM``

count : int
Number of flagged pixels

Examples
--------

.. plot::
:include-source:

import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits

from grizli.jwst_utils import flag_nircam_hot_pixels

signal = np.zeros((48,48), dtype=np.float32)

# hot
signal[16,16] = 10

# plus
for off in [-1,1]:
signal[32+off, 32] = 10
signal[32, 32+off] = 7

err = np.ones_like(signal)
np.random.seed(1)
noise = np.random.normal(size=signal.shape)*err

dq = np.zeros(signal.shape, dtype=int)
dq[32,32] = 2048 # HOT

header = pyfits.Header()
header['MDRIZSKY'] = 0.

hdul = pyfits.HDUList([
pyfits.ImageHDU(data=signal+noise, name='SCI', header=header),
pyfits.ImageHDU(data=err, name='ERR'),
pyfits.ImageHDU(data=dq, name='DQ'),
])

sn, dq_flag, count = flag_nircam_hot_pixels(hdul)

fig, axes = plt.subplots(1,2,figsize=(8,4), sharex=True, sharey=True)

axes[0].imshow(signal + noise, vmin=-2, vmax=9, cmap='gray')
axes[0].set_xlabel('Simulated data')
axes[1].imshow(dq_flag, cmap='magma')
axes[1].set_xlabel('Flagged pixels')

for ax in axes:
ax.set_xticklabels([])
ax.set_yticklabels([])

fig.tight_layout(pad=1)

plt.show()

"""
import scipy.ndimage as nd
from jwst.datamodels.mask import pixel as pixel_codes

if isinstance(data, str):
is_open = True
rate = pyfits.open(data)
else:
rate = data
is_open = False

bits = get_jwst_dq_bit(jwst_dq_flags)

mask = (rate['DQ'].data & bits > 0) | (rate['ERR'].data <= 0)
mask |= (rate['SCI'].data < -3*rate['ERR'].data) | (~np.isfinite(rate['SCI'].data))

if 'MDRIZSKY' in rate['SCI'].header:
bkg = rate['SCI'].header['MDRIZSKY']
else:
bkg = np.nanmedian(rate['SCI'].data[~mask])

indat = rate['SCI'].data - bkg

if 'BKG' in rate:
indat -= rate['BKG'].data

indat[mask] = 0.
sn = indat / rate['ERR'].data
sn[mask] = 0

##########
# Isolated hot pixels
footprint = np.ones((max_filter_size, max_filter_size), dtype=bool)
footprint[(max_filter_size - 1) //2, (max_filter_size - 1) //2] = False

snmax = nd.maximum_filter(sn, footprint=footprint)

hi = sn > hot_threshold
hot = hi & (snmax < hot_filter_sn_max)

###########
# Plus mask
plus = np.array([[0,1,0], [1,0,1], [0,1,0]]) > 0
corner = (~plus)
corner[1,1] = False

sn_up = sn*1
sn_up[mask] = 1000

dplus = nd.minimum_filter(sn, footprint=plus)

dcorner = nd.maximum_filter(sn, footprint=corner)

plusses = (dplus > plus_sn_min) & (dcorner < corner_sn_max)
plusses &= (rate['DQ'].data & bits > 0)

plus_mask = nd.binary_dilation(plusses, structure=plus)

dq = (hot * pixel_codes['HOT']) | (plus_mask * pixel_codes['WARM'])

msg = f'flag_nircam_hot_pixels : hot={hot.sum()} plus={plus_mask.sum()}'
utils.log_comment(utils.LOGFILE, msg, verbose=verbose)

if is_open:
rate.close()

return sn, dq, (dq > 0).sum()
37 changes: 34 additions & 3 deletions grizli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sregion import SRegion, patch_from_polygon

from . import GRIZLI_PATH
from .jwst_utils import JWST_DQ_FLAGS

KMS = u.km/u.s
FLAMBDA_CGS = u.erg/u.s/u.cm**2/u.angstrom
Expand Down Expand Up @@ -5852,7 +5853,10 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',
calc_wcsmap=False,
niriss_ghost_kwargs={},
snowblind_kwargs=None,
get_dbmask=True):
jwst_dq_flags=JWST_DQ_FLAGS,
nircam_hot_pixel_kwargs={},
get_dbmask=True,
**kwargs):
"""
Make drizzle mosaic from exposures in a visit dictionary

Expand Down Expand Up @@ -5921,7 +5925,15 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',
snowblind_kwargs : dict
Arguments to pass to `~grizli.utils.jwst_snowblind_mask` if `snowblind` hasn't
already been run on JWST exposures


jwst_dq_flags : list
List of JWST flag names to include in the bad pixel mask. To ignore, set to
``None``

nircam_hot_pixel_kwargs : dict
Keyword arguments for `grizli.jwst_utils.flag_nircam_hot_pixels`. Set to
``None`` to disable and use the static bad pixel tables.

Returns
-------
outsci : array-like
Expand All @@ -5947,6 +5959,7 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',

from .prep import apply_region_mask_from_db
from .version import __version__ as grizli__version
from .jwst_utils import get_jwst_dq_bit, flag_nircam_hot_pixels

bucket_name = None

Expand Down Expand Up @@ -6077,6 +6090,20 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',

#bpdata = 0
_inst = flt[0].header['INSTRUME']

# Directly flag hot pixels rather than use mask
if (_inst in ['NIRCAM']) & (nircam_hot_pixel_kwargs is not None):
_sn, dq_flag, _count = flag_nircam_hot_pixels(
flt,
**nircam_hot_pixel_kwargs
)
if (_count > 0) & (_count < 8192):
bpdata = (dq_flag > 0)*1024
extra_wfc3ir_badpix = False
else:
msg = f' flag_nircam_hot_pixels: {_count} out of range'
log_comment(LOGFILE, msg, verbose=verbose)

if (extra_wfc3ir_badpix) & (_inst in ['NIRCAM','NIRISS']):
_det = flt[0].header['DETECTOR']
bpfiles = [os.path.join(os.path.dirname(__file__),
Expand Down Expand Up @@ -6276,7 +6303,11 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point',

# JWST: just 1,1024,4096 bits
if flt[0].header['TELESCOP'] in ['JWST']:
dq = flt[('DQ', ext)].data & (1+1024+4096)
bad_bits = (1 | 1024 | 4096)
if jwst_dq_flags is not None:
bad_bits |= get_jwst_dq_bit(jwst_dq_flags, verbose=verbose)

dq = flt[('DQ', ext)].data & bad_bits
dq |= bpdata.astype(dq.dtype)

# Clipping threshold for BKG extensions, global at top
Expand Down