Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subscan operations to preprocess #1028

Merged
merged 10 commits into from
Dec 4, 2024
1 change: 1 addition & 0 deletions docs/preprocess.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ Flagging and Products
.. autoclass:: sotodlib.preprocess.processes.FlagTurnarounds
.. autoclass:: sotodlib.preprocess.processes.DarkDets
.. autoclass:: sotodlib.preprocess.processes.SourceFlags
.. autoclass:: sotodlib.preprocess.processes.GetStats

HWP Related
:::::::::::
Expand Down
19 changes: 16 additions & 3 deletions sotodlib/preprocess/pcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .. import core
from so3g.proj import Ranges, RangesMatrix
from scipy.sparse import csr_array
from matplotlib import pyplot as plt

class _Preprocess(object):
"""The base class for Preprocessing modules which defines the required
Expand Down Expand Up @@ -270,16 +271,27 @@ def _expand(new, full, wrap_valid=True):
continue
out.wrap_new( k, new._assignments[k], cls=_zeros_cls(v))
oidx=[]; nidx=[]
for a in new._assignments[k]:
for ii, a in enumerate(new._assignments[k]):
if a == 'dets':
oidx.append(fs_dets)
nidx.append(ns_dets)
elif a == 'samps':
oidx.append(fs_samps)
nidx.append(ns_samps)
else:
oidx.append(slice(None))
nidx.append(slice(None))
if (ii == 0) and isinstance(out[k], RangesMatrix): # Treat like dets
# _ranges_matrix_match expects oidx[0] and nidx[0] to be list(inds), not slice.
# Unknown axes treated as dets if first entry, else like samps. Added to support (subscans, samps) RangesMatrix.
if a in full._axes:
_, fs, ns = full[a].intersection(new[a], return_slices=True)
else:
fs = range(new[a].count)
ns = range(new[a].count)
oidx.append(fs)
nidx.append(ns)
else: # Treat like samps
oidx.append(slice(None))
nidx.append(slice(None))
oidx = tuple(oidx)
nidx = tuple(nidx)
if isinstance(out[k], RangesMatrix):
Expand Down Expand Up @@ -456,6 +468,7 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False):
update_full_aman( proc_aman, full, self.wrap_valid)
if update_plot:
process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png'))
plt.close()
if select:
process.select(aman, proc_aman)
proc_aman.restrict('dets', aman.dets.vals)
Expand Down
38 changes: 38 additions & 0 deletions sotodlib/preprocess/preprocess_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,44 @@ def plot_trending_flags(aman, trend_aman, filename='./trending_flags.png'):
os.makedirs(head_tail[0], exist_ok=True)
plt.savefig(filename)

def plot_signal(aman, signal=None, xx=None, signal_name="signal", x_name="timestamps", plot_ds_factor=50, plot_ds_factor_dets=None, xlim=None, alpha=0.2, yscale='linear', y_unit=None, filename="./signal.png"):
from operator import attrgetter
if plot_ds_factor_dets is None:
plot_ds_factor_dets = plot_ds_factor
if signal is None:
signal = attrgetter(signal_name)(aman)
if xx is None:
xx = attrgetter(x_name)(aman)
yy = signal[::plot_ds_factor_dets, 1::plot_ds_factor].copy() # (dets, samps); (dets, nusamps); (dets, nusamps, subscans)
xx = xx[1::plot_ds_factor].copy() # (samps); (nusamps)
if x_name == "timestamps":
xx -= xx[0]
if yy.ndim > 2: # Flatten subscan axis into dets
yy = yy.swapaxes(1,2).reshape(-1, yy.shape[1])

if xlim is not None:
xinds = np.logical_and(xx >= xlim[0], xx <= xlim[1])
xx = xx[xinds]
yy = yy[:,xinds]

fig, ax = plt.subplots(1, 1, figsize=(6.4, 4.8))
ax.plot(xx, yy.T, color='k', alpha=0.2)
ax.set_yscale(yscale)
if "freqs" in x_name:
ax.set_xlabel("freq [Hz]")
else:
ax.set_xlabel(f"{x_name} [s]")
y_unit = "" if y_unit is None else f" [{y_unit}]"
ax.set_ylabel(f"{signal_name.replace('.Pxx', '')}{y_unit}")
plt.suptitle(f"{aman.obs_info.obs_id}, dT = {np.ptp(aman.timestamps)/60:.1f} min")
plt.tight_layout()
head_tail = os.path.split(filename)
os.makedirs(head_tail[0], exist_ok=True)
plt.savefig(filename)

def plot_psd(aman, signal=None, xx=None, signal_name="psd.Pxx", x_name="psd.freqs", plot_ds_factor=4, plot_ds_factor_dets=20, xlim=None, alpha=0.2, yscale='log', y_unit=None, filename="./psd.png"):
return plot_signal(aman, signal, xx, signal_name, x_name, plot_ds_factor, plot_ds_factor_dets, xlim, alpha, yscale, y_unit, filename)

def plot_signal_diff(aman, flag_aman, flag_type="glitches", flag_threshold=10, plot_ds_factor=50, filename="./glitch_signal_diff.png"):
"""
Function for plotting the difference in signal before and after cuts from either glitches or jumps.
Expand Down
142 changes: 124 additions & 18 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class GlitchDetection(_FracFlaggedMixIn, _Preprocess):
buffer: 10
hp_fc: 1
n_sig: 10
subscan: False
save: True
plot:
plot_ds_factor: 50
Expand Down Expand Up @@ -340,6 +341,7 @@ def plot(self, aman, proc_aman, filename):
plot_ds_factor=self.plot_cfgs.get("plot_ds_factor", 50), filename=filename.replace('{name}', f'{ufm}_jump_signal_diff'))
plot_flag_stats(aman, proc_aman[name], flag_type='jumps', filename=filename.replace('{name}', f'{ufm}_jumps_stats'))


