Skip to content

Commit

Permalink
Pick out selection applying function (#826)
Browse files Browse the repository at this point in the history
* Pickout selection applying function

* Happy strict flake8
  • Loading branch information
dachengx authored Apr 24, 2024
1 parent 67a7831 commit b9f4641
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ def _update_progress_bar(pbar, t_start, t_end, n_chunks, chunk_end, nbytes):
pbar.mbs.append((nbytes / 1e6) / seconds_per_chunk)
mbs = np.mean(pbar.mbs)
if mbs < 1:
rate = f"{mbs*1000:.1f} kB/s"
rate = f"{mbs * 1000:.1f} kB/s"
else:
rate = f"{mbs:.1f} MB/s"
postfix = f"#{n_chunks} ({seconds_per_chunk:.2f} s). {rate}"
Expand Down
3 changes: 2 additions & 1 deletion strax/plugins/overlap_window_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def do_compute(self, chunk_i=None, **kwargs):
chunk_starts_are_equal = len(unique_starts) == 1
if chunk_starts_are_equal:
self.log.debug(
f"Success after {try_counter}. Extra time = {cache_inputs_beyond-prev_split} ns"
f"Success after {try_counter}. "
f"Extra time = {cache_inputs_beyond - prev_split} ns"
)
break
else:
Expand Down
2 changes: 1 addition & 1 deletion strax/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def discarder(source):
self.mailboxes = dict(self.mailboxes)
self.log.debug(
f"Created the following mailboxes: {self.mailboxes} with the "
f"following threads: {[(d, m._threads) for d,m in self.mailboxes.items()]}"
f"following threads: {[(d, m._threads) for d, m in self.mailboxes.items()]}"
)

def iter(self):
Expand Down
32 changes: 21 additions & 11 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def multi_run(
tasks_done += 1
_run_id = futures.pop(f)
log.debug(
f"Done with run_id: {_run_id} and {len(run_id_numpy)-tasks_done} are left."
f"Done with run_id: {_run_id} and {len(run_id_numpy) - tasks_done} are left."
)
pbar.update(1)
if f.exception() is not None:
Expand Down Expand Up @@ -630,7 +630,7 @@ def multi_run(
pbar.close()
if ignore_errors and len(failures):
log.warning(
f"Failures for {len(failures)/len(run_ids):.0%} of runs. Failed for: {failures}"
f"Failures for {len(failures) / len(run_ids):.0%} of runs. Failed for: {failures}"
)
return final_result

Expand Down Expand Up @@ -670,6 +670,24 @@ def iter_chunk_meta(md):
yield c


@export
def parse_selection(x, selection):
"""Parse a selection string into a mask that can be used to filter data.
:param selection: Query string, sequence of strings, or simple function to apply.
:return: Boolean indicating the selected items.
"""
if hasattr(selection, "__call__"):
mask = selection(x)
else:
if isinstance(selection, (list, tuple)):
selection = " & ".join(f"({x})" for x in selection)

mask = numexpr.evaluate(selection, local_dict={fn: x[fn] for fn in x.dtype.names})
return mask


@export
def apply_selection(
x,
Expand Down Expand Up @@ -722,15 +740,7 @@ def apply_selection(
selection = selection_str

if selection:
if hasattr(selection, "__call__"):
mask = selection(x)
else:
if isinstance(selection, (list, tuple)):
selection = " & ".join(f"({x})" for x in selection)

mask = numexpr.evaluate(selection, local_dict={fn: x[fn] for fn in x.dtype.names})

x = x[mask]
x = x[parse_selection(x, selection)]

if keep_columns:
keep_columns = strax.to_str_tuple(keep_columns)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_mailbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_reader(source):
assert hasattr(test_reader, "got")
assert test_reader.got == list(range(10))
mb.cleanup()
threads = [f"{t.name} is dead: {True^t.is_alive()}" for t in threading.enumerate()]
threads = [f"{t.name} is dead: {True ^ t.is_alive()}" for t in threading.enumerate()]
assert (
len(threads) == n_threads_start
), f"Not all threads died. \n Threads running are:{threads}"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_peak_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def retrun_1(x):
assert strax.endtime(r_buffer[-1]) - r_buffer["time"].min() > magic_overflow_time
r = r_buffer.copy()
del r_buffer
print(f"Array is {r.nbytes/1e6} MB, good luck")
print(f"Array is {r.nbytes / 1e6} MB, good luck")

# Do peak finding!
print(f"Find hits")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_superruns.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_rechnunking_and_loading(self):
chunks = [chunk for chunk in self.context.get_iter("_superrun_test_rechunking", "records")]
assert len(chunks) > 1, (
"Number of chunks should be larger 1. "
f"{chunks[0].target_size_mb, chunks[0].nbytes/10**6}"
f"{chunks[0].target_size_mb, chunks[0].nbytes / 10**6}"
)
assert np.all(rr_superrun["time"] == rr_subruns["time"])

Expand Down

0 comments on commit b9f4641

Please sign in to comment.