-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API to load detector data with mask #518
Changes from 8 commits
4f5aa00
d89c969
af0a279
1f7fb93
137ed9a
f336dc4
3251393
31459bc
259f2c3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
""" | ||
import logging | ||
import re | ||
from collections.abc import Iterable | ||
from copy import copy | ||
from warnings import warn | ||
|
||
|
@@ -122,6 +123,7 @@ | |
_source_re = re.compile(r'(?P<detname>.+)/DET/(\d+)CH') | ||
# Override in subclass | ||
_main_data_key = '' # Key to use for checking data counts match | ||
_mask_data_key = '' | ||
_frames_per_entry = 1 # Override if separate pulse dimension in files | ||
_modnos_start_at = 0 # Override if module numbers start at 1 (JUNGFRAU) | ||
module_shape = (0, 0) | ||
|
@@ -176,6 +178,40 @@ | |
def __getitem__(self, item): | ||
return MultimodKeyData(self, item) | ||
|
||
def masked_data(self, key=None, *, mask_bits=None, masked_value=np.nan): | ||
"""Combine corrected data with the mask in the files | ||
|
||
This provides an interface similar to ``det['data.adc']``, but masking | ||
out pixels with the mask from the correction pipeline. | ||
|
||
Parameters | ||
---------- | ||
|
||
key: str | ||
The data key to look at, by default the main data key of the detector | ||
(e.g. 'data.adc'). | ||
mask_bits: int or list of ints | ||
Reasons to exclude pixels, as a bitmask or a list of integers. | ||
By default, all types of bad pixel are masked out. | ||
masked_value: int, float | ||
The replacement value to use for masked data. By default this is NaN. | ||
""" | ||
key = key or self._main_data_key | ||
self[self._mask_data_key] # Check that the mask is there | ||
|
||
if isinstance(mask_bits, Iterable): | ||
mask_bits = self._combine_bitfield(mask_bits) | ||
return DetectorMaskedKeyData( | ||
self, key, mask_key=self._mask_data_key, | ||
mask_bits=mask_bits, masked_value=masked_value | ||
) | ||
|
||
@staticmethod | ||
def _combine_bitfield(ints): | ||
res = 0 | ||
for i in ints: | ||
res |= i | ||
return res | ||
|
||
@classmethod | ||
def _find_detector_name(cls, data): | ||
detector_names = set() | ||
|
@@ -471,6 +507,7 @@ | |
""" | ||
n_modules = 16 | ||
_main_data_key = 'image.data' | ||
_mask_data_key = 'image.mask' | ||
|
||
def __init__(self, data: DataCollection, detector_name=None, modules=None, | ||
*, min_modules=1): | ||
|
@@ -481,6 +518,34 @@ | |
return XtdfImageMultimodKeyData(self, item) | ||
return super().__getitem__(item) | ||
|
||
def masked_data(self, key=None, *, mask_bits=None, masked_value=np.nan): | ||
"""Combine corrected data with the mask in the files | ||
|
||
This provides an interface similar to ``det['image.data']``, but masking | ||
out pixels with the mask from the correction pipeline. | ||
|
||
Parameters | ||
---------- | ||
|
||
key: str | ||
The data key to look at, by default the main data key of the detector | ||
(e.g. 'image.data'). | ||
mask_bits: int or list of ints | ||
Reasons to exclude pixels, as a bitmask or a list of integers. | ||
By default, all types of bad pixel are masked out. | ||
masked_value: int, float | ||
The replacement value to use for masked data. By default this is NaN. | ||
""" | ||
key = key or self._main_data_key | ||
assert key.startswith('image.') | ||
self[self._mask_data_key] # Check that the mask is there | ||
if isinstance(mask_bits, Iterable): | ||
mask_bits = self._combine_bitfield(mask_bits) | ||
return XtdfMaskedKeyData( | ||
self, key, mask_key=self._mask_data_key, | ||
mask_bits=mask_bits, masked_value=masked_value | ||
) | ||
|
||
# Several methods below are overridden in LPD1M for parallel gain mode | ||
|
||
@staticmethod | ||
|
@@ -746,7 +811,6 @@ | |
return res | ||
|
||
|
||
|
||
class MultimodKeyData: | ||
def __init__(self, det: MultimodDetectorBase, key): | ||
self.det = det | ||
|
@@ -755,6 +819,9 @@ | |
m: det.data[s, key] for (m, s) in det.modno_to_source.items() | ||
} | ||
|
||
def _init_kwargs(self): # Extended in subclasses | ||
return dict(det=self.det, key=self.key) | ||
|
||
@property | ||
def train_ids(self): | ||
return self.det.train_ids | ||
|
@@ -799,9 +866,11 @@ | |
def dtype(self): | ||
return self._eg_keydata.dtype | ||
|
||
# For select_trains() & split_trains() to work correctly with subclasses | ||
def _with_selected_det(self, det_selected): | ||
# Overridden for XtdfImageMultimodKeyData to preserve pulse selection | ||
return MultimodKeyData(det_selected, self.key) | ||
kw = self._init_kwargs() | ||
kw.update(det=det_selected) | ||
return type(self)(**kw) | ||
|
||
def select_trains(self, trains): | ||
return self._with_selected_det(self.det.select_trains(trains)) | ||
|
@@ -831,13 +900,16 @@ | |
) | ||
return out | ||
|
||
def xarray(self, *, fill_value=None, roi=(), astype=None): | ||
def _wrap_xarray(self, arr): | ||
Check notice Code scanning / CodeQL Mismatch between signature and use of an overridden method Note
Overridden method signature does not match
call Error loading related location Loading method XtdfImageMultimodKeyData._wrap_xarray Error loading related location Loading Overridden method signature does not match call Error loading related location Loading method XtdfImageMultimodKeyData._wrap_xarray Error loading related location Loading |
||
from xarray import DataArray | ||
arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) | ||
|
||
coords = {'module': self.modules, 'trainId': self.train_id_coordinates()} | ||
return DataArray(arr, dims=self.dimensions, coords=coords) | ||
|
||
def xarray(self, *, fill_value=None, roi=(), astype=None): | ||
arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype) | ||
return self._wrap_xarray(arr) | ||
|
||
def dask_array(self, *, labelled=False, fill_value=None, astype=None): | ||
from dask.delayed import delayed | ||
from dask.array import concatenate, from_delayed | ||
|
@@ -854,9 +926,7 @@ | |
) for c in split], axis=1) | ||
|
||
if labelled: | ||
from xarray import DataArray | ||
coords = {'module': self.modules, 'trainId': self.train_id_coordinates()} | ||
return DataArray(arr, dims=self.dimensions, coords=coords) | ||
return self._wrap_xarray(arr) | ||
|
||
return arr | ||
|
||
|
@@ -880,6 +950,44 @@ | |
return out | ||
|
||
|
||
class DetectorMaskedKeyData(MultimodKeyData): | ||
def __init__(self, *args, mask_key, mask_bits, masked_value, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._mask_key = mask_key | ||
self._mask_bits = mask_bits | ||
self._masked_value = masked_value | ||
|
||
def __repr__(self): | ||
return f"<Masked {self.key!r} detector data for {len(self.modules)} modules>" | ||
|
||
def _init_kwargs(self): | ||
kw = super()._init_kwargs() | ||
kw.update( | ||
mask_key=self._mask_key, | ||
mask_bits=self._mask_bits, | ||
masked_value=self._masked_value, | ||
) | ||
return kw | ||
|
||
def _load_mask(self, module_gaps): | ||
"""Load the mask & convert to boolean (True for bad pixels)""" | ||
mask_data = self.det[self._mask_key].ndarray(module_gaps=module_gaps) | ||
if self._mask_bits is None: | ||
return mask_data != 0 # Skip extra temporary array from & | ||
else: | ||
return (mask_data & self._mask_bits) != 0 | ||
|
||
def ndarray(self, *, module_gaps=False, **kwargs): | ||
"""Load data into a NumPy array & apply the mask""" | ||
# Load mask first: it shrinks from 4 bytes/px to 1, so peak memory use | ||
# is lower than loading it after the data | ||
mask = self._load_mask(module_gaps=module_gaps) | ||
|
||
data = super().ndarray(module_gaps=module_gaps, **kwargs) | ||
data[mask] = self._masked_value | ||
return data | ||
|
||
|
||
class XtdfImageMultimodKeyData(MultimodKeyData): | ||
_sel_frames_cached = None | ||
det: XtdfDetectorBase | ||
|
@@ -890,6 +998,11 @@ | |
entry_shape = self._eg_keydata.entry_shape | ||
self._extraneous_dim = (len(entry_shape) >= 1) and (entry_shape[0] == 1) | ||
|
||
def _init_kwargs(self): | ||
kw = super()._init_kwargs() | ||
kw.update(pulse_sel=self._pulse_sel) | ||
return kw | ||
|
||
@property | ||
def ndim(self): | ||
return super().ndim - (1 if self._extraneous_dim else 0) | ||
|
@@ -958,14 +1071,10 @@ | |
entry_dims = [f'dim_{i}' for i in range(ndim_inner)] | ||
return ['module', 'train_pulse'] + entry_dims | ||
|
||
# Used for .select_trains() and .split_trains() | ||
def _with_selected_det(self, det_selected): | ||
return XtdfImageMultimodKeyData(det_selected, self.key, self._pulse_sel) | ||
|
||
def select_pulses(self, pulses): | ||
pulses = _check_pulse_selection(pulses) | ||
|
||
return XtdfImageMultimodKeyData(self.det, self.key, pulses) | ||
kw = self._init_kwargs() | ||
kw.update(pulse_sel=_check_pulse_selection(pulses)) | ||
return type(self)(**kw) | ||
|
||
@property | ||
def _sel_frames(self): | ||
|
@@ -1102,7 +1211,13 @@ | |
|
||
return arr | ||
|
||
|
||
class XtdfMaskedKeyData(DetectorMaskedKeyData, XtdfImageMultimodKeyData): | ||
Check failure Code scanning / CodeQL Missing call to `__init__` during object initialization Error
Class XtdfMaskedKeyData may not be initialized properly as
method XtdfImageMultimodKeyData.__init__ Error loading related location Loading __init__ method Error loading related location Loading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe this is a false positive caused by CodeQL not fully understanding
I usually avoid multiple inheritance because even if it works, it's a pain to think about stuff like this. In this instance, I started writing the masked KeyData classes as wrappers around the existing classes (has-a relationships rather than is-a), but it was going to be several times more code to provide the same methods and attributes again. So I decided using inheritance was the lesser of two evils. |
||
# Created from xtdf_det.masked_data() | ||
pass | ||
|
||
|
||
class FramesFileWriter(FileWriter): | ||
"""Write selected detector frames in European XFEL HDF5 format""" | ||
def __init__(self, path, data, inc_tp_ids): | ||
super().__init__(path, data) | ||
|
@@ -1625,6 +1740,7 @@ | |
r'(MODULE_|RECEIVER-|JNGFR)(?P<modno>\d+)' | ||
) | ||
_main_data_key = 'data.adc' | ||
_mask_data_key = 'data.mask' | ||
philsmt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_modnos_start_at = 1 | ||
module_shape = (512, 1024) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since bit logic is not generally intuitive for pragmatic programmers, allowing to pass bad pixel bits by iterable may be nice here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense. I guess we should also expose the bad pixel flags somewhere - I don't think we have those at the moment. Maybe they belong in
extra.calibration
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely, EXtra was the intended target for some time now.
extra.calibration
makes sense given it's used in calibration data.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to record here, I opened European-XFEL/EXtra#172 to do that.