Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Serialize cfg entries #407

Merged
merged 5 commits into from
Jan 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 62 additions & 8 deletions syncopy/shared/tools.py
Original file line number Diff line number Diff line change
@@ -5,12 +5,14 @@

# Builtin/3rd party package imports
import numpy as np
from numbers import Number
from copy import deepcopy
import inspect
import json

# Local imports
from syncopy.shared.errors import SPYValueError, SPYWarning, SPYTypeError
from syncopy.shared.parsers import sequence_parser

__all__ = ["StructDict", "get_defaults"]

@@ -88,7 +90,6 @@ def __deepcopy__(self, memo):
return result



class SerializableDict(dict):

"""
@@ -111,18 +112,55 @@ def is_json(self, key, value):
json.dumps(value)
except TypeError:
lgl = "serializable data type, e.g. floats, lists, tuples, ... "
raise SPYTypeError(value, f"value for key '{key}'", lgl)
raise SPYTypeError(value, f"value {value} for key '{key}'", lgl)
try:
json.dumps(key)
except TypeError:
lgl = "serializable data type, e.g. floats, lists, tuples, ... "
raise SPYTypeError(value, f"key '{key}'", lgl)


def get_frontend_cfg(defaults, lcls, kwargs):
def _serialize_value(value):
"""
Helper to serialize 1-level deep sequences (lists, arrays, ranges) or
single numbers/strings as ``value``s.

Main task is to get rid of numpy data types which are not
serializable (e.i. np.int64).
"""
Assemble cfg dict to allow direct replay of frontend calls

if isinstance(value, np.ndarray):
value = value.tolist()

if isinstance(value, range):
value = list(value)

# unpack the list, if ppl mix types this will go wrong
if isinstance(value, list):
if hasattr(value[0], 'is_integer'):
value = [float(v) for v in value]
# should only be the integers
elif isinstance(value[0], Number) and not isinstance(value[0], bool):
value = [int(v) for v in value]

# singleton/non-sequence type entries
if isinstance(value, Number) and not isinstance(value, bool):
# all floating types have this method
if hasattr(value, 'is_integer'):
# get rid of np.int64 or np.float32
value = int(value) if value.is_integer() else float(value)
else:
value = int(value)

return value


def get_frontend_cfg(defaults, lcls, kwargs):
"""
Assemble serializable cfg dict to allow direct replay of frontend calls

Most parsing is done in the respective frontends, the config values
should be straightforward to serialize.

Parameters
----------
@@ -147,13 +185,29 @@ def get_frontend_cfg(defaults, lcls, kwargs):

# create new cfg dict
new_cfg = StructDict()

for par_name in defaults:
# check only needed for injected kwargs like `parallel`
# check only set parameters
if par_name in lcls:
new_cfg[par_name] = lcls[par_name]
# attach additional kwargs (like select)
value = _serialize_value(lcls[par_name])
new_cfg[par_name] = value

# 'select' only allowed dictionary parameter within kwargs
# we can 'pop' here as selection got digested beforehand by @unwrap_select
sdict = kwargs.pop('select', None)
if sdict is not None:
# serialized selection dict
ser_sdict = dict()
for sel_key in sdict:
ser_sdict[sel_key] = _serialize_value(sdict[sel_key])
new_cfg['select'] = ser_sdict

# should only be 'parallel' and 'chan_per_worker'
for key in kwargs:
new_cfg[key] = kwargs[key]
new_cfg[key] = _serialize_value(kwargs[key])

# use instantiation for a final check
SerializableDict(new_cfg)

return new_cfg

4 changes: 2 additions & 2 deletions syncopy/tests/test_cfg.py
Original file line number Diff line number Diff line change
@@ -17,11 +17,11 @@
from syncopy.shared.tools import StructDict


availableFrontend_cfgs = {'freqanalysis': {'method': 'mtmconvol', 't_ftimwin': 0.1},
availableFrontend_cfgs = {'freqanalysis': {'method': 'mtmconvol', 't_ftimwin': 0.1, 'foi': np.arange(1,60)},
'preprocessing': {'freq': 10, 'filter_class': 'firws', 'filter_type': 'hp'},
'resampledata': {'resamplefs': 125, 'lpfreq': 60},
'connectivityanalysis': {'method': 'coh', 'tapsmofrq': 5},
'selectdata': {'trials': [1, 7, 3], 'channel': [2, 0]}
'selectdata': {'trials': np.array([1, 7, 3]), 'channel': [np.int64(2), 0]}
}


2 changes: 0 additions & 2 deletions syncopy/tests/test_continuousdata.py
Original file line number Diff line number Diff line change
@@ -45,7 +45,6 @@
["channel03", "channel01", "channel01", "channel02"], # string selection w/repetition + unordered
[4, 2, 2, 5, 5], # repetition + unordered
range(5, 8), # narrow range
slice(-2, None), # negative-start slice
"channel02", # str selection
1 # scalar selection
]
@@ -66,7 +65,6 @@
0, # scalar selection
[0, 1, 1, 2, 3], # preserve repetition, don't convert to slice
range(2, 5), # narrow range
slice(0, 5, 2), # slice w/non-unitary step-size
]
timeSelections = list(zip(["latency"] * len(latencySelections), latencySelections))
freqSelections = list(zip(["frequency"] * len(frequencySelections), frequencySelections))
60 changes: 7 additions & 53 deletions syncopy/tests/test_selectdata.py
Original file line number Diff line number Diff line change
@@ -57,14 +57,8 @@ class TestSelector():
[0, 0, 1, 2, 3], # preserve repetition, don't convert to slice
range(0, 3),
range(5, 8),
slice(None),
None,
"all",
slice(0, 5),
slice(7, None),
slice(2, 8),
slice(0, 10, 2),
slice(-2, None),
[0, 1, 2, 3], # contiguous list...
[2, 3, 5]), # non-contiguous list...
"result": ([2, 0],
@@ -79,49 +73,29 @@ class TestSelector():
slice(5, 8, 1),
slice(None, None, 1),
slice(None, None, 1),
slice(None, None, 1),
slice(0, 5, 1),
slice(7, None, 1),
slice(2, 8, 1),
slice(0, 10, 2),
slice(-2, None, 1),
slice(0, 4, 1), # ...gets converted to slice
[2, 3, 5]), # stays as is
"invalid": (["channel200", "channel400"],
["invalid"],
tuple("wrongtype"),
"notall",
range(0, 100),
slice(80, None),
slice(-20, None),
slice(-15, -2),
slice(5, 1),
[40, 60, 80]),
"errors": (SPYValueError,
SPYValueError,
SPYTypeError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError)}

