Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CutList #827

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions strax/plugins/cut_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import strax
from .plugin import Plugin, SaveWhen
from .merge_only_plugin import MergeOnlyPlugin

export, __all__ = strax.exporter()

Expand Down Expand Up @@ -65,3 +66,69 @@ def compute(self, **kwargs):
def cut_by(self, **kwargs):
# This should be provided by the user making a CutPlugin
raise NotImplementedError()


@export
class CutList(MergeOnlyPlugin):
"""Base class that merges all existing cuts into a single array which can be loaded by the
analysts."""

__version__ = "0.0.0"

save_when = SaveWhen.TARGET
cuts = ()
# need to declare depends_on here to satisfy strax
# https://github.com/AxFoundation/strax/blob/df18c9cef38ea1cee9737d56b1bea078ebb246a9/strax/plugin.py#L99
depends_on = ()
_depends_on = ()

def infer_dtype(self):
dtype = super().infer_dtype()
dtype += [
(
(
f"Boolean AND of all cuts in {self.accumulated_cuts_string}",
self.accumulated_cuts_string,
),
np.bool_,
)
]
return dtype

def compute(self, **kwargs):
cuts = super().compute(**kwargs)
cuts_joint = np.zeros(len(cuts), self.dtype)
strax.copy_to_buffer(
cuts, cuts_joint, f"_copy_cuts_{strax.deterministic_hash(self.depends_on)}"
)
cuts_joint[self.accumulated_cuts_string] = get_accumulated_bool(cuts)
return cuts_joint

@property # type: ignore
def depends_on(self): # noqa
if not len(self._depends_on):
deps = []
for c in self.cuts:
deps.extend(strax.to_str_tuple(c.provides))
self._depends_on = tuple(deps)
return self._depends_on

@depends_on.setter
def depends_on(self, str_or_tuple):
self._depends_on = strax.to_str_tuple(str_or_tuple)


@export
def get_accumulated_bool(array):
"""Computes accumulated boolean over all cuts.

:param array: Array containing merged cuts.

"""
fields = array.dtype.names
fields = np.array([f for f in fields if f not in ("time", "endtime")])

res = np.ones(len(array), np.bool_)
for field in fields:
res &= array[field]
return res