Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API to load detector data with mask #518

Merged
merged 9 commits into from
Jun 4, 2024
4 changes: 4 additions & 0 deletions docs/agipd_lpd_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ DSSC and JUNGFRAU, pulling together the separate modules into a single array.
arranged along the first axis. So ``det['image.data'].ndarray()`` will
load all the selected data as a NumPy array.

.. automethod:: masked_data

.. automethod:: get_array

.. automethod:: get_dask_array
Expand Down Expand Up @@ -58,6 +60,8 @@ DSSC and JUNGFRAU, pulling together the separate modules into a single array.
arranged along the first axis. So ``jf['data.adc'].ndarray()`` will
load all the selected data as a NumPy array.

.. automethod:: masked_data

.. automethod:: get_array

.. automethod:: get_dask_array
Expand Down
146 changes: 131 additions & 15 deletions extra_data/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
import logging
import re
from collections.abc import Iterable
from copy import copy
from warnings import warn

Expand Down Expand Up @@ -122,6 +123,7 @@
_source_re = re.compile(r'(?P<detname>.+)/DET/(\d+)CH')
# Override in subclass
_main_data_key = '' # Key to use for checking data counts match
_mask_data_key = ''
_frames_per_entry = 1 # Override if separate pulse dimension in files
_modnos_start_at = 0 # Override if module numbers start at 1 (JUNGFRAU)
module_shape = (0, 0)
Expand Down Expand Up @@ -176,6 +178,40 @@
def __getitem__(self, item):
return MultimodKeyData(self, item)

def masked_data(self, key=None, *, mask_bits=None, masked_value=np.nan):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since bit logic is not generally intuitive for pragmatic programmers, allowing to pass bad pixel bits by iterable may be nice here as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I guess we should also expose the bad pixel flags somewhere - I don't think we have those at the moment. Maybe they belong in extra.calibration ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely, EXtra was the intended target for some time now. extra.calibration makes sense given it's used in calibration data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to record here, I opened European-XFEL/EXtra#172 to do that.

"""Combine corrected data with the mask in the files

This provides an interface similar to ``det['data.adc']``, but masking
out pixels with the mask from the correction pipeline.

Parameters
----------

key: str
The data key to look at, by default the main data key of the detector
(e.g. 'data.adc').
mask_bits: int or list of ints
Reasons to exclude pixels, as a bitmask or a list of integers.
By default, all types of bad pixel are masked out.
masked_value: int, float
The replacement value to use for masked data. By default this is NaN.
"""
key = key or self._main_data_key
self[self._mask_data_key] # Check that the mask is there
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe raise a more explicit error message here if that happens.

if isinstance(mask_bits, Iterable):
mask_bits = self._combine_bitfield(mask_bits)
return DetectorMaskedKeyData(
self, key, mask_key=self._mask_data_key,
mask_bits=mask_bits, masked_value=masked_value
)

@staticmethod
def _combine_bitfield(ints):
res = 0
for i in ints:
res |= i
return res

@classmethod
def _find_detector_name(cls, data):
detector_names = set()
Expand Down Expand Up @@ -471,6 +507,7 @@
"""
n_modules = 16
_main_data_key = 'image.data'
_mask_data_key = 'image.mask'

def __init__(self, data: DataCollection, detector_name=None, modules=None,
*, min_modules=1):
Expand All @@ -481,6 +518,34 @@
return XtdfImageMultimodKeyData(self, item)
return super().__getitem__(item)

def masked_data(self, key=None, *, mask_bits=None, masked_value=np.nan):
"""Combine corrected data with the mask in the files

This provides an interface similar to ``det['image.data']``, but masking
out pixels with the mask from the correction pipeline.

Parameters
----------