class PSDCalc(_Preprocess):
""" Calculate the PSD of the data and add it to the Preprocessing AxisManager under the
"psd" field.
Expand All @@ -353,6 +355,7 @@ class PSDCalc(_Preprocess):
"psd_cfgs": # optional, kwargs to scipy.welch
"nperseg": 1024
"wrap_name": "psd" # optional
"subscan": False
"save": True

.. autofunction:: sotodlib.tod_ops.fft_ops.calc_psd
Expand All @@ -368,21 +371,105 @@ def __init__(self, step_cfgs):
def calc_and_save(self, aman, proc_aman):
freqs, Pxx = tod_ops.fft_ops.calc_psd(aman, signal=aman[self.signal],
**self.calc_cfgs)
fft_aman = core.AxisManager(
aman.dets,
core.OffsetAxis("nusamps",len(freqs))
)

fft_aman = core.AxisManager(aman.dets,
core.OffsetAxis("nusamps", len(freqs)))
pxx_axis_map = [(0, "dets"), (1, "nusamps")]
if self.calc_cfgs.get('subscan', False):
fft_aman.wrap("Pxx_ss", Pxx, pxx_axis_map+[(2, aman.subscans)])
Pxx = np.nanmean(Pxx, axis=-1) # Mean of subscans

fft_aman.wrap("freqs", freqs, [(0,"nusamps")])
fft_aman.wrap("Pxx", Pxx, [(0,"dets"), (1,"nusamps")])
fft_aman.wrap("Pxx", Pxx, pxx_axis_map)

self.save(proc_aman, fft_aman)

def save(self, proc_aman, fft_aman):
if not(self.save_cfgs is None):
proc_aman.wrap(self.wrap, fft_aman)
def plot(self, aman, proc_aman, filename):
if self.plot_cfgs is None:
return
if self.plot_cfgs:
from .preprocess_plot import plot_psd

filename = filename.replace('{ctime}', f'{str(aman.timestamps[0])[:5]}')
filename = filename.replace('{obsid}', aman.obs_info.obs_id)
det = aman.dets.vals[0]
ufm = det.split('_')[2]
filename = filename.replace('{name}', f'{ufm}_{self.wrap}')

plot_psd(aman, signal=attrgetter(f"{self.wrap}.Pxx")(proc_aman),
xx=attrgetter(f"{self.wrap}.freqs")(proc_aman), filename=filename, **self.plot_cfgs)


class GetStats(_Preprocess):
""" Get basic statistics from a TOD or its power spectrum.

