diff --git a/grizli/aws/visit_processor.py b/grizli/aws/visit_processor.py index 32ec9639..8aefa236 100755 --- a/grizli/aws/visit_processor.py +++ b/grizli/aws/visit_processor.py @@ -2309,7 +2309,9 @@ def cutout_mosaic(rootname='gds', product='{rootname}-{f}', ra=53.1615666, dec=- snowblind_kwargs=snowblind_kwargs, weight_type=weight_type, rnoise_percentile=rnoise_percentile, - calc_wcsmap=calc_wcsmap) + calc_wcsmap=calc_wcsmap, + **kwargs, + ) outsci, outwht, header, flist, wcs_tab = _ diff --git a/grizli/jwst_utils.py b/grizli/jwst_utils.py index fc22ac09..b4e071ab 100644 --- a/grizli/jwst_utils.py +++ b/grizli/jwst_utils.py @@ -26,6 +26,21 @@ # PP file WCS DO_PURE_PARALLEL_WCS = True +JWST_DQ_FLAGS = [ + "DO_NOT_USE", + "OTHER_BAD_PIXEL", + "UNRELIABLE_SLOPE", + "UNRELIABLE_BIAS", + "NO_SAT_CHECK", + "NO_GAIN_VALUE", + "HOT", + "WARM", + "DEAD", + "RC", + "LOW_QE", +] + + def set_crds_context(fits_file=None, override_environ=False, verbose=True): """ Set CRDS_CONTEXT @@ -196,6 +211,41 @@ def set_quiet_logging(level=QUIET_LEVEL): pass +def get_jwst_dq_bit(dq_flags=JWST_DQ_FLAGS, verbose=False): + """ + Get a combined bit from JWST DQ flags + + Parameters + ---------- + dq_flags : list + List of flag names + + verbose : bool + Messaging + + Returns + ------- + dq_flag : int + Combined bit flag + + """ + try: + import jwst.datamodels + except: + msg = f"get_jwst_dq_bits: import jwst.datamodels failed" + utils.log_comment(utils.LOGFILE, msg, verbose=verbose) + return 1 + + dq_flag = 1 + for _bp in dq_flags: + dq_flag |= jwst.datamodels.dqflags.pixel[_bp] + + msg = f"get_jwst_dq_bits: {'+'.join(dq_flags)} = {dq_flag}" + utils.log_comment(utils.LOGFILE, msg, verbose=verbose) + + return dq_flag + + def hdu_to_imagemodel(in_hdu): """ Workaround for initializing a `jwst.datamodels.ImageModel` from a @@ -3492,4 +3542,169 @@ def time_corr_photom_copy(param, t): amplitude, tau, t0 = param["amplitude"], param["tau"], param["t0"] corr = amplitude * np.exp(-(t - t0)/tau) - return corr \ No newline at end of file + return corr + + +def flag_nircam_hot_pixels(data='jw01837039001_02201_00001_nrcblong_rate.fits', hot_threshold=7, max_filter_size=3, hot_filter_sn_max=5, plus_sn_min=4, corner_sn_max=3, jwst_dq_flags=JWST_DQ_FLAGS, verbose=True): + """ + Flag isolated hot pixels and "plusses" around known bad pixels + + Parameters + ---------- + data : str, `~astropy.io.fits.HDUList` + NIRCam image filename or open HDU + + hot_threshold : float + S/N threshold for central hot pixel + + max_filter_size : int + Size of the local maximum filter where the central pixel is zeroed out + + hot_filter_sn_max : float + Maximum allowed S/N of the local maximum excluding the central pixel + + plus_sn_min : float + Minimum S/N of the pixels in a "plus" around known bad pixels + + corner_sn_max : float + Maximum S/N of the corners around known bad pixels + + jwst_dq_flags : list + List of JWST flag names + + verbose : bool + Messaging + + Returns + ------- + sn : array-like + S/N array derived from ``file`` + + dq : array-like, int + Flagged pixels where ``hot = HOT`` and ``plus = WARM`` + + count : int + Number of flagged pixels + + Examples + -------- + + .. plot:: + :include-source: + + import numpy as np + import matplotlib.pyplot as plt + import astropy.io.fits as pyfits + + from grizli.jwst_utils import flag_nircam_hot_pixels + + signal = np.zeros((48,48), dtype=np.float32) + + # hot + signal[16,16] = 10 + + # plus + for off in [-1,1]: + signal[32+off, 32] = 10 + signal[32, 32+off] = 7 + + err = np.ones_like(signal) + np.random.seed(1) + noise = np.random.normal(size=signal.shape)*err + + dq = np.zeros(signal.shape, dtype=int) + dq[32,32] = 2048 # HOT + + header = pyfits.Header() + header['MDRIZSKY'] = 0. + + hdul = pyfits.HDUList([ + pyfits.ImageHDU(data=signal+noise, name='SCI', header=header), + pyfits.ImageHDU(data=err, name='ERR'), + pyfits.ImageHDU(data=dq, name='DQ'), + ]) + + sn, dq_flag, count = flag_nircam_hot_pixels(hdul) + + fig, axes = plt.subplots(1,2,figsize=(8,4), sharex=True, sharey=True) + + axes[0].imshow(signal + noise, vmin=-2, vmax=9, cmap='gray') + axes[0].set_xlabel('Simulated data') + axes[1].imshow(dq_flag, cmap='magma') + axes[1].set_xlabel('Flagged pixels') + + for ax in axes: + ax.set_xticklabels([]) + ax.set_yticklabels([]) + + fig.tight_layout(pad=1) + + plt.show() + + """ + import scipy.ndimage as nd + from jwst.datamodels.mask import pixel as pixel_codes + + if isinstance(data, str): + is_open = True + rate = pyfits.open(data) + else: + rate = data + is_open = False + + bits = get_jwst_dq_bit(jwst_dq_flags) + + mask = (rate['DQ'].data & bits > 0) | (rate['ERR'].data <= 0) + mask |= (rate['SCI'].data < -3*rate['ERR'].data) | (~np.isfinite(rate['SCI'].data)) + + if 'MDRIZSKY' in rate['SCI'].header: + bkg = rate['SCI'].header['MDRIZSKY'] + else: + bkg = np.nanmedian(rate['SCI'].data[~mask]) + + indat = rate['SCI'].data - bkg + + if 'BKG' in rate: + indat -= rate['BKG'].data + + indat[mask] = 0. + sn = indat / rate['ERR'].data + sn[mask] = 0 + + ########## + # Isolated hot pixels + footprint = np.ones((max_filter_size, max_filter_size), dtype=bool) + footprint[(max_filter_size - 1) //2, (max_filter_size - 1) //2] = False + + snmax = nd.maximum_filter(sn, footprint=footprint) + + hi = sn > hot_threshold + hot = hi & (snmax < hot_filter_sn_max) + + ########### + # Plus mask + plus = np.array([[0,1,0], [1,0,1], [0,1,0]]) > 0 + corner = (~plus) + corner[1,1] = False + + sn_up = sn*1 + sn_up[mask] = 1000 + + dplus = nd.minimum_filter(sn, footprint=plus) + + dcorner = nd.maximum_filter(sn, footprint=corner) + + plusses = (dplus > plus_sn_min) & (dcorner < corner_sn_max) + plusses &= (rate['DQ'].data & bits > 0) + + plus_mask = nd.binary_dilation(plusses, structure=plus) + + dq = (hot * pixel_codes['HOT']) | (plus_mask * pixel_codes['WARM']) + + msg = f'flag_nircam_hot_pixels : hot={hot.sum()} plus={plus_mask.sum()}' + utils.log_comment(utils.LOGFILE, msg, verbose=verbose) + + if is_open: + rate.close() + + return sn, dq, (dq > 0).sum() diff --git a/grizli/utils.py b/grizli/utils.py index 3e014ef6..fdcf498c 100644 --- a/grizli/utils.py +++ b/grizli/utils.py @@ -21,6 +21,7 @@ from sregion import SRegion, patch_from_polygon from . import GRIZLI_PATH +from .jwst_utils import JWST_DQ_FLAGS KMS = u.km/u.s FLAMBDA_CGS = u.erg/u.s/u.cm**2/u.angstrom @@ -5852,7 +5853,10 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point', calc_wcsmap=False, niriss_ghost_kwargs={}, snowblind_kwargs=None, - get_dbmask=True): + jwst_dq_flags=JWST_DQ_FLAGS, + nircam_hot_pixel_kwargs={}, + get_dbmask=True, + **kwargs): """ Make drizzle mosaic from exposures in a visit dictionary @@ -5921,7 +5925,15 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point', snowblind_kwargs : dict Arguments to pass to `~grizli.utils.jwst_snowblind_mask` if `snowblind` hasn't already been run on JWST exposures - + + jwst_dq_flags : list + List of JWST flag names to include in the bad pixel mask. To ignore, set to + ``None`` + + nircam_hot_pixel_kwargs : dict + Keyword arguments for `grizli.jwst_utils.flag_nircam_hot_pixels`. Set to + ``None`` to disable and use the static bad pixel tables. + Returns ------- outsci : array-like @@ -5947,6 +5959,7 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point', from .prep import apply_region_mask_from_db from .version import __version__ as grizli__version + from .jwst_utils import get_jwst_dq_bit, flag_nircam_hot_pixels bucket_name = None @@ -6077,6 +6090,20 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point', #bpdata = 0 _inst = flt[0].header['INSTRUME'] + + # Directly flag hot pixels rather than use mask + if (_inst in ['NIRCAM']) & (nircam_hot_pixel_kwargs is not None): + _sn, dq_flag, _count = flag_nircam_hot_pixels( + flt, + **nircam_hot_pixel_kwargs + ) + if (_count > 0) & (_count < 8192): + bpdata = (dq_flag > 0)*1024 + extra_wfc3ir_badpix = False + else: + msg = f' flag_nircam_hot_pixels: {_count} out of range' + log_comment(LOGFILE, msg, verbose=verbose) + if (extra_wfc3ir_badpix) & (_inst in ['NIRCAM','NIRISS']): _det = flt[0].header['DETECTOR'] bpfiles = [os.path.join(os.path.dirname(__file__), @@ -6276,7 +6303,11 @@ def drizzle_from_visit(visit, output=None, pixfrac=1., kernel='point', # JWST: just 1,1024,4096 bits if flt[0].header['TELESCOP'] in ['JWST']: - dq = flt[('DQ', ext)].data & (1+1024+4096) + bad_bits = (1 | 1024 | 4096) + if jwst_dq_flags is not None: + bad_bits |= get_jwst_dq_bit(jwst_dq_flags, verbose=verbose) + + dq = flt[('DQ', ext)].data & bad_bits dq |= bpdata.astype(dq.dtype) # Clipping threshold for BKG extensions, global at top