key: str
The data key to look at, by default the main data key of the detector
(e.g. 'image.data').
mask_bits: int or list of ints
Reasons to exclude pixels, as a bitmask or a list of integers.
By default, all types of bad pixel are masked out.
masked_value: int, float
The replacement value to use for masked data. By default this is NaN.
"""
key = key or self._main_data_key
assert key.startswith('image.')
self[self._mask_data_key] # Check that the mask is there
if isinstance(mask_bits, Iterable):
mask_bits = self._combine_bitfield(mask_bits)
return XtdfMaskedKeyData(
self, key, mask_key=self._mask_data_key,
mask_bits=mask_bits, masked_value=masked_value
)

# Several methods below are overridden in LPD1M for parallel gain mode

@staticmethod
Expand Down Expand Up @@ -746,7 +811,6 @@
return res



class MultimodKeyData:
def __init__(self, det: MultimodDetectorBase, key):
self.det = det
Expand All @@ -755,6 +819,9 @@
m: det.data[s, key] for (m, s) in det.modno_to_source.items()
}

def _init_kwargs(self): # Extended in subclasses
return dict(det=self.det, key=self.key)

@property
def train_ids(self):
return self.det.train_ids
Expand Down Expand Up @@ -799,9 +866,11 @@
def dtype(self):
return self._eg_keydata.dtype

# For select_trains() & split_trains() to work correctly with subclasses
def _with_selected_det(self, det_selected):
# Overridden for XtdfImageMultimodKeyData to preserve pulse selection
return MultimodKeyData(det_selected, self.key)
kw = self._init_kwargs()
kw.update(det=det_selected)
return type(self)(**kw)

def select_trains(self, trains):
return self._with_selected_det(self.det.select_trains(trains))
Expand Down Expand Up @@ -831,13 +900,16 @@
)
return out

def xarray(self, *, fill_value=None, roi=(), astype=None):
def _wrap_xarray(self, arr):

Check notice

Code scanning / CodeQL

Mismatch between signature and use of an overridden method Note

Overridden method signature does not match
call
, where it is passed too many arguments. Overriding method
method XtdfImageMultimodKeyData._wrap_xarray
matches the call.
Overridden method signature does not match
call
, where it is passed too many arguments. Overriding method
method XtdfImageMultimodKeyData._wrap_xarray
matches the call.
from xarray import DataArray
arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype)

coords = {'module': self.modules, 'trainId': self.train_id_coordinates()}
return DataArray(arr, dims=self.dimensions, coords=coords)

def xarray(self, *, fill_value=None, roi=(), astype=None):
arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype)
return self._wrap_xarray(arr)

def dask_array(self, *, labelled=False, fill_value=None, astype=None):
from dask.delayed import delayed
from dask.array import concatenate, from_delayed
Expand All @@ -854,9 +926,7 @@
) for c in split], axis=1)

if labelled:
from xarray import DataArray
coords = {'module': self.modules, 'trainId': self.train_id_coordinates()}
return DataArray(arr, dims=self.dimensions, coords=coords)
return self._wrap_xarray(arr)

return arr

Expand All @@ -880,6 +950,44 @@
return out


class DetectorMaskedKeyData(MultimodKeyData):
def __init__(self, *args, mask_key, mask_bits, masked_value, **kwargs):
super().__init__(*args, **kwargs)
self._mask_key = mask_key
self._mask_bits = mask_bits
self._masked_value = masked_value

def __repr__(self):
return f"<Masked {self.key!r} detector data for {len(self.modules)} modules>"

def _init_kwargs(self):
kw = super()._init_kwargs()
kw.update(
mask_key=self._mask_key,
mask_bits=self._mask_bits,
masked_value=self._masked_value,
)
return kw

def _load_mask(self, module_gaps):
"""Load the mask & convert to boolean (True for bad pixels)"""
mask_data = self.det[self._mask_key].ndarray(module_gaps=module_gaps)
if self._mask_bits is None:
return mask_data != 0 # Skip extra temporary array from &
else:
return (mask_data & self._mask_bits) != 0

def ndarray(self, *, module_gaps=False, **kwargs):
"""Load data into a NumPy array & apply the mask"""
# Load mask first: it shrinks from 4 bytes/px to 1, so peak memory use
# is lower than loading it after the data
mask = self._load_mask(module_gaps=module_gaps)

data = super().ndarray(module_gaps=module_gaps, **kwargs)
data[mask] = self._masked_value
return data


class XtdfImageMultimodKeyData(MultimodKeyData):
_sel_frames_cached = None
det: XtdfDetectorBase
Expand All @@ -890,6 +998,11 @@
entry_shape = self._eg_keydata.entry_shape
self._extraneous_dim = (len(entry_shape) >= 1) and (entry_shape[0] == 1)

def _init_kwargs(self):
kw = super()._init_kwargs()
kw.update(pulse_sel=self._pulse_sel)
return kw

@property
def ndim(self):
return super().ndim - (1 if self._extraneous_dim else 0)
Expand Down Expand Up @@ -958,14 +1071,10 @@
entry_dims = [f'dim_{i}' for i in range(ndim_inner)]
return ['module', 'train_pulse'] + entry_dims

# Used for .select_trains() and .split_trains()
def _with_selected_det(self, det_selected):
return XtdfImageMultimodKeyData(det_selected, self.key, self._pulse_sel)

def select_pulses(self, pulses):
pulses = _check_pulse_selection(pulses)

return XtdfImageMultimodKeyData(self.det, self.key, pulses)
kw = self._init_kwargs()
kw.update(pulse_sel=_check_pulse_selection(pulses))
return type(self)(**kw)

@property
def _sel_frames(self):
Expand Down Expand Up @@ -1102,7 +1211,13 @@

return arr


class XtdfMaskedKeyData(DetectorMaskedKeyData, XtdfImageMultimodKeyData):

Check failure

Code scanning / CodeQL

Missing call to `__init__` during object initialization Error

Class XtdfMaskedKeyData may not be initialized properly as
method XtdfImageMultimodKeyData.__init__
is not called from its
__init__ method
.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a false positive caused by CodeQL not fully understanding super() with multiple inheritance. The __init__ method is inherited from DetectorMaskedKeyData, which does call super().__init__(...). That calls the next __init__ in the method resolution order - not just the parent class - so when it's creating an XtdfMaskedKeyData, it will find XtdfImageMultimodKeyData.__init__().

In [2]: XtdfMaskedKeyData.__mro__
Out[2]: 
(extra_data.components.XtdfMaskedKeyData,
 extra_data.components.DetectorMaskedKeyData,
 extra_data.components.XtdfImageMultimodKeyData,
 extra_data.components.MultimodKeyData,
 object)

I usually avoid multiple inheritance because even if it works, it's a pain to think about stuff like this. In this instance, I started writing the masked KeyData classes as wrappers around the existing classes (has-a relationships rather than is-a), but it was going to be several times more code to provide the same methods and attributes again. So I decided using inheritance was the lesser of two evils.

# Created from xtdf_det.masked_data()
pass


class FramesFileWriter(FileWriter):
"""Write selected detector frames in European XFEL HDF5 format"""
def __init__(self, path, data, inc_tp_ids):
super().__init__(path, data)
Expand Down Expand Up @@ -1625,6 +1740,7 @@
r'(MODULE_|RECEIVER-|JNGFR)(?P<modno>\d+)'
)
_main_data_key = 'data.adc'
_mask_data_key = 'data.mask'
philsmt marked this conversation as resolved.
Show resolved Hide resolved
_modnos_start_at = 1
module_shape = (512, 1024)

Expand Down
10 changes: 10 additions & 0 deletions extra_data/tests/make_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def make_reduced_spb_run(dir_path, raw=True, rng=None, format_version='0.5'):
frames_per_train=frame_counts)
], ntrains=64, chunksize=32, format_version=format_version)

if modno == 9 and not raw:
# For testing masked_data
with h5py.File(path, 'a') as f:
mask_ds = f['INSTRUMENT/SPB_DET_AGIPD1M-1/DET/9CH0:xtdf/image/mask']
mask_ds[0, 0, :32] = np.arange(32)

write_file(osp.join(dir_path, '{}-R0238-DA01-S00000.h5'.format(prefix)),
[ XGM('SA1_XTD2_XGM/DOOCS/MAIN'),
XGM('SPB_XTD9_XGM/DOOCS/MAIN'),
Expand Down Expand Up @@ -408,6 +414,10 @@ def make_fxe_jungfrau_run(dir_path):
write_file(path, [
JUNGFRAUModule(f'FXE_XAD_JF500K/DET/JNGFR03')
], ntrains=100, chunksize=1, format_version='1.0')
with h5py.File(path, 'a') as f:
# For testing masked_data
mask_ds = f['INSTRUMENT/FXE_XAD_JF500K/DET/JNGFR03:daqOutput/data/mask']
mask_ds[0, 0, 0, :32] = np.arange(32)

write_file(osp.join(dir_path, f'RAW-R0052-JNGFRCTRL00-S00000.h5'), [
JUNGFRAUControl('FXE_XAD_JF1M/DET/CONTROL'),
Expand Down
56 changes: 56 additions & 0 deletions extra_data/tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,62 @@ def test_jungfraus_first_modno(mock_jungfrau_run, mock_fxe_jungfrau_run):
assert np.all(arr['module'] == [modno])


def test_jungfrau_masked_data(mock_fxe_jungfrau_run):
run = RunDirectory(mock_fxe_jungfrau_run)
jf = JUNGFRAU(run, 'FXE_XAD_JF500K')

# Default options
kd = jf.masked_data().select_trains(np.s_[:1])
arr = kd.ndarray()
assert arr.shape == (1, 1, 16, 512, 1024)
assert arr.dtype == np.float32
line0 = np.zeros(1024, dtype=np.float32)
line0[1:32] = np.nan
np.testing.assert_array_equal(arr[0, 0, 0, 0, :], line0)

# Xarray
xarr = kd.xarray()
assert xarr.dims[:2] == ('module', 'trainId')
np.testing.assert_array_equal(xarr.values[0, 0, 0, 0, :], line0)

# Specify which mask bits to use, & replace masked values with 99
kd = jf.masked_data(mask_bits=1, masked_value=99).select_trains(np.s_[:1])
arr = kd.ndarray()
assert arr.shape == (1, 1, 16, 512, 1024)
line0 = np.zeros(1024, dtype=np.float32)
line0[1:32:2] = 99
np.testing.assert_array_equal(arr[0, 0, 0, 0, :], line0)

# Different field
kd = jf.masked_data('data.gain', masked_value=255).select_trains(np.s_[:1])
arr = kd.ndarray()
assert arr.shape == (1, 1, 16, 512, 1024)
assert arr.dtype == np.uint8
line0 = np.zeros(1024, dtype=np.uint8)
line0[1:32] = 255
np.testing.assert_array_equal(arr[0, 0, 0, 0, :], line0)


def test_xtdf_masked_data(mock_reduced_spb_proc_run):
run = RunDirectory(mock_reduced_spb_proc_run)
agipd = AGIPD1M(run, modules=[8, 9])

kd = agipd.masked_data().select_trains(np.s_[:1])
assert kd.shape == (2, kd.shape[1], 512, 128)
arr = kd.ndarray()
assert arr.shape == kd.shape
assert arr.dtype == np.float32
line0_2mod = np.zeros((2, 128), dtype=np.float32)
line0_2mod[1, 1:32] = np.nan
np.testing.assert_array_equal(arr[:, 0, 0, :], line0_2mod)

kd = agipd.masked_data(mask_bits=[1, 4], masked_value=-1).select_trains(np.s_[:1])
arr = kd.ndarray()
line0_2mod = np.zeros((2, 128), dtype=np.float32)
line0_2mod[1, np.nonzero(np.arange(32) & 5)] = -1
np.testing.assert_array_equal(arr[:, 0, 0, :], line0_2mod)


def test_get_dask_array(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
det = LPD1M(run)
Expand Down
Loading