Skip to content

Commit

Permalink
Merge pull request #407 from esi-neuroscience/392-serializable-cfg-en…
Browse files Browse the repository at this point in the history
…tries

FIX: Serialize cfg entries
  • Loading branch information
dfsp-spirit authored Jan 5, 2023
2 parents 13ed76e + 2563103 commit 7ecd1e8
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 65 deletions.
70 changes: 62 additions & 8 deletions syncopy/shared/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

# Builtin/3rd party package imports
import numpy as np
from numbers import Number
from copy import deepcopy
import inspect
import json

# Local imports
from syncopy.shared.errors import SPYValueError, SPYWarning, SPYTypeError
from syncopy.shared.parsers import sequence_parser

__all__ = ["StructDict", "get_defaults"]

Expand Down Expand Up @@ -88,7 +90,6 @@ def __deepcopy__(self, memo):
return result



class SerializableDict(dict):

"""
Expand All @@ -111,18 +112,55 @@ def is_json(self, key, value):
json.dumps(value)
except TypeError:
lgl = "serializable data type, e.g. floats, lists, tuples, ... "
raise SPYTypeError(value, f"value for key '{key}'", lgl)
raise SPYTypeError(value, f"value {value} for key '{key}'", lgl)
try:
json.dumps(key)
except TypeError:
lgl = "serializable data type, e.g. floats, lists, tuples, ... "
raise SPYTypeError(value, f"key '{key}'", lgl)


def get_frontend_cfg(defaults, lcls, kwargs):
def _serialize_value(value):
"""
Helper to serialize 1-level deep sequences (lists, arrays, ranges) or
single numbers/strings as ``value``s.
Main task is to get rid of numpy data types which are not
serializable (e.i. np.int64).
"""
Assemble cfg dict to allow direct replay of frontend calls

if isinstance(value, np.ndarray):
value = value.tolist()

if isinstance(value, range):
value = list(value)

# unpack the list, if ppl mix types this will go wrong
if isinstance(value, list):
if hasattr(value[0], 'is_integer'):
value = [float(v) for v in value]
# should only be the integers
elif isinstance(value[0], Number) and not isinstance(value[0], bool):
value = [int(v) for v in value]

# singleton/non-sequence type entries
if isinstance(value, Number) and not isinstance(value, bool):
# all floating types have this method
if hasattr(value, 'is_integer'):
# get rid of np.int64 or np.float32
value = int(value) if value.is_integer() else float(value)
else:
value = int(value)

return value


def get_frontend_cfg(defaults, lcls, kwargs):
"""
Assemble serializable cfg dict to allow direct replay of frontend calls
Most parsing is done in the respective frontends, the config values
should be straightforward to serialize.
Parameters
----------
Expand All @@ -147,13 +185,29 @@ def get_frontend_cfg(defaults, lcls, kwargs):

# create new cfg dict
new_cfg = StructDict()

for par_name in defaults:
# check only needed for injected kwargs like `parallel`
# check only set parameters
if par_name in lcls:
new_cfg[par_name] = lcls[par_name]
# attach additional kwargs (like select)
value = _serialize_value(lcls[par_name])
new_cfg[par_name] = value

# 'select' only allowed dictionary parameter within kwargs
# we can 'pop' here as selection got digested beforehand by @unwrap_select
sdict = kwargs.pop('select', None)
if sdict is not None:
# serialized selection dict
ser_sdict = dict()
for sel_key in sdict:
ser_sdict[sel_key] = _serialize_value(sdict[sel_key])
new_cfg['select'] = ser_sdict

# should only be 'parallel' and 'chan_per_worker'
for key in kwargs:
new_cfg[key] = kwargs[key]
new_cfg[key] = _serialize_value(kwargs[key])

# use instantiation for a final check
SerializableDict(new_cfg)

return new_cfg

Expand Down
4 changes: 2 additions & 2 deletions syncopy/tests/test_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from syncopy.shared.tools import StructDict


availableFrontend_cfgs = {'freqanalysis': {'method': 'mtmconvol', 't_ftimwin': 0.1},
availableFrontend_cfgs = {'freqanalysis': {'method': 'mtmconvol', 't_ftimwin': 0.1, 'foi': np.arange(1,60)},
'preprocessing': {'freq': 10, 'filter_class': 'firws', 'filter_type': 'hp'},
'resampledata': {'resamplefs': 125, 'lpfreq': 60},
'connectivityanalysis': {'method': 'coh', 'tapsmofrq': 5},
'selectdata': {'trials': [1, 7, 3], 'channel': [2, 0]}
'selectdata': {'trials': np.array([1, 7, 3]), 'channel': [np.int64(2), 0]}
}


Expand Down
2 changes: 0 additions & 2 deletions syncopy/tests/test_continuousdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
["channel03", "channel01", "channel01", "channel02"], # string selection w/repetition + unordered
[4, 2, 2, 5, 5], # repetition + unordered
range(5, 8), # narrow range
slice(-2, None), # negative-start slice
"channel02", # str selection
1 # scalar selection
]
Expand All @@ -66,7 +65,6 @@
0, # scalar selection
[0, 1, 1, 2, 3], # preserve repetition, don't convert to slice
range(2, 5), # narrow range
slice(0, 5, 2), # slice w/non-unitary step-size
]
timeSelections = list(zip(["latency"] * len(latencySelections), latencySelections))
freqSelections = list(zip(["frequency"] * len(frequencySelections), frequencySelections))
Expand Down
60 changes: 7 additions & 53 deletions syncopy/tests/test_selectdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,8 @@ class TestSelector():
[0, 0, 1, 2, 3], # preserve repetition, don't convert to slice
range(0, 3),
range(5, 8),
slice(None),
None,
"all",
slice(0, 5),
slice(7, None),
slice(2, 8),
slice(0, 10, 2),
slice(-2, None),
[0, 1, 2, 3], # contiguous list...
[2, 3, 5]), # non-contiguous list...
"result": ([2, 0],
Expand All @@ -79,49 +73,29 @@ class TestSelector():
slice(5, 8, 1),
slice(None, None, 1),
slice(None, None, 1),
slice(None, None, 1),
slice(0, 5, 1),
slice(7, None, 1),
slice(2, 8, 1),
slice(0, 10, 2),
slice(-2, None, 1),
slice(0, 4, 1), # ...gets converted to slice
[2, 3, 5]), # stays as is
"invalid": (["channel200", "channel400"],
["invalid"],
tuple("wrongtype"),
"notall",
range(0, 100),
slice(80, None),
slice(-20, None),
slice(-15, -2),
slice(5, 1),
[40, 60, 80]),
"errors": (SPYValueError,
SPYValueError,
SPYTypeError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError)}

