Skip to content

Commit

Permalink
Merge pull request #576 from European-XFEL/fix/read-big-data-pulse-sel
Browse files Browse the repository at this point in the history
Avoid giant memory allocations when reading xtdf data with pulse selection
  • Loading branch information
takluyver authored Nov 27, 2024
2 parents a5a025c + 67fe7d3 commit 68a0fde
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions extra_data/components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Interfaces to data from specific instruments
"""
import logging
import math
import re
from collections.abc import Iterable
from copy import copy
Expand Down Expand Up @@ -312,7 +313,7 @@ def _select_trains(cls, data, mod_data_counts, min_modules):
return data.select_trains(by_id[train_ids])

@staticmethod
def _split_align_chunk(chunk, target_train_ids: np.ndarray):
def _split_align_chunk(chunk, target_train_ids: np.ndarray, length_limit=np.inf):
"""
Split up a source chunk to align with parts of a joined array.
Expand All @@ -328,6 +329,9 @@ def _split_align_chunk(chunk, target_train_ids: np.ndarray):
target_train_ids: numpy.ndarray
Train ID index for target array to align chunk data to. Train IDs may
occur more than once in here.
length_limit: int
Maximum length of slices (stop - start) to yield. Larger slices will
be split up into several pieces. Unlimited by default.
"""
# Expand the list of train IDs to one per frame
chunk_tids = np.repeat(chunk.train_ids, chunk.counts.astype(np.intp))
Expand All @@ -353,14 +357,16 @@ def _split_align_chunk(chunk, target_train_ids: np.ndarray):
else:
n_match = len(chunk_tids)

# Select the matching data
chunk_match_end = chunk_match_start + n_match
tgt_end = tgt_start + n_match

yield slice(tgt_start, tgt_end), slice(chunk_match_start, chunk_match_end)
# Split the matched data if needed for length_limit
n_batches = max(math.ceil(n_match / length_limit), 1)
for i in range(n_batches):
start = i * n_match // n_batches
stop = (i + 1) * n_match // n_batches
yield (slice(tgt_start + start, tgt_start + stop),
slice(chunk_match_start + start, chunk_match_start + stop))

# Prepare remaining data in the chunk for the next match
chunk_match_start = chunk_match_end
chunk_match_start += n_match
chunk_tids = chunk_tids[n_match:]

@property
Expand Down Expand Up @@ -1157,8 +1163,15 @@ def _sel_frames(self):

def _read_chunk(self, chunk: DataChunk, mod_out, roi):
"""Read per-pulse data from file into an output array (of 1 module)"""
# Limit to 5 GB sections of the dataset at once, so the temporary
# arrays used in the workaround below are not too large.
nbytes_frame = chunk.dataset.dtype.itemsize
for dim in chunk.dataset.shape[1:]:
nbytes_frame *= dim
frame_limit = 5 * (1024 ** 3) // nbytes_frame

for tgt_slice, chunk_slice in self.det._split_align_chunk(
chunk, self.det.train_ids_perframe
chunk, self.det.train_ids_perframe, length_limit=frame_limit
):
inc_pulses_chunk = self._sel_frames[tgt_slice]
if inc_pulses_chunk.sum() == 0: # No data from this chunk selected
Expand All @@ -1176,29 +1189,30 @@ def _read_chunk(self, chunk: DataChunk, mod_out, roi):
# Except it's fast if you read the data to a matching selection in
# memory (one weird trick).
# So as a workaround, this allocates a temporary array of the same
# shape as the dataset, reads into it, and then copies the selected
# shape as the full chunk, reads into it, and then copies the selected
# data to the output array. The extra memory copy is not optimal,
# but it's better than the HDF5 performance issue, at least in some
# realistic cases.
# N.B. tmp should only use memory for the data it contains -
# zeros() uses calloc, so the OS can do virtual memory tricks.
# Don't change this to zeros_like() !
tmp = np.zeros(chunk.dataset.shape, chunk.dataset.dtype)
pulse_sel = np.nonzero(inc_pulses_chunk)[0] + chunk_slice.start
sel_region = (pulse_sel,) + roi
tmp = np.zeros(
shape=inc_pulses_chunk.shape + chunk.dataset.shape[1:],
dtype=chunk.dataset.dtype
)
tmp_sel = np.nonzero(inc_pulses_chunk)[0]
dataset_sel = tmp_sel + chunk_slice.start
chunk.dataset.read_direct(
tmp, source_sel=sel_region, dest_sel=sel_region,
tmp, source_sel=(dataset_sel,) + roi, dest_sel=(tmp_sel,) + roi,
)
# Where does this data go in the target array?
tgt_start_ix = self._sel_frames[:tgt_slice.start].sum()
tgt_pulse_sel = slice(
tgt_start_ix, tgt_start_ix + inc_pulses_chunk.sum()
)
# Copy data from temp array to output array
tmp_frames_mask = np.zeros(len(tmp), dtype=np.bool_)
tmp_frames_mask[pulse_sel] = True
np.compress(
tmp_frames_mask, tmp[np.index_exp[:] + roi],
inc_pulses_chunk, tmp[np.index_exp[:] + roi],
axis=0, out=mod_out[tgt_pulse_sel]
)

Expand Down

0 comments on commit 68a0fde

Please sign in to comment.