Example config block:

- name : "tod_stats"
signal: "signal" # optional
wrap: "tod_stats" # optional
calc:
stat_names: ["median", "std"]
split_subscans: False # optional
psd_mask: # optional, for cutting a power spectrum in frequency
freqs: "psd.freqs"
low_f: 1
high_f: 10
save: True

"""
name = "tod_stats"
def __init__(self, step_cfgs):
self.signal = step_cfgs.get('signal', 'signal')
self.wrap = step_cfgs.get('wrap', 'tod_stats')

super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
if self.calc_cfgs.get('psd_mask') is not None:
mask_dict = self.calc_cfgs.get('psd_mask')
_f = attrgetter(mask_dict['freqs'])
try:
freqs = _f(aman)
except KeyError:
freqs = _f(proc_aman)
low_f, high_f = mask_dict['low_f'], mask_dict['high_f']
fmask = np.all([freqs >= low_f, freqs <= high_f], axis=0)
self.calc_cfgs['mask'] = fmask
del self.calc_cfgs['psd_mask']

_f = attrgetter(self.signal)
try:
signal = _f(aman)
except KeyError:
signal = _f(proc_aman)
stats_aman = tod_ops.flags.get_stats(aman, signal, **self.calc_cfgs)
self.save(proc_aman, stats_aman)

def save(self, proc_aman, stats_aman):
if not(self.save_cfgs is None):
proc_aman.wrap(self.wrap, stats_aman)

def plot(self, aman, proc_aman, filename):
if self.plot_cfgs is None:
return
if self.plot_cfgs:
from .preprocess_plot import plot_signal

filename = filename.replace('{ctime}', f'{str(aman.timestamps[0])[:5]}')
filename = filename.replace('{obsid}', aman.obs_info.obs_id)
det = aman.dets.vals[0]
ufm = det.split('_')[2]
filename = filename.replace('{name}', f'{ufm}_{self.signal}')

plot_signal(aman, signal_name=self.signal, x_name="timestamps", filename=filename, **self.plot_cfgs)

class Noise(_Preprocess):
"""Estimate the white noise levels in the data. Assumes the PSD has been
wrapped into the preprocessing AxisManager. All calculation configs goes to `calc_wn`.
wrapped into the preprocessing AxisManager. All calculation configs goes to `calc_wn`.

Saves the results into the "noise" field of proc_aman.

Expand All @@ -391,6 +478,8 @@ class Noise(_Preprocess):
Example config block::

- name: "noise"
fit: False
subscan: False
calc:
low_f: 5
high_f: 10
Expand All @@ -408,28 +497,36 @@ class Noise(_Preprocess):
def __init__(self, step_cfgs):
self.psd = step_cfgs.get('psd', 'psd')
self.fit = step_cfgs.get('fit', False)
self.subscan = step_cfgs.get('subscan', False)

super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
if self.psd not in proc_aman:
raise ValueError("PSD is not saved in Preprocessing AxisManager")
psd = proc_aman[self.psd]

pxx = psd.Pxx_ss if self.subscan else psd.Pxx

if self.calc_cfgs is None:
self.calc_cfgs = {}

