Skip to content

Commit

Permalink
Chnaged split peaks into a generated_jit.
Browse files Browse the repository at this point in the history
  • Loading branch information
WenzDaniel committed Jun 24, 2020
1 parent c8ab35a commit f34d0f3
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ def __call__(self, peaks, records, to_pe, data_type, next_ri,

is_split = np.zeros(len(peaks), dtype=np.bool_)

new_peaks = self._split_peaks(
new_peaks = self._gernate_split_peaks(
# Numba doesn't like self as argument, but it's ok with functions...
split_finder=self.find_split_points,
peaks=peaks,
is_split=is_split,
orig_dt=records[0]['dt'],
min_area=min_area,
args_options=tuple(args_options),
data_type=data_type,
result_dtype=peaks.dtype)

if is_split.sum() != 0:
Expand All @@ -109,9 +110,40 @@ def __call__(self, peaks, records, to_pe, data_type, next_ri,

@staticmethod
@strax.growing_result(dtype=strax.peak_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
@numba.generated_jit(nopython=True, nogil=True, cache=True)
def _gernate_split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
data_type, args_options,
_result_buffer=None, result_dtype=None):

if data_type == 'peaks':
def specific_return(p, r, split_i, bonus_output):
if split_i == NO_MORE_SPLITS:
p['max_goodness_of_split'] = bonus_output
# although the iteration will end anyway afterwards:
return
r['max_gap'] = -1 # Too lazy to compute this

elif data_type == 'hitlets':
def specific_return(p, r, split_i, bonus_output):
if split_i == NO_MORE_SPLITS:
return
r['record_i'] = p['record_i']
else:
raise ValueError('Unknown data_type. Data_type is not supported.')

def sp(split_finder, peaks, orig_dt, is_split, min_area, args_options,
_result_buffer, result_dtype):
return split_finder._split_peaks(split_finder, peaks, orig_dt, is_split, min_area, args_options,
specific_return=specific_return, _result_buffer=_result_buffer,
result_dtype=result_dtype)

return sp


@staticmethod
@numba.jit(nopython=True, nogil=True, cache=True)
def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
data_type, args_options,
args_options, specific_return,
_result_buffer=None, result_dtype=None):
# TODO NEEDS TESTS!
new_peaks = _result_buffer
Expand All @@ -127,17 +159,11 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
for split_i, bonus_output in split_finder(
w, p['dt'], p_i, *args_options):

# Data_kind specific fields:
if data_type == 'peaks':
if split_i == NO_MORE_SPLITS:
p['max_goodness_of_split'] = bonus_output
# although the iteration will end anyway afterwards:
continue
r['max_gap'] = -1 # Too lazy to compute this
elif data_type == 'hitlets':
r['record_i'] = p['record_i']
else:
raise ValueError(f'Data_type "{data_type}" is not supported.')
specific_return(r, p, split_i, bonus_output)
if split_i == NO_MORE_SPLITS:
# No idea if this if-statement can be integrated into
# specific return
continue

is_split[p_i] = True
r = new_peaks[offset]
Expand Down

0 comments on commit f34d0f3

Please sign in to comment.