Skip to content

Commit

Permalink
Add some tests for det.masked_data() API
Browse files Browse the repository at this point in the history
  • Loading branch information
takluyver committed May 3, 2024
1 parent 1f7fb93 commit 65219a5
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
1 change: 1 addition & 0 deletions extra_data/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def ndarray(self, *, module_gaps=False, **kwargs):
# Load mask first: it shrinks from 4 bytes/px to 1, so peak memory use
# is lower than loading it after the data
mask = self._load_mask(module_gaps=module_gaps)
print(mask[0, 0, 0, :35])

data = super().ndarray(module_gaps=module_gaps, **kwargs)
data[mask] = self._masked_value
Expand Down
10 changes: 10 additions & 0 deletions extra_data/tests/make_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def make_reduced_spb_run(dir_path, raw=True, rng=None, format_version='0.5'):
frames_per_train=frame_counts)
], ntrains=64, chunksize=32, format_version=format_version)

if modno == 9 and not raw:
# For testing masked_data
with h5py.File(path, 'a') as f:
mask_ds = f['INSTRUMENT/SPB_DET_AGIPD1M-1/DET/9CH0:xtdf/image/mask']
mask_ds[0, 0, :32] = np.arange(32)

write_file(osp.join(dir_path, '{}-R0238-DA01-S00000.h5'.format(prefix)),
[ XGM('SA1_XTD2_XGM/DOOCS/MAIN'),
XGM('SPB_XTD9_XGM/DOOCS/MAIN'),
Expand Down Expand Up @@ -408,6 +414,10 @@ def make_fxe_jungfrau_run(dir_path):
write_file(path, [
JUNGFRAUModule(f'FXE_XAD_JF500K/DET/JNGFR03')
], ntrains=100, chunksize=1, format_version='1.0')
with h5py.File(path, 'a') as f:
# For testing masked_data
mask_ds = f['INSTRUMENT/FXE_XAD_JF500K/DET/JNGFR03:daqOutput/data/mask']
mask_ds[0, 0, 0, :32] = np.arange(32)

write_file(osp.join(dir_path, f'RAW-R0052-JNGFRCTRL00-S00000.h5'), [
JUNGFRAUControl('FXE_XAD_JF1M/DET/CONTROL'),
Expand Down
53 changes: 53 additions & 0 deletions extra_data/tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,59 @@ def test_jungfraus_first_modno(mock_jungfrau_run, mock_fxe_jungfrau_run):
assert np.all(arr['module'] == [modno])


def test_jungfrau_masked_data(mock_fxe_jungfrau_run):
run = RunDirectory(mock_fxe_jungfrau_run)
jf = JUNGFRAU(run, 'FXE_XAD_JF500K')

# Default options
kd = jf.masked_data().select_trains(np.s_[:1])
arr = kd.ndarray()
assert arr.shape == (1, 1, 16, 512, 1024)
line0 = np.zeros(1024, dtype=np.float32)
line0[1:32] = np.nan
np.testing.assert_array_equal(arr[0, 0, 0, 0, :], line0, strict=True)

# Xarray
xarr = kd.xarray()
assert xarr.dims[:2] == ('module', 'trainId')
np.testing.assert_array_equal(xarr.values[0, 0, 0, 0, :], line0, strict=True)

# Specify which mask bits to use, & replace masked values with 99
kd = jf.masked_data(mask_bits=1, masked_value=99).select_trains(np.s_[:1])
arr = kd.ndarray()
assert arr.shape == (1, 1, 16, 512, 1024)
line0 = np.zeros(1024, dtype=np.float32)
line0[1:32:2] = 99
np.testing.assert_array_equal(arr[0, 0, 0, 0, :], line0, strict=True)

# Different field
kd = jf.masked_data('data.gain', masked_value=255).select_trains(np.s_[:1])
arr = kd.ndarray()
assert arr.shape == (1, 1, 16, 512, 1024)
line0 = np.zeros(1024, dtype=np.uint8)
line0[1:32] = 255
np.testing.assert_array_equal(arr[0, 0, 0, 0, :], line0, strict=True)


def test_xtdf_masked_data(mock_reduced_spb_proc_run):
run = RunDirectory(mock_reduced_spb_proc_run)
agipd = AGIPD1M(run, modules=[8, 9])

kd = agipd.masked_data().select_trains(np.s_[:1])
assert kd.shape == (2, kd.shape[1], 512, 128)
arr = kd.ndarray()
assert arr.shape == kd.shape
line0_2mod = np.zeros((2, 128), dtype=np.float32)
line0_2mod[1, 1:32] = np.nan
np.testing.assert_array_equal(arr[:, 0, 0, :], line0_2mod, strict=True)

kd = agipd.masked_data(mask_bits=1, masked_value=-1).select_trains(np.s_[:1])
arr = kd.ndarray()
line0_2mod = np.zeros((2, 128), dtype=np.float32)
line0_2mod[1, 1:32:2] = -1
np.testing.assert_array_equal(arr[:, 0, 0, :], line0_2mod, strict=True)


def test_get_dask_array(mock_fxe_raw_run):
run = RunDirectory(mock_fxe_raw_run)
det = LPD1M(run)
Expand Down

0 comments on commit 65219a5

Please sign in to comment.