if self.fit:
calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=psd.Pxx,
if self.calc_cfgs.get('subscan') is None:
self.calc_cfgs['subscan'] = self.subscan
calc_aman = tod_ops.fft_ops.fit_noise_model(aman, pxx=pxx,
f=psd.freqs,
merge_fit=True,
**self.calc_cfgs)
else:
wn = tod_ops.fft_ops.calc_wn(aman, pxx=psd.Pxx,
wn = tod_ops.fft_ops.calc_wn(aman, pxx=pxx,
freqs=psd.freqs,
**self.calc_cfgs)
calc_aman = core.AxisManager(aman.dets)
calc_aman.wrap("white_noise", wn, [(0,"dets")])
if not self.subscan:
calc_aman = core.AxisManager(aman.dets)
calc_aman.wrap("white_noise", wn, [(0,"dets")])
else:
calc_aman = core.AxisManager(aman.dets, aman.subscan_info.subscans)
calc_aman.wrap("white_noise", wn, [(0,"dets"), (1,"subscans")])

self.save(proc_aman, calc_aman)

Expand Down Expand Up @@ -457,10 +554,12 @@ def select(self, meta, proc_aman=None):
self.select_cfgs['name'] = self.select_cfgs.get('name','noise')

if self.fit:
keep = proc_aman[self.select_cfgs['name']].fit[:,1] <= self.select_cfgs["max_noise"]
wn = proc_aman[self.select_cfgs['name']].fit[:,1]
else:
keep = proc_aman[self.select_cfgs['name']].white_noise <= self.select_cfgs["max_noise"]

wn = proc_aman[self.select_cfgs['name']].white_noise
if self.subscan:
wn = np.nanmean(wn, axis=-1) # Mean over subscans
keep = wn <= np.float64(self.select_cfgs["max_noise"])
meta.restrict("dets", meta.dets.vals[keep])
return meta

Expand Down Expand Up @@ -786,6 +885,9 @@ def calc_and_save(self, aman, proc_aman):
calc_aman = core.AxisManager(aman.dets, aman.samps)
calc_aman.wrap('turnarounds', ta, [(0, 'dets'), (1, 'samps')])

if ('merge_subscans' not in self.calc_cfgs) or (self.calc_cfgs['merge_subscans']):
calc_aman.wrap('subscan_info', aman.subscan_info)

self.save(proc_aman, calc_aman)

def save(self, proc_aman, turn_aman):
Expand Down Expand Up @@ -1083,9 +1185,9 @@ class PCARelCal(_Preprocess):
yfac: 1.5
calc_good_medianw: True
lpf:
type: "low_pass_sine2"
type: "sine2"
cutoff: 1
width: 0.1
trans_width: 0.1
trim_samps: 2000
save: True
plot:
Expand All @@ -1102,6 +1204,7 @@ def __init__(self, step_cfgs):
super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
self.plot_signal = self.signal
if self.calc_cfgs.get("lpf") is not None:
filt = tod_ops.filters.get_lpf(self.calc_cfgs.get("lpf"))
filt_tod = tod_ops.fourier_filter(aman, filt, signal_name='signal')
Expand All @@ -1117,6 +1220,8 @@ def calc_and_save(self, aman, proc_aman):
proc_aman.samps.offset + proc_aman.samps.count - trim))
filt_aman.restrict('samps', (filt_aman.samps.offset + trim,
filt_aman.samps.offset + filt_aman.samps.count - trim))
if self.plot_cfgs:
self.plot_signal = filt_aman[self.signal]

bands = np.unique(aman.det_info.wafer.bandpass)
bands = bands[bands != 'NC']
Expand Down Expand Up @@ -1184,7 +1289,7 @@ def plot(self, aman, proc_aman, filename):
for band in bands:
pca_aman = aman.restrict('dets', aman.dets.vals[proc_aman[self.run_name][f'{band}_idx']], in_place=False)
band_aman = proc_aman[self.run_name].restrict('dets', aman.dets.vals[proc_aman[self.run_name][f'{band}_idx']], in_place=False)
plot_pcabounds(pca_aman, band_aman, filename=filename.replace('{name}', f'{ufm}_{band}_pca'), signal=self.signal, band=band, plot_ds_factor=self.plot_cfgs.get('plot_ds_factor', 20))
plot_pcabounds(pca_aman, band_aman, filename=filename.replace('{name}', f'{ufm}_{band}_pca'), signal=self.plot_signal, band=band, plot_ds_factor=self.plot_cfgs.get('plot_ds_factor', 20))


class PTPFlags(_Preprocess):
Expand Down Expand Up @@ -1384,3 +1489,4 @@ def save(self, proc_aman, split_flg_aman):
_Preprocess.register(DarkDets)
_Preprocess.register(SourceFlags)
_Preprocess.register(HWPAngleModel)
_Preprocess.register(GetStats)
Loading
Loading