Skip to content

Commit

Permalink
Fix memory leak peaksplitting (#309)
Browse files Browse the repository at this point in the history
* Remove hitlet induced memory leak
  • Loading branch information
JoranAngevaare authored Aug 31, 2020
1 parent 7e5ee60 commit e94b1c7
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 37 deletions.
2 changes: 1 addition & 1 deletion strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def sum_waveform(peaks, records, adc_to_pe):
p['time'] // dt, n_p)

max_in_record = r['data'][r_start:r_end].max() * multiplier
p['saturated_channel'][ch] |= int(max_in_record >= r['baseline'])
p['saturated_channel'][ch] |= np.int8(max_in_record >= r['baseline'])

bl_fpart = r['baseline'] % 1
# TODO: check numba does casting correctly here!
Expand Down
107 changes: 71 additions & 36 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
:param records: Records from which peaks were built
:param to_pe: ADC to PE conversion factor array (of n_channels)
:param algorithm: 'local_minimum' or 'natural_breaks'.
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
:param data_type: 'peaks' or 'hitlets'. Specifies whether to use
sum_wavefrom or get_hitlets_data to compute the waveform of
the new split peaks/hitlets.
:param result_dtype: dtype of the result.
Expand All @@ -25,7 +25,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum',
"""
splitter = dict(local_minimum=LocalMinimumSplitter,
natural_breaks=NaturalBreaksSplitter)[algorithm]()

if data_type == 'hitlets':
# This is only needed once.
_, next_ri = strax.record_links(records)
Expand Down Expand Up @@ -62,7 +62,7 @@ class PeakSplitter:
find_split_args_defaults: tuple

def __call__(self, peaks, records, to_pe, data_type,
next_ri=None, do_iterations=1, min_area=0, **kwargs):
next_ri=None, do_iterations=1, min_area=0, **kwargs):
if not len(records) or not len(peaks) or not do_iterations:
return peaks

Expand All @@ -86,42 +86,28 @@ def __call__(self, peaks, records, to_pe, data_type,

is_split = np.zeros(len(peaks), dtype=np.bool_)

# data_kind specific_outputs:
if data_type == 'peaks':
@numba.njit
def specific_output(r, p, split_i, bonus_output):
if split_i == NO_MORE_SPLITS:
p['max_goodness_of_split'] = bonus_output
# although the iteration will end anyway afterwards:
r['max_gap'] = -1 # Too lazy to compute this
split_function = {'peaks': self._split_peaks,
'hitlets': self._split_hitlets}
if data_type not in split_function:
raise ValueError(f'Data_type "{data_type}" is not supported.')

elif data_type == 'hitlets':
@numba.njit
def specific_output(r, p, split_i, bonus_output):
if split_i == NO_MORE_SPLITS:
return
r['record_i'] = p['record_i']
else:
raise TypeError(f'Unknown data_type. "{data_type}" is not supported.')
new_peaks = self._split_peaks(
new_peaks = split_function[data_type](
# Numba doesn't like self as argument, but it's ok with functions...
split_finder=self.find_split_points,
peaks=peaks,
is_split=is_split,
orig_dt=records[0]['dt'],
min_area=min_area,
args_options=tuple(args_options),
specific_output=specific_output,
result_dtype=peaks.dtype)

if is_split.sum() != 0:
# Found new peaks: compute basic properties
if data_type == 'peaks':
strax.sum_waveform(new_peaks, records, to_pe)
elif data_type == 'hitlets':
# Add record fields here
strax.update_new_hitlets(new_peaks, records, next_ri, to_pe)
else:
raise ValueError(f'Data_type "{data_type}" is not supported.')

strax.compute_widths(new_peaks)

Expand All @@ -138,12 +124,14 @@ def specific_output(r, p, split_i, bonus_output):
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
specific_output, args_options,
args_options,
_result_buffer=None, result_dtype=None):
"""Loop over peaks, pass waveforms to algorithm, construct
new peaks if and where a split occurs.
"""
# TODO NEEDS TESTS!
# NB: code very similar to _split_hitlets see
# github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
# that changing one function should also be reflected in the other.
new_peaks = _result_buffer
offset = 0

Expand All @@ -155,26 +143,21 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
w = p['data'][:p['length']]
for split_i, bonus_output in split_finder(
w, p['dt'], p_i, *args_options):

# This is a bit odd here. Due tp the specific_outputs we have to get r
# although we may not need it at all, but I do not see any nice way around
# this.
r = new_peaks[offset]
specific_output(r, p, split_i, bonus_output)
if split_i == NO_MORE_SPLITS:
# No idea if this if-statement can be integrated into
# specific return
p['max_goodness_of_split'] = bonus_output
# although the iteration will end anyway afterwards:
continue

is_split[p_i] = True
r = new_peaks[offset]
r['time'] = p['time'] + prev_split_i * p['dt']
r['channel'] = p['channel']
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt

r['max_gap'] = -1 # Too lazy to compute this
if r['length'] <= 0:
print(p['data'])
print(prev_split_i, split_i)
Expand All @@ -189,6 +172,58 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,

yield offset

@staticmethod
@strax.growing_result(dtype=strax.hitlet_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def _split_hitlets(split_finder, peaks, orig_dt, is_split, min_area,
args_options,
_result_buffer=None, result_dtype=None):
"""Loop over hits, pass waveforms to algorithm, construct
new hits if and where a split occurs.
"""
# TODO NEEDS TESTS!
# NB: code very similar to _split_peaks see
# github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
# that changing one function should also be reflected in the other.
new_hits = _result_buffer
offset = 0

for h_i, h in enumerate(peaks):
if h['area'] < min_area:
continue

prev_split_i = 0
w = h['data'][:h['length']]
for split_i, bonus_output in split_finder(
w, h['dt'], h_i, *args_options):
if split_i == NO_MORE_SPLITS:
continue

is_split[h_i] = True
r = new_hits[offset]
r['time'] = h['time'] + prev_split_i * h['dt']
r['channel'] = h['channel']
# Hitlet specific
r['record_i'] = h['record_i']
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * h['dt'] / orig_dt
if r['length'] <= 0:
print(h['data'])
print(prev_split_i, split_i)
raise ValueError("Attempt to create invalid hitlet!")

offset += 1
if offset == len(new_hits):
yield offset
offset = 0

prev_split_i = split_i

yield offset

@staticmethod
def find_split_points(w, dt, peak_i, *args_options):
"""This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter
Expand Down Expand Up @@ -267,7 +302,7 @@ class NaturalBreaksSplitter(PeakSplitter):
close as we can get to it given the peaks sampling) on either side.
"""
find_split_args_defaults = (
('threshold', None), # will be a numpy array of len(peaks)
('threshold', None), # will be a numpy array of len(peaks)
('normalize', False),
('split_low', False),
('filter_wing_width', 0))
Expand Down

0 comments on commit e94b1c7

Please sign in to comment.