diff --git a/strax/processing/peak_building.py b/strax/processing/peak_building.py index 569f04db9..4c9498d87 100644 --- a/strax/processing/peak_building.py +++ b/strax/processing/peak_building.py @@ -209,7 +209,7 @@ def sum_waveform(peaks, records, adc_to_pe): p['time'] // dt, n_p) max_in_record = r['data'][r_start:r_end].max() * multiplier - p['saturated_channel'][ch] |= int(max_in_record >= r['baseline']) + p['saturated_channel'][ch] |= np.int8(max_in_record >= r['baseline']) bl_fpart = r['baseline'] % 1 # TODO: check numba does casting correctly here! diff --git a/strax/processing/peak_splitting.py b/strax/processing/peak_splitting.py index 3fe88f251..117111d6b 100644 --- a/strax/processing/peak_splitting.py +++ b/strax/processing/peak_splitting.py @@ -16,7 +16,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum', :param records: Records from which peaks were built :param to_pe: ADC to PE conversion factor array (of n_channels) :param algorithm: 'local_minimum' or 'natural_breaks'. - :param data_type: 'peaks' or 'hitlets'. Specifies whether to use + :param data_type: 'peaks' or 'hitlets'. Specifies whether to use sum_wavefrom or get_hitlets_data to compute the waveform of the new split peaks/hitlets. :param result_dtype: dtype of the result. @@ -25,7 +25,7 @@ def split_peaks(peaks, records, to_pe, algorithm='local_minimum', """ splitter = dict(local_minimum=LocalMinimumSplitter, natural_breaks=NaturalBreaksSplitter)[algorithm]() - + if data_type == 'hitlets': # This is only needed once. _, next_ri = strax.record_links(records) @@ -62,7 +62,7 @@ class PeakSplitter: find_split_args_defaults: tuple def __call__(self, peaks, records, to_pe, data_type, - next_ri=None, do_iterations=1, min_area=0, **kwargs): + next_ri=None, do_iterations=1, min_area=0, **kwargs): if not len(records) or not len(peaks) or not do_iterations: return peaks @@ -86,24 +86,12 @@ def __call__(self, peaks, records, to_pe, data_type, is_split = np.zeros(len(peaks), dtype=np.bool_) - # data_kind specific_outputs: - if data_type == 'peaks': - @numba.njit - def specific_output(r, p, split_i, bonus_output): - if split_i == NO_MORE_SPLITS: - p['max_goodness_of_split'] = bonus_output - # although the iteration will end anyway afterwards: - r['max_gap'] = -1 # Too lazy to compute this + split_function = {'peaks': self._split_peaks, + 'hitlets': self._split_hitlets} + if data_type not in split_function: + raise ValueError(f'Data_type "{data_type}" is not supported.') - elif data_type == 'hitlets': - @numba.njit - def specific_output(r, p, split_i, bonus_output): - if split_i == NO_MORE_SPLITS: - return - r['record_i'] = p['record_i'] - else: - raise TypeError(f'Unknown data_type. "{data_type}" is not supported.') - new_peaks = self._split_peaks( + new_peaks = split_function[data_type]( # Numba doesn't like self as argument, but it's ok with functions... split_finder=self.find_split_points, peaks=peaks, @@ -111,17 +99,15 @@ def specific_output(r, p, split_i, bonus_output): orig_dt=records[0]['dt'], min_area=min_area, args_options=tuple(args_options), - specific_output=specific_output, result_dtype=peaks.dtype) - + if is_split.sum() != 0: # Found new peaks: compute basic properties if data_type == 'peaks': strax.sum_waveform(new_peaks, records, to_pe) elif data_type == 'hitlets': + # Add record fields here strax.update_new_hitlets(new_peaks, records, next_ri, to_pe) - else: - raise ValueError(f'Data_type "{data_type}" is not supported.') strax.compute_widths(new_peaks) @@ -138,12 +124,14 @@ def specific_output(r, p, split_i, bonus_output): @strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4)) @numba.jit(nopython=True, nogil=True) def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, - specific_output, args_options, + args_options, _result_buffer=None, result_dtype=None): """Loop over peaks, pass waveforms to algorithm, construct new peaks if and where a split occurs. """ - # TODO NEEDS TESTS! + # NB: code very similar to _split_hitlets see + # github.com/AxFoundation/strax/pull/309 for more info. Keep in mind + # that changing one function should also be reflected in the other. new_peaks = _result_buffer offset = 0 @@ -155,18 +143,13 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, w = p['data'][:p['length']] for split_i, bonus_output in split_finder( w, p['dt'], p_i, *args_options): - - # This is a bit odd here. Due tp the specific_outputs we have to get r - # although we may not need it at all, but I do not see any nice way around - # this. - r = new_peaks[offset] - specific_output(r, p, split_i, bonus_output) if split_i == NO_MORE_SPLITS: - # No idea if this if-statement can be integrated into - # specific return + p['max_goodness_of_split'] = bonus_output + # although the iteration will end anyway afterwards: continue is_split[p_i] = True + r = new_peaks[offset] r['time'] = p['time'] + prev_split_i * p['dt'] r['channel'] = p['channel'] # Set the dt to the original (lowest) dt first; @@ -174,7 +157,7 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, # is computed r['dt'] = orig_dt r['length'] = (split_i - prev_split_i) * p['dt'] / orig_dt - + r['max_gap'] = -1 # Too lazy to compute this if r['length'] <= 0: print(p['data']) print(prev_split_i, split_i) @@ -189,6 +172,58 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area, yield offset + @staticmethod + @strax.growing_result(dtype=strax.hitlet_dtype(), chunk_size=int(1e4)) + @numba.jit(nopython=True, nogil=True) + def _split_hitlets(split_finder, peaks, orig_dt, is_split, min_area, + args_options, + _result_buffer=None, result_dtype=None): + """Loop over hits, pass waveforms to algorithm, construct + new hits if and where a split occurs. + """ + # TODO NEEDS TESTS! + # NB: code very similar to _split_peaks see + # github.com/AxFoundation/strax/pull/309 for more info. Keep in mind + # that changing one function should also be reflected in the other. + new_hits = _result_buffer + offset = 0 + + for h_i, h in enumerate(peaks): + if h['area'] < min_area: + continue + + prev_split_i = 0 + w = h['data'][:h['length']] + for split_i, bonus_output in split_finder( + w, h['dt'], h_i, *args_options): + if split_i == NO_MORE_SPLITS: + continue + + is_split[h_i] = True + r = new_hits[offset] + r['time'] = h['time'] + prev_split_i * h['dt'] + r['channel'] = h['channel'] + # Hitlet specific + r['record_i'] = h['record_i'] + # Set the dt to the original (lowest) dt first; + # this may change when the sum waveform of the new peak + # is computed + r['dt'] = orig_dt + r['length'] = (split_i - prev_split_i) * h['dt'] / orig_dt + if r['length'] <= 0: + print(h['data']) + print(prev_split_i, split_i) + raise ValueError("Attempt to create invalid hitlet!") + + offset += 1 + if offset == len(new_hits): + yield offset + offset = 0 + + prev_split_i = split_i + + yield offset + @staticmethod def find_split_points(w, dt, peak_i, *args_options): """This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter @@ -267,7 +302,7 @@ class NaturalBreaksSplitter(PeakSplitter): close as we can get to it given the peaks sampling) on either side. """ find_split_args_defaults = ( - ('threshold', None), # will be a numpy array of len(peaks) + ('threshold', None), # will be a numpy array of len(peaks) ('normalize', False), ('split_low', False), ('filter_wing_width', 0))