Skip to content

Commit

Permalink
add default jitter axes to default config, and add tests for that. Al…
Browse files Browse the repository at this point in the history
…so, add **kwds arguments to add_jitter
  • Loading branch information
rettigl committed Oct 9, 2023
1 parent 8f49f40 commit 503e9f2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
2 changes: 2 additions & 0 deletions sed/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ dataframe:
tof_binning: 1
# binning factor used for the adc coordinate (2^(adc_binning-1))
adc_binning: 1
# list of columns to apply jitter to
jitter_cols: ["@x_column", "@y_column", "@tof_column"]

energy:
# Number of bins to use for energy calibration traces
Expand Down
14 changes: 8 additions & 6 deletions sed/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,23 +1180,25 @@ def calibrate_delay_axis(
else:
print(self._dataframe)

def add_jitter(self, cols: Sequence[str] = None):
def add_jitter(self, cols: List[str] = None, **kwds):
"""Add jitter to the selected dataframe columns.
Args:
cols (Sequence[str], optional): The colums onto which to apply jitter.
cols (List[str], optional): The colums onto which to apply jitter.
Defaults to config["dataframe"]["jitter_cols"].
**kwds: keyword arguments passed to apply_jitter
"""
if cols is None:
cols = self._config["dataframe"].get(
"jitter_cols",
self._dataframe.columns,
) # jitter all columns
cols = self._config["dataframe"]["jitter_cols"]
for loc, col in enumerate(cols):
if col.startswith("@"):
cols[loc] = self._config["dataframe"].get(col.strip("@"))

self._dataframe = self._dataframe.map_partitions(
apply_jitter,
cols=cols,
cols_jittered=cols,
**kwds,
)
metadata = []
for col in cols:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,11 +581,15 @@ def test_add_jitter():
system_config={},
)
res1 = processor.dataframe["X"].compute()
res1a = processor.dataframe["ADC"].compute()
processor.add_jitter()
res2 = processor.dataframe["X"].compute()
res2a = processor.dataframe["ADC"].compute()
np.testing.assert_allclose(res1, np.round(res1))
np.testing.assert_allclose(res1, np.round(res2))
assert (res1 != res2).all()
# test that jittering is not applied on ADC column
np.testing.assert_allclose(res1a, res2a)


def test_event_histogram():
Expand Down

0 comments on commit 503e9f2

Please sign in to comment.