selectDict["taper"] = {"valid": ([4, 2, 3],
[4, 2, 2, 3], # repetition
[0, 1, 1, 2, 3], # preserve repetition, don't convert to slice
range(0, 3),
range(2, 5),
slice(None),
None,
"all",
0, # scalar
slice(0, 5),
slice(3, None),
slice(2, 4),
slice(0, 5, 2),
slice(-2, None),
[0, 1, 2, 3], # contiguous list...
[1, 3, 4]), # non-contiguous list...
"result": ([4, 2, 3],
Expand All @@ -131,13 +105,7 @@ class TestSelector():
slice(2, 5, 1),
slice(None, None, 1),
slice(None, None, 1),
slice(None, None, 1),
[0],
slice(0, 5, 1),
slice(3, None, 1),
slice(2, 4, 1),
slice(0, 5, 2),
slice(-2, None, 1),
slice(0, 4, 1), # ...gets converted to slice
[1, 3, 4]), # stays as is
"invalid": (["taper_typo", "channel400"],
Expand Down Expand Up @@ -168,16 +136,9 @@ class TestSelector():
[0, 0, 2, 3], # preserve repetition, don't convert to slice
range(0, 3),
range(2, 5),
slice(None),
None,
"all",
"unit3", # string -> scalar
4, # scalar
slice(0, 5),
slice(3, None),
slice(2, 4),
slice(0, 5, 2),
slice(-2, None),
[0, 1, 2, 3], # contiguous list...
[1, 3, 4]), # non-contiguous list...
"invalid": (["unit7", "unit77"],
Expand Down Expand Up @@ -205,14 +166,8 @@ class TestSelector():
[0, 0, 1, 2], # preserve repetition, don't convert to slice
range(0, 2),
range(1, 2),
slice(None),
None,
"all",
1, # scalar
slice(0, 2),
slice(1, None),
slice(0, 1),
slice(-1, None),
[0, 1]), # contiguous list...
"invalid": (["eventid", "eventid"],
tuple("wrongtype"),
Expand All @@ -237,23 +192,21 @@ class TestSelector():
tuple("wrongtype"),
"notall",
range(0, 10),
slice(0, 5),
[np.nan, 1],
[0.5, 1.5 , 2.0], # more than 2 components
[2.0, 1.5]), # lower bound > upper bound
"errors": (SPYValueError,
SPYTypeError,
SPYValueError,
SPYTypeError,
SPYTypeError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError)}
selectDict["frequency"] = {"invalid": (["notnumeric", "stillnotnumeric"],
tuple("wrongtype"),
"notall",
range(0, 10),
slice(0, 5),
[np.nan, 1],
[-1, 2], # lower limit out of bounds
[2, 900], # upper limit out of bounds
Expand All @@ -263,7 +216,7 @@ class TestSelector():
SPYTypeError,
SPYValueError,
SPYTypeError,
SPYTypeError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
Expand Down Expand Up @@ -347,6 +300,8 @@ def test_general(self):

# alternate (expensive) way to get by-trial selection indices
result = []
print(prop)
print(selects, selection)
for trial in discrete.trials:
if selects[0] is None:
res = slice(0, trial.shape[0], 1)
Expand All @@ -359,6 +314,8 @@ def test_general(self):
if steps.min() == steps.max() == 1:
res = slice(res[0], res[-1] + 1, 1)
result.append(res)
print(result)
print()
allResults.append(result)

self.selectDict[prop]["result"] = tuple(allResults)
Expand Down Expand Up @@ -505,9 +462,6 @@ def test_general(self):
else:
idx = [slice(None)] * len(dummy.dimord)
idx[dummy.dimord.index(prop)] = solution
print(idx, solution, prop)
print(np.array(dummy.data)[tuple(idx)])
print(selected.data[()])
assert np.array_equal(np.array(dummy.data)[tuple(idx)],
selected.data)
assert np.array_equal(getattr(selected, prop),
Expand All @@ -531,7 +485,7 @@ def test_general(self):
else:
# ensure objects that don't have `time` props complain properly
with pytest.raises(SPYValueError):
Selector(dummy, {"latency": [0]})
Selector(dummy, {"latency": [-.5]})

# ensure invalid `frequency` specifications trigger expected errors
if hasattr(dummy, "freq"):
Expand Down

0 comments on commit 7ecd1e8

Please sign in to comment.