selectDict["taper"] = {"valid": ([4, 2, 3],
[4, 2, 2, 3], # repetition
[0, 1, 1, 2, 3], # preserve repetition, don't convert to slice
range(0, 3),
range(2, 5),
slice(None),
None,
"all",
0, # scalar
slice(0, 5),
slice(3, None),
slice(2, 4),
slice(0, 5, 2),
slice(-2, None),
[0, 1, 2, 3], # contiguous list...
[1, 3, 4]), # non-contiguous list...
"result": ([4, 2, 3],
@@ -131,13 +105,7 @@ class TestSelector():
slice(2, 5, 1),
slice(None, None, 1),
slice(None, None, 1),
slice(None, None, 1),
[0],
slice(0, 5, 1),
slice(3, None, 1),
slice(2, 4, 1),
slice(0, 5, 2),
slice(-2, None, 1),
slice(0, 4, 1), # ...gets converted to slice
[1, 3, 4]), # stays as is
"invalid": (["taper_typo", "channel400"],
@@ -168,16 +136,9 @@ class TestSelector():
[0, 0, 2, 3], # preserve repetition, don't convert to slice
range(0, 3),
range(2, 5),
slice(None),
None,
"all",
"unit3", # string -> scalar
4, # scalar
slice(0, 5),
slice(3, None),
slice(2, 4),
slice(0, 5, 2),
slice(-2, None),
[0, 1, 2, 3], # contiguous list...
[1, 3, 4]), # non-contiguous list...
"invalid": (["unit7", "unit77"],
@@ -205,14 +166,8 @@ class TestSelector():
[0, 0, 1, 2], # preserve repetition, don't convert to slice
range(0, 2),
range(1, 2),
slice(None),
None,
"all",
1, # scalar
slice(0, 2),
slice(1, None),
slice(0, 1),
slice(-1, None),
[0, 1]), # contiguous list...
"invalid": (["eventid", "eventid"],
tuple("wrongtype"),
@@ -237,23 +192,21 @@ class TestSelector():
tuple("wrongtype"),
"notall",
range(0, 10),
slice(0, 5),
[np.nan, 1],
[0.5, 1.5 , 2.0], # more than 2 components
[2.0, 1.5]), # lower bound > upper bound
"errors": (SPYValueError,
SPYTypeError,
SPYValueError,
SPYTypeError,
SPYTypeError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError)}
selectDict["frequency"] = {"invalid": (["notnumeric", "stillnotnumeric"],
tuple("wrongtype"),
"notall",
range(0, 10),
slice(0, 5),
[np.nan, 1],
[-1, 2], # lower limit out of bounds
[2, 900], # upper limit out of bounds
@@ -263,7 +216,7 @@ class TestSelector():
SPYTypeError,
SPYValueError,
SPYTypeError,
SPYTypeError,
SPYValueError,
SPYValueError,
SPYValueError,
SPYValueError,
@@ -347,6 +300,8 @@ def test_general(self):

# alternate (expensive) way to get by-trial selection indices
result = []
print(prop)
print(selects, selection)
for trial in discrete.trials:
if selects[0] is None:
res = slice(0, trial.shape[0], 1)
@@ -359,6 +314,8 @@ def test_general(self):
if steps.min() == steps.max() == 1:
res = slice(res[0], res[-1] + 1, 1)
result.append(res)
print(result)
print()
allResults.append(result)

self.selectDict[prop]["result"] = tuple(allResults)
@@ -505,9 +462,6 @@ def test_general(self):
else:
idx = [slice(None)] * len(dummy.dimord)
idx[dummy.dimord.index(prop)] = solution
print(idx, solution, prop)
print(np.array(dummy.data)[tuple(idx)])
print(selected.data[()])
assert np.array_equal(np.array(dummy.data)[tuple(idx)],
selected.data)
assert np.array_equal(getattr(selected, prop),
@@ -531,7 +485,7 @@ def test_general(self):
else:
# ensure objects that don't have `time` props complain properly
with pytest.raises(SPYValueError):
Selector(dummy, {"latency": [0]})
Selector(dummy, {"latency": [-.5]})

# ensure invalid `frequency` specifications trigger expected errors
if hasattr(dummy, "freq"):