Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor hitlets #436

Merged
merged 14 commits into from
May 3, 2021
4 changes: 4 additions & 0 deletions strax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def hitlet_with_data_dtype(n_samples=2):
dtype = hitlet_dtype()
additional_fields = [(('Hitlet data in PE/sample with ZLE (only the first length samples are filled)', 'data'),
np.float32, n_samples),
(('Dummy field required for splitting',
'max_gap'), np.int32),
(('Maximum interior goodness of split',
'max_goodness_of_split'), np.float32),
]

return dtype + additional_fields
Expand Down
69 changes: 45 additions & 24 deletions strax/processing/hitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,45 @@
export, __all__ = strax.exporter()

# Hardcoded numbers:
TRIAL_COUNTER_NEIGHBORING_RECORDS = 100 # Trial counter when looking for hitlet data.
NO_FWXM = -42 # Value in case FWXM cannot be found.

# ----------------------
# Hitlet building:
# ----------------------

@export
def create_hitlets_from_hits(hits,
save_outside_hits,
channel_range,
chunk_start=0,
chunk_end=np.inf,):
"""
Function which creates hitlets from a bunch of hits.

:param hits: Hits found in records.
:param save_outside_hits: Tuple with left and right hit extension.
:param channel_range: Detectors change range from channel map.
:param chunk_start: Start of either a chunk. Ensures that
no hitlet is extended beyond chunk boundaries. You can keep the
defaults if not used inside of a plugin.
:param chunk_end: End of a chunk Ensures that no hitlet is extended
beyond chunk boundaries. You can keep the defaults if not used
inside of a plugin.
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved

:return: Hitlets with temporary fields (data, max_goodness_of_split...)
"""
# Merge concatenate overlapping within a channel. This is important
# in case hits were split by record boundaries. In case we
# accidentally concatenate two PMT signals we split them later again.
hits = strax.concat_overlapping_hits(hits,
save_outside_hits,
channel_range,
chunk_start,
chunk_end, )
hits = strax.sort_by_time(hits)

hitlets = np.zeros(len(hits), strax.hitlet_dtype())
strax.copy_to_buffer(hits, hitlets, '_refresh_hit_to_hitlets')
return hitlets


@export
def concat_overlapping_hits(hits, extensions, pmt_channels, start, end):
"""
Expand Down Expand Up @@ -118,25 +151,7 @@ def _concat_overlapping_hits(hits,


@export
@numba.njit(nogil=True, cache=True)
def refresh_hit_to_hitlets(hits, hitlets):
"""
Function which copies basic hit information into a new hitlet array.
"""
nhits = len(hits)
for ind in range(nhits):
h_new = hitlets[ind]
h_old = hits[ind]

h_new['time'] = h_old['time']
h_new['length'] = h_old['length']
h_new['channel'] = h_old['channel']
h_new['area'] = h_old['area']
h_new['dt'] = h_old['dt']


@export
def get_hitlets_data(hitlets, records, to_pe, min_hitlet_sample=100):
def get_hitlets_data(hitlets, records, to_pe, min_hitlet_sample=200):
"""
Function which searches for every hitlet in a given chunk the
corresponding records data. Additionally compute the total area of
Expand All @@ -152,6 +167,12 @@ def get_hitlets_data(hitlets, records, to_pe, min_hitlet_sample=100):
:returns: Hitlets including data stored in the "data" field
(if it did not exists before it will be added.)
"""
if len(hitlets) == 0:
return np.zeros(0, dtype=strax.hitlet_with_data_dtype(min_hitlet_sample))

if len(hitlets) > 0 and len(records) == 0:
raise ValueError('Cannot get data for hitlets if records are empty!')

# Numba will not raise any exceptions if to_pe is too short, leading
# to strange bugs.
to_pe_has_wrong_shape = len(to_pe) < hitlets['channel'].max()
Expand Down Expand Up @@ -296,7 +317,6 @@ def hitlet_properties(hitlets):
h['range_hdr_80p_area'] = resh[1,1]-resh[1,0]



@export
@numba.njit(cache=True, nogil=True)
def get_fwxm(hitlet, fraction=0.5):
Expand Down Expand Up @@ -374,6 +394,7 @@ def _get_fwxm_boundary(data, max_val):
return ind, s
return len(data)-1, data[-1]


@export
def conditional_entropy(hitlets, template='flat', square_data=False):
"""
Expand Down
60 changes: 1 addition & 59 deletions strax/processing/peak_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,7 @@ def __call__(self, peaks, records, to_pe, data_type,

is_split = np.zeros(len(peaks), dtype=np.bool_)

split_function = {'peaks': self._split_peaks,
'hitlets': self._split_hitlets}
if data_type not in split_function:
raise ValueError(f'Data_type "{data_type}" is not supported.')

new_peaks = split_function[data_type](
new_peaks = self._split_peaks(
# Numba doesn't like self as argument, but it's ok with functions...
split_finder=self.find_split_points,
peaks=peaks,
Expand Down Expand Up @@ -122,9 +117,6 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,
"""Loop over peaks, pass waveforms to algorithm, construct
new peaks if and where a split occurs.
"""
# NB: code very similar to _split_hitlets see
# github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
# that changing one function should also be reflected in the other.
new_peaks = _result_buffer
offset = 0

Expand Down Expand Up @@ -165,56 +157,6 @@ def _split_peaks(split_finder, peaks, orig_dt, is_split, min_area,

yield offset

@staticmethod
@strax.growing_result(dtype=strax.hitlet_dtype(), chunk_size=int(1e4))
@numba.jit(nopython=True, nogil=True)
def _split_hitlets(split_finder, peaks, orig_dt, is_split, min_area,
args_options,
_result_buffer=None, result_dtype=None):
"""Loop over hits, pass waveforms to algorithm, construct
new hits if and where a split occurs.
"""
# TODO NEEDS TESTS!
# NB: code very similar to _split_peaks see
# github.com/AxFoundation/strax/pull/309 for more info. Keep in mind
# that changing one function should also be reflected in the other.
new_hits = _result_buffer
offset = 0

for h_i, h in enumerate(peaks):
if h['area'] < min_area:
continue

prev_split_i = 0
w = h['data'][:h['length']]
for split_i, bonus_output in split_finder(
w, h['dt'], h_i, *args_options):
if split_i == NO_MORE_SPLITS:
continue

is_split[h_i] = True
r = new_hits[offset]
r['time'] = h['time'] + prev_split_i * h['dt']
r['channel'] = h['channel']
# Set the dt to the original (lowest) dt first;
# this may change when the sum waveform of the new peak
# is computed
r['dt'] = orig_dt
r['length'] = (split_i - prev_split_i) * h['dt'] / orig_dt
if r['length'] <= 0:
print(h['data'])
print(prev_split_i, split_i)
raise ValueError("Attempt to create invalid hitlet!")

offset += 1
if offset == len(new_hits):
yield offset
offset = 0

prev_split_i = split_i

yield offset

@staticmethod
def find_split_points(w, dt, peak_i, *args_options):
"""This function is overwritten by LocalMinimumSplitter or LocalMinimumSplitter
Expand Down
50 changes: 28 additions & 22 deletions tests/test_hitlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from strax.testutils import fake_hits


# -----------------------
# Concatenated overlapping hits:
# -----------------------
@given(fake_hits,
fake_hits,
st.integers(min_value=0, max_value=10),
Expand Down Expand Up @@ -68,26 +65,38 @@ def test_concat_overlapping_hits(hits0, hits1, le, re):
assert np.all(mask < 0), f'Found two hits within {ch} which are touching or overlapping'


# -----------------------------
# Test for get_hitlets_data.
# This test is done with some predefined
# records.
# -----------------------------
def test_create_hits_from_hitlets_empty_hits():
hits = np.zeros(0, dtype=strax.hit_dtype)
hitlets = strax.create_hitlets_from_hits(hits, (1, 1), (0, 1))
assert len(hitlets) == 0, 'Hitlets should be empty'


class TestGetHitletData(unittest.TestCase):

def setUp(self):
self.test_data = [1, 3, 2, 1, 0, 0]
self.test_data_truth = self.test_data[:-2]

def make_records_and_hitlets(self, dummy_records, data_field_length=999999):
def make_records_and_hitlets(self, dummy_records):
records = self._make_fake_records(dummy_records)
hits = strax.find_hits(records, min_amplitude=2)
hits = strax.concat_overlapping_hits(hits, (1, 1), (0, 1), 0, float('inf'))
n_samples = min(np.max(hits['length']), data_field_length)
hitlets = np.zeros(len(hits), strax.hitlet_with_data_dtype(n_samples=n_samples))
strax.refresh_hit_to_hitlets(hits, hitlets)
hitlets = strax.create_hitlets_from_hits(hits, (1, 1), (0, 1), 0, float('inf'))
return records, hitlets

def test_inputs_are_empty(self):
records, hitlets = self.make_records_and_hitlets([[self.test_data]])
hitlets_empty = np.zeros(0, dtype=strax.hitlet_with_data_dtype(2))
records_empty = np.zeros(0, dtype=strax.record_dtype(10))

hitlets_result = strax.get_hitlets_data(hitlets_empty, records, np.ones(3000))
assert len(hitlets_result) == 0, 'get_hitlet_data returned result for empty hitlets'

hitlets_result = strax.get_hitlets_data(hitlets_empty, records_empty, np.ones(3000))
assert len(hitlets_result) == 0, 'get_hitlet_data returned result for empty hitlets'

self.assertRaises(ValueError, strax.get_hitlets_data, hitlets, records_empty, np.ones(3000))


def test_to_pe_wrong_shape(self):
records, hitlets = self.make_records_and_hitlets([[self.test_data]])
hitlets['channel'] = 2000
Expand Down Expand Up @@ -115,8 +124,10 @@ def test_get_hitlets_data_without_data_field(self):
self._test_data_is_identical(hitlets, [self.test_data_truth])

def test_to_short_data_field(self):
records, hitlets = self.make_records_and_hitlets([[self.test_data]], 2)
self.assertRaises(ValueError, strax.get_hitlets_data, hitlets, records, np.ones(3000))
records, hitlets = self.make_records_and_hitlets([[self.test_data]])
WenzDaniel marked this conversation as resolved.
Show resolved Hide resolved
hitlets_to_short = np.zeros(len(hitlets), dtype=strax.hitlet_with_data_dtype(2))
strax.copy_to_buffer(hitlets, hitlets_to_short, '_refresh_hit_to_hitlet')
self.assertRaises(ValueError, strax.get_hitlets_data, hitlets_to_short, records, np.ones(3000))

def test_get_hitlets_data(self):
dummy_records = [ # Contains Hitlet #:
Expand Down Expand Up @@ -146,7 +157,7 @@ def test_get_hitlets_data(self):
]

records, hitlets = self.make_records_and_hitlets(dummy_records)
strax.get_hitlets_data(hitlets, records, np.array([1, 1]))
hitlets = strax.get_hitlets_data(hitlets, records, np.ones(2))

for i, (a, wf, t) in enumerate(zip(true_area, true_waveform, true_time)):
h = hitlets[i]
Expand Down Expand Up @@ -212,11 +223,6 @@ def _count_zle_samples(data):
return i


# -----------------------------
# Test for hitlet_properties.
# This test includes the fwxm and
# refresh_hit_to_hitlets.
# -----------------------------
@st.composite
def hits_n_data(draw, strategy):
hits = draw(strategy)
Expand Down Expand Up @@ -293,7 +299,7 @@ def test_hitlet_properties(hits_n_data):
hitlets = np.zeros(len(hits), dtype=strax.hitlet_with_data_dtype(nsamples))
if len(hitlets):
assert hitlets['data'].shape[1] >= 2, 'Data buffer is not at least 2 samples long.'
strax.refresh_hit_to_hitlets(hits, hitlets)
strax.copy_to_buffer(hits, hitlets, '_refresh_hit_to_hitlet_properties_test')

# Testing refresh_hit_to_hitlets for free:
assert len(hits) == len(hitlets), 'Somehow hitlets and hits have different sizes'
Expand Down