From 533425bbd16afa3ff9d001acc58693af62f42e59 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Fri, 30 Dec 2022 10:13:26 +0100 Subject: [PATCH 001/135] FIX: Replace very inefficient discrete _get_trial On my data x20 speedup --- syncopy/datatype/discrete_data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 376c0fbb5..83c88b98b 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -168,7 +168,8 @@ def trialid(self, trlid): def trials(self): """list-like([sample x (>=2)] :class:`numpy.ndarray`) : trial slices of :attr:`data` property""" if self.trialid is not None: - valid_trls = np.unique(self.trialid[self.trialid >= 0]) + valid_trls = np.unique(self.trialid) + valid_trls = valid_trls[valid_trls >= 0] return Indexer(map(self._get_trial, valid_trls), valid_trls.size) else: @@ -184,7 +185,12 @@ def trialtime(self): # Helper function that grabs a single trial def _get_trial(self, trialno): - return self._data[self.trialid == trialno, :] + this_trl = self.trialid == trialno + if not np.any(this_trl): + return self._data[None, :] + st = this_trl.argmax() + end = len(this_trl) - this_trl[st:][::-1].argmax() - 1 + return self._data[st:end, :][this_trl[st:end],:] # Helper function that spawns a `FauxTrial` object given actual trial information def _preview_trial(self, trialno): From 8c1db01e94c7f23c5abf5091c763498ed51c95cb Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Fri, 30 Dec 2022 10:26:54 +0100 Subject: [PATCH 002/135] FIX: end of slice incremented by 1 --- syncopy/datatype/discrete_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 83c88b98b..65b600df5 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -189,7 +189,7 @@ def _get_trial(self, trialno): if not np.any(this_trl): return self._data[None, :] st = this_trl.argmax() - end = len(this_trl) - this_trl[st:][::-1].argmax() - 1 + end = len(this_trl) - this_trl[st:][::-1].argmax() return self._data[st:end, :][this_trl[st:end],:] # Helper function that spawns a `FauxTrial` object given actual trial information From f28eead144b5096805899daeb2c6118796135b87 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Mon, 2 Jan 2023 15:05:41 +0100 Subject: [PATCH 003/135] FIX: return empty array NOT all data --- syncopy/datatype/discrete_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 65b600df5..a8c6c4c7c 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -187,7 +187,7 @@ def trialtime(self): def _get_trial(self, trialno): this_trl = self.trialid == trialno if not np.any(this_trl): - return self._data[None, :] + return self._data[0:0, :] st = this_trl.argmax() end = len(this_trl) - this_trl[st:][::-1].argmax() return self._data[st:end, :][this_trl[st:end],:] From f2c280c37bfc3150d35c408a04886ef734e4f7a7 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Mon, 2 Jan 2023 17:22:05 +0100 Subject: [PATCH 004/135] CHG: remove unique from sample --- syncopy/datatype/discrete_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index a8c6c4c7c..79afc29aa 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -115,7 +115,7 @@ def sample(self): """Indices of all recorded samples""" if self.data is None: return None - return np.unique(self.data[:, self.dimord.index("sample")]) + return self.data[:, self.dimord.index("sample")] @property def samplerate(self): From 1cd9b03df783ce0bbf24cebd96cab327ed4f6b9d Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 4 Jan 2023 12:55:25 +0100 Subject: [PATCH 005/135] CHG: new _trialslice property for DiscreteData --- syncopy/datatype/methods/definetrial.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/syncopy/datatype/methods/definetrial.py b/syncopy/datatype/methods/definetrial.py index a2a23c27d..c268a46e1 100644 --- a/syncopy/datatype/methods/definetrial.py +++ b/syncopy/datatype/methods/definetrial.py @@ -336,18 +336,18 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, # Compute trial-IDs by matching data samples with provided trial-bounds samples = tgt.data[:, tgt.dimord.index("sample")] if np.size(samples) > 0: - starts = tgt.sampleinfo[:, 0] - ends = tgt.sampleinfo[:, 1] - startids = np.searchsorted(starts, samples, side="right") - endids = np.searchsorted(ends, samples, side="left") - mask = startids == endids - startids -= 1 - # Samples not belonging into any trial get a trial-ID of -1 - startids[mask] = int(startids.min() <= 0) * (-1) - tgt.trialid = startids + idx = np.searchsorted(samples, tgt.sampleinfo.ravel()) + idx = idx.reshape(tgt.sampleinfo.shape) + + tgt._trialslice = [slice(st,end) for st,end in idx] + tgt.trialid = np.full((samples.shape), -1, dtype=int) + for itrl, itrl_slice in enumerate(tgt._trialslice): + tgt.trialid[itrl_slice] = itrl + # no data - empty object, can happen due to a selection else: tgt.trialid = None + tgt._trialslice = None tgt._trialdefinition = None # Write log entry From 321c2d3c917351a77d4b822dd244a9fe39d972b2 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 4 Jan 2023 13:07:31 +0100 Subject: [PATCH 006/135] CHG: Update DiscreteData to use _trialslice Removed custom .trials property --- syncopy/datatype/discrete_data.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 79afc29aa..73e307363 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -164,17 +164,6 @@ def trialid(self, trlid): raise exc self._trialid = np.array(trlid, dtype=int) - @property - def trials(self): - """list-like([sample x (>=2)] :class:`numpy.ndarray`) : trial slices of :attr:`data` property""" - if self.trialid is not None: - valid_trls = np.unique(self.trialid) - valid_trls = valid_trls[valid_trls >= 0] - return Indexer(map(self._get_trial, valid_trls), - valid_trls.size) - else: - return None - @property def trialtime(self): """list(:class:`numpy.ndarray`): trigger-relative sample times in s""" @@ -185,12 +174,7 @@ def trialtime(self): # Helper function that grabs a single trial def _get_trial(self, trialno): - this_trl = self.trialid == trialno - if not np.any(this_trl): - return self._data[0:0, :] - st = this_trl.argmax() - end = len(this_trl) - this_trl[st:][::-1].argmax() - return self._data[st:end, :][this_trl[st:end],:] + return self._data[self._trialslice[trialno], :] # Helper function that spawns a `FauxTrial` object given actual trial information def _preview_trial(self, trialno): From f6a6cf3d084c6f3016a6e9c13658075ef6427bee Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 4 Jan 2023 13:45:51 +0100 Subject: [PATCH 007/135] ise logger in SpyWarning --- syncopy/shared/errors.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index c0597033d..66d3f124a 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -6,6 +6,7 @@ # Builtin/3rd party package imports import sys import traceback +import logging from collections import OrderedDict # Local imports @@ -311,7 +312,8 @@ def SPYWarning(msg, caller=None): if caller is None: caller = sys._getframe().f_back.f_code.co_name PrintMsg = "{coloron:s}{bold:s}Syncopy{caller:s} WARNING: {msg:s}{coloroff:s}" - print(PrintMsg.format(coloron=warnCol, + logger = logging.getLogger(syncopy.shared.errors.loggername) + logger.warning(PrintMsg.format(coloron=warnCol, bold=boldEm, caller=" <" + caller + ">" if len(caller) else caller, msg=msg, From f7772954a928dba1369f0d4fe00e2f5e7d2a438b Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 4 Jan 2023 13:56:03 +0100 Subject: [PATCH 008/135] NEW: set logger name --- syncopy/shared/errors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 66d3f124a..da66e1292 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -18,6 +18,8 @@ __all__ = [] +loggername = "syncopy" + class SPYError(Exception): """ Base class for SynCoPy errors @@ -190,7 +192,8 @@ def SPYExceptionHandler(*excargs, **exckwargs): cols.Normal if isipy else "") # Show generated message and leave (or kick-off debugging in Jupyer/iPython if %pdb is on) - print(emsg) + logger = logging.getLogger(syncopy.shared.errors.loggername) + logger.error(emsg) if isipy: if ipy.call_pdb: ipy.InteractiveTB.debugger() From 6464c91ef20a6e92f1220c97a8467287bfbe1595 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 10:26:03 +0100 Subject: [PATCH 009/135] FIX: add import --- syncopy/shared/errors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index da66e1292..0afa7a9a7 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -11,6 +11,7 @@ # Local imports from syncopy import __tbcount__ +import syncopy # Custom definition of bold ANSI for formatting errors/warnings in iPython/Jupyter ansiBold = "\033[1m" From 686f5173cbd01f7bc7ad21fa916e36e4f0c86cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 10:44:34 +0100 Subject: [PATCH 010/135] CHG: disable test using a slice --- syncopy/tests/test_discretedata.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index 050d89f94..6928b0933 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -529,8 +529,8 @@ def test_ed_dataselection(self): eventidSelections = [ [0, 0, 1], # preserve repetition, don't convert to slice - range(0, 2), # narrow range - slice(-2, None) # negative-start slice + range(0, 2)#, # narrow range + #slice(-2, None) # negative-start slice ] latencySelections = [ @@ -556,7 +556,7 @@ def test_ed_dataselection(self): cfg = StructDict(kwdict) # data selection via class-method + `Selector` instance for indexing selected = obj.selectdata(**kwdict) - obj.selectdata(**kwdict, inplace=True) + obj.selectdata(**kwdict, inplace=True) selector = obj.selection tk = 0 for trialno in selector.trial_ids: From d8fbb8cdd1bb347a95bf905bbc7c2321afe6d697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 11:40:19 +0100 Subject: [PATCH 011/135] FIX: remove slice test --- syncopy/tests/test_discretedata.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index 6928b0933..afec39967 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -529,8 +529,7 @@ def test_ed_dataselection(self): eventidSelections = [ [0, 0, 1], # preserve repetition, don't convert to slice - range(0, 2)#, # narrow range - #slice(-2, None) # negative-start slice + range(0, 2) ] latencySelections = [ From 481a00e7dea4f0e91b6a43304f8e0c50a3d96d39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 11:44:28 +0100 Subject: [PATCH 012/135] CHG: replace all print calls with logger calls in errors.py --- syncopy/shared/errors.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 0afa7a9a7..c05d951f0 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -274,7 +274,8 @@ def SPYExceptionHandler(*excargs, **exckwargs): # Show generated message and get outta here - print(emsg) + logger = logging.getLogger(syncopy.shared.errors.loggername) + logger.critical(emsg) # Kick-start debugging in case %pdb is enabled in Jupyter/iPython if isipy: @@ -358,7 +359,8 @@ def SPYInfo(msg, caller=None): if caller is None: caller = sys._getframe().f_back.f_code.co_name PrintMsg = "{coloron:s}{bold:s}Syncopy{caller:s} INFO: {msg:s}{coloroff:s}" - print(PrintMsg.format(coloron=infoCol, + logger = logging.getLogger(syncopy.shared.errors.loggername) + logger.info(PrintMsg.format(coloron=infoCol, bold=boldEm, caller=" <" + caller + ">" if len(caller) else caller, msg=msg, From 98ca825648c7250d322768f1551936f1f77d3107 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 12:33:00 +0100 Subject: [PATCH 013/135] FIX: minor, fix typo --- syncopy/nwanalysis/csd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/nwanalysis/csd.py b/syncopy/nwanalysis/csd.py index 905c87c01..789bed822 100644 --- a/syncopy/nwanalysis/csd.py +++ b/syncopy/nwanalysis/csd.py @@ -37,7 +37,7 @@ def csd(trl_dat, This is NOT the same as what is commonly referred to as "cross spectral density" as there is no (time) averaging!! Multi-tapering alone is not necessarily sufficient to get enough - statitstical power for a robust csd estimate. Yet for completeness + statistical power for a robust csd estimate. Yet for completeness and testing the option ``norm = True`` returns a single-trial coherence estimate for ``taper = "dpss"``. From 857fd83a565b06495c904ab63b6c321c3e04780f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 12:33:22 +0100 Subject: [PATCH 014/135] FIX: prevent teardown method from being called in parallel test --- syncopy/tests/test_welch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/tests/test_welch.py b/syncopy/tests/test_welch.py index 906bd59f7..f41bb9ca4 100644 --- a/syncopy/tests/test_welch.py +++ b/syncopy/tests/test_welch.py @@ -319,7 +319,7 @@ def test_parallel(self, testcluster=None): plt.ioff() client = dd.Client(testcluster) all_tests = [attr for attr in self.__dir__() - if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] + if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr and attr.startswith('test'))] for test in all_tests: test_method = getattr(self, test) From dc183f6b5ce63be3155e8ff483199a90d0ee80a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 13:08:21 +0100 Subject: [PATCH 015/135] NEW: add logging documentation to dev docs --- doc/source/developer/developers.rst | 3 ++- doc/source/developer/logging.rst | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 doc/source/developer/logging.rst diff --git a/doc/source/developer/developers.rst b/doc/source/developer/developers.rst index 70a41edbc..9a35cf7ad 100644 --- a/doc/source/developer/developers.rst +++ b/doc/source/developer/developers.rst @@ -3,7 +3,7 @@ Syncopy Developer Guide *********************** The following information is meant for advanced users with an understanding of -class hierarchies that want to extend and/or modify Syncopy's base functionality. +class hierarchies that want to extend and/or modify Syncopy's base functionality. .. toctree:: :glob: @@ -13,4 +13,5 @@ class hierarchies that want to extend and/or modify Syncopy's base functionality io tools compute_kernels + logging developer_api diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst new file mode 100644 index 000000000..08adb9294 --- /dev/null +++ b/doc/source/developer/logging.rst @@ -0,0 +1,11 @@ +.. _syncopy-logging: + +Controlling Logging in Syncopy +=============================== + +Syncopy uses the `Python logging module `_ for logging, and logs to a logger named `'syncopy'`. + +To adapt the logging behaviour of Syncopy, one can configure the logger as explained in the documentation for the logging module. + + +The default log level is `'WARNING'`. To change the log level, one can either use the logging API (see above), or set the environment varible `'SYNCOPY_LOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details. From 25d7403d81a6a818a7f08ae06f14e1603440e676 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 13:09:47 +0100 Subject: [PATCH 016/135] CHG: fix typo --- doc/source/developer/logging.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index 08adb9294..c9ffaeb24 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -8,4 +8,4 @@ Syncopy uses the `Python logging module `_ for details. +The default log level is `'WARNING'`. To change the log level, one can either use the logging API (see above), or set the environment variable `'SYNCOPY_LOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details on the supported log levels. From 4ce4e1ce015ebf6075e7a55fc587ecb88359e9fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Thu, 5 Jan 2023 13:24:12 +0100 Subject: [PATCH 017/135] CHG: extend logging documenation --- doc/source/developer/logging.rst | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index c9ffaeb24..b65f80cc0 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -3,9 +3,24 @@ Controlling Logging in Syncopy =============================== -Syncopy uses the `Python logging module `_ for logging, and logs to a logger named `'syncopy'`. +Syncopy uses the `Python logging module `_ for logging, and logs to a logger named `'syncopy'` which is handled by the console. -To adapt the logging behaviour of Syncopy, one can configure the logger as explained in the documentation for the logging module. +To adapt the logging behaviour of Syncopy, one can configure the logger as explained in the documentation for the logging module. E.g.: +.. code-block:: python -The default log level is `'WARNING'`. To change the log level, one can either use the logging API (see above), or set the environment variable `'SYNCOPY_LOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details on the supported log levels. + import syncopy + import logging + # Get the logger used by syncopy + logger = logging.getLogger('syncopy') + + # Change the log level: + logger.setLevel(logging.DEBUG) + + # Make it log to a file instead of the console: + fh = logging.FileHandler('syncopy_log_within_my_app.log') + logger.addHandler(fh) + + + +The default log level is `'WARNING'`. To change the log level, you can either use the logging API in your application code as explained above, or set the environment variable `'SYNCOPY_LOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details on the supported log levels. From ccbff5c7795a76ab8dcd48f12b2b0a9a34f85a84 Mon Sep 17 00:00:00 2001 From: Katharine Shapcott Date: Thu, 5 Jan 2023 17:54:59 +0100 Subject: [PATCH 018/135] FIX: no data returns empty array --- syncopy/datatype/discrete_data.py | 3 +++ syncopy/datatype/methods/definetrial.py | 21 +++++++-------------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 73e307363..641241e2c 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -156,6 +156,9 @@ def trialid(self, trlid): print("SyNCoPy core - trialid: Cannot assign `trialid` without data. " + "Please assing data first") return + if (self.data.shape[0] == 0) and (trlid.shape[0] == 0): + self._trialid = np.array(trlid, dtype=int) + return scount = np.nanmax(self.data[:, self.dimord.index("sample")]) try: array_parser(trlid, varname="trialid", dims=(self.data.shape[0],), diff --git a/syncopy/datatype/methods/definetrial.py b/syncopy/datatype/methods/definetrial.py index c268a46e1..5ba449294 100644 --- a/syncopy/datatype/methods/definetrial.py +++ b/syncopy/datatype/methods/definetrial.py @@ -335,20 +335,13 @@ def definetrial(obj, trialdefinition=None, pre=None, post=None, start=None, # Compute trial-IDs by matching data samples with provided trial-bounds samples = tgt.data[:, tgt.dimord.index("sample")] - if np.size(samples) > 0: - idx = np.searchsorted(samples, tgt.sampleinfo.ravel()) - idx = idx.reshape(tgt.sampleinfo.shape) - - tgt._trialslice = [slice(st,end) for st,end in idx] - tgt.trialid = np.full((samples.shape), -1, dtype=int) - for itrl, itrl_slice in enumerate(tgt._trialslice): - tgt.trialid[itrl_slice] = itrl - - # no data - empty object, can happen due to a selection - else: - tgt.trialid = None - tgt._trialslice = None - tgt._trialdefinition = None + idx = np.searchsorted(samples, tgt.sampleinfo.ravel()) + idx = idx.reshape(tgt.sampleinfo.shape) + + tgt._trialslice = [slice(st,end) for st,end in idx] + tgt.trialid = np.full((samples.shape), -1, dtype=int) + for itrl, itrl_slice in enumerate(tgt._trialslice): + tgt.trialid[itrl_slice] = itrl # Write log entry if ref == tgt: From b730e92a83846a2822ced860461d84fc1e0e26be Mon Sep 17 00:00:00 2001 From: tensionhead Date: Thu, 5 Jan 2023 20:50:43 +0100 Subject: [PATCH 019/135] CHG: revert sample property Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 641241e2c..fc71a4871 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -115,7 +115,9 @@ def sample(self): """Indices of all recorded samples""" if self.data is None: return None - return self.data[:, self.dimord.index("sample")] + # return self.data[:, self.dimord.index("sample")] + # there should be only one event per sample number?! + return np.unique(self.data[:, self.dimord.index("sample")]) @property def samplerate(self): From 40b9348ecc0d511204c058e6e709cb0afb339871 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 09:27:08 +0100 Subject: [PATCH 020/135] set default log level --- syncopy/shared/errors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index da66e1292..a10e0e01e 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -4,6 +4,7 @@ # # Builtin/3rd party package imports +import os import sys import traceback import logging @@ -19,6 +20,7 @@ loggername = "syncopy" +default_loglevel = os.getenv("SYNCOPY_LOGLEVEL", "WARNING") # The logging threshold, one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' class SPYError(Exception): """ From 49a8241699e2f376f892b2e0beba3fef3343576c Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 09:56:21 +0100 Subject: [PATCH 021/135] NEW: allow silencing of module import message --- syncopy/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 0719e3fb0..426eb3ea2 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -46,7 +46,9 @@ try: dd.get_client() except ValueError: - print(msg) + silence_file = os.path.expanduser("~/.spy_silentstartup") + if os.getenv("SYNCOPY_SILENTSTARTUP") is None and not os.path.isfile(silence_file): + print(msg) # Set up sensible printing options for NumPy arrays np.set_printoptions(suppress=True, precision=4, linewidth=80) From 8c9cbf263445149c3a582a0e19b2f71975246343 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 09:56:44 +0100 Subject: [PATCH 022/135] NEW: better checking of default log level --- syncopy/shared/errors.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index a13ea0202..4b97c973a 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -8,6 +8,7 @@ import sys import traceback import logging +import warnings from collections import OrderedDict # Local imports @@ -19,9 +20,19 @@ __all__ = [] +def _get_default_loglevel(): + """Return the default loglevel, which is 'WARNING' unless set in the env var 'SYNCOPY_LOGLEVEL'.""" + loglevel = os.getenv("SYNCOPY_LOGLEVEL", "WARNING") + numeric_level = getattr(logging, loglevel.upper(), None) + if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. + warnings.warn("Invalid log level set in environment variable 'SYNCOPY_LOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + loglevel = "WARNING" + return loglevel + + +loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). +default_loglevel = _get_default_loglevel() # The logging threshold, one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. -loggername = "syncopy" -default_loglevel = os.getenv("SYNCOPY_LOGLEVEL", "WARNING") # The logging threshold, one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' class SPYError(Exception): """ From 2f98f241f55f7273e27d64e0ceef47d59500b671 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 10:38:13 +0100 Subject: [PATCH 023/135] NEW: add logdir, move temp storage folder --- syncopy/__init__.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 426eb3ea2..949e0ad78 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -46,8 +46,8 @@ try: dd.get_client() except ValueError: - silence_file = os.path.expanduser("~/.spy_silentstartup") - if os.getenv("SYNCOPY_SILENTSTARTUP") is None and not os.path.isfile(silence_file): + silence_file = os.path.expanduser("~/.spy/silentstartup") + if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): print(msg) # Set up sensible printing options for NumPy arrays @@ -96,9 +96,17 @@ __storage__ = os.path.abspath(os.path.expanduser(os.environ["SPYTMPDIR"])) else: if os.path.exists(csHome): - __storage__ = os.path.join(csHome, ".spy") + __storage__ = os.path.join(csHome, ".spy", "tmp_storage") else: - __storage__ = os.path.join(os.path.expanduser("~"), ".spy") + __storage__ = os.path.join(os.path.expanduser("~"), ".spy", "tmp_storage") + +if os.environ.get("SPYLOGDIR"): + __logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) +else: + if os.path.exists(csHome): + __logdir__ = os.path.join(csHome, ".spy", "logs") + else: + __logdir__ = os.path.join(os.path.expanduser("~"), ".spy", "logs") # Set upper bound for temp directory size (in GB) __storagelimit__ = 10 From 7172970494e7fb97cd8daf33c5ac0f8af85ed21c Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 11:14:06 +0100 Subject: [PATCH 024/135] FIX: recursively create dir --- syncopy/datatype/base_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index ba9fc4c9a..1d038306d 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -1239,7 +1239,7 @@ def __init__(self): # Create package-wide tmp directory if not already present if not os.path.exists(__storage__): try: - os.mkdir(__storage__) + os.mkdirs(__storage__) except Exception as exc: err = ( "Syncopy core: cannot create temporary storage directory {}. " From 404006c6147980807d89d024aa003378b0d83e0e Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 11:16:10 +0100 Subject: [PATCH 025/135] DIX: fix makedirs usage --- syncopy/datatype/base_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index 1d038306d..db581efb1 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -1239,7 +1239,7 @@ def __init__(self): # Create package-wide tmp directory if not already present if not os.path.exists(__storage__): try: - os.mkdirs(__storage__) + os.makedirs(__storage__) except Exception as exc: err = ( "Syncopy core: cannot create temporary storage directory {}. " From c8063e92cb58e81e862a134b7af99ab1d1429f5a Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 11:16:49 +0100 Subject: [PATCH 026/135] FIX: allow existing dir --- syncopy/datatype/base_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index db581efb1..2f22f57eb 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -1239,7 +1239,7 @@ def __init__(self): # Create package-wide tmp directory if not already present if not os.path.exists(__storage__): try: - os.makedirs(__storage__) + os.makedirs(__storage__, exist_ok=True) except Exception as exc: err = ( "Syncopy core: cannot create temporary storage directory {}. " From a475af7ac14bc9d53b8fc44bb5cc862caca18d70 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 12:07:31 +0100 Subject: [PATCH 027/135] NEW: provide different loggers for parallel and seq parts of code --- doc/source/developer/logging.rst | 2 +- syncopy/__init__.py | 20 ++++++++++++++++++++ syncopy/shared/errors.py | 24 +++++------------------- syncopy/shared/log.py | 27 +++++++++++++++++++++++++++ 4 files changed, 53 insertions(+), 20 deletions(-) create mode 100644 syncopy/shared/log.py diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index b65f80cc0..871022eca 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -23,4 +23,4 @@ To adapt the logging behaviour of Syncopy, one can configure the logger as expla -The default log level is `'WARNING'`. To change the log level, you can either use the logging API in your application code as explained above, or set the environment variable `'SYNCOPY_LOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details on the supported log levels. +The default log level is for the Syncopy logger is `'WARNING'`. To change the log level, you can either use the logging API in your application code as explained above, or set the environment variable `'SPYLOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details on the supported log levels. diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 949e0ad78..a900b372f 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -11,6 +11,8 @@ import getpass import numpy as np from hashlib import blake2b, sha1 +import logging +import warnings from importlib.metadata import version, PackageNotFoundError import dask.distributed as dd @@ -100,6 +102,7 @@ else: __storage__ = os.path.join(os.path.expanduser("~"), ".spy", "tmp_storage") +# Setup logging. if os.environ.get("SPYLOGDIR"): __logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) else: @@ -108,6 +111,23 @@ else: __logdir__ = os.path.join(os.path.expanduser("~"), ".spy", "logs") +loglevel = os.getenv("SPYLOGLEVEL", "WARNING") +numeric_level = getattr(logging, loglevel.upper(), None) +if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. + warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + loglevel = "WARNING" + +spy_logger = logging.getLogger('syncopy') +spy_logger.setLevel(loglevel) + +# Log to per-host files in parallel code by default. +host = socket.gethostname() +spy_parallel_logger = logging.getLogger("syncopy_" + host) + +fh = logging.FileHandler(os.path.join(__logdir__, f'syncopy_{host}.log')) +spy_parallel_logger.addHandler(fh) + + # Set upper bound for temp directory size (in GB) __storagelimit__ = 10 diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 4b97c973a..0fc4a9d49 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -4,15 +4,14 @@ # # Builtin/3rd party package imports -import os import sys import traceback import logging -import warnings from collections import OrderedDict # Local imports from syncopy import __tbcount__ +from syncopy.shared.log import get_logger import syncopy # Custom definition of bold ANSI for formatting errors/warnings in iPython/Jupyter @@ -20,19 +19,6 @@ __all__ = [] -def _get_default_loglevel(): - """Return the default loglevel, which is 'WARNING' unless set in the env var 'SYNCOPY_LOGLEVEL'.""" - loglevel = os.getenv("SYNCOPY_LOGLEVEL", "WARNING") - numeric_level = getattr(logging, loglevel.upper(), None) - if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. - warnings.warn("Invalid log level set in environment variable 'SYNCOPY_LOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") - loglevel = "WARNING" - return loglevel - - -loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). -default_loglevel = _get_default_loglevel() # The logging threshold, one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. - class SPYError(Exception): """ @@ -206,7 +192,7 @@ def SPYExceptionHandler(*excargs, **exckwargs): cols.Normal if isipy else "") # Show generated message and leave (or kick-off debugging in Jupyer/iPython if %pdb is on) - logger = logging.getLogger(syncopy.shared.errors.loggername) + logger = get_logger() logger.error(emsg) if isipy: if ipy.call_pdb: @@ -287,7 +273,7 @@ def SPYExceptionHandler(*excargs, **exckwargs): # Show generated message and get outta here - logger = logging.getLogger(syncopy.shared.errors.loggername) + logger = get_logger() logger.critical(emsg) # Kick-start debugging in case %pdb is enabled in Jupyter/iPython @@ -330,7 +316,7 @@ def SPYWarning(msg, caller=None): if caller is None: caller = sys._getframe().f_back.f_code.co_name PrintMsg = "{coloron:s}{bold:s}Syncopy{caller:s} WARNING: {msg:s}{coloroff:s}" - logger = logging.getLogger(syncopy.shared.errors.loggername) + logger = get_logger() logger.warning(PrintMsg.format(coloron=warnCol, bold=boldEm, caller=" <" + caller + ">" if len(caller) else caller, @@ -372,7 +358,7 @@ def SPYInfo(msg, caller=None): if caller is None: caller = sys._getframe().f_back.f_code.co_name PrintMsg = "{coloron:s}{bold:s}Syncopy{caller:s} INFO: {msg:s}{coloroff:s}" - logger = logging.getLogger(syncopy.shared.errors.loggername) + logger = get_logger() logger.info(PrintMsg.format(coloron=infoCol, bold=boldEm, caller=" <" + caller + ">" if len(caller) else caller, diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py new file mode 100644 index 000000000..ab3d0e98d --- /dev/null +++ b/syncopy/shared/log.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# +# Logging functions for Syncopy. +# +# Note: The logging setup is done in the top-level `__init.py__` file. + +import logging +import socket + + +loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). + +def get_logger(): + """Get the syncopy root logger. + + Logs to console by default. To be used in everything that runs on the local computer.""" + return logging.getLogger(loggername) + +def get_parallel_logger(): + """ + Get a logger for stuff that is run in parallel. + + Logs to a machine-specific file in the SPYLOGDIR by default. To be used in computational routines. + """ + host = socket.gethostname() + return logging.getLogger(loggername + "_" + host) + From b04141c7e5e63d2667bfcd6b0326cf2712303f91 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 12:09:11 +0100 Subject: [PATCH 028/135] CHG: also use loglevel for parallel logger --- syncopy/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index a900b372f..345f94939 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -126,6 +126,7 @@ fh = logging.FileHandler(os.path.join(__logdir__, f'syncopy_{host}.log')) spy_parallel_logger.addHandler(fh) +spy_parallel_logger.setLevel(loglevel) # Set upper bound for temp directory size (in GB) From b0c337cd4f430cc0b6fecd3c0fd387f0e3c73566 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 12:15:52 +0100 Subject: [PATCH 029/135] FIX: create logdir unless exists --- syncopy/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 345f94939..64f4702b1 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -111,6 +111,9 @@ else: __logdir__ = os.path.join(os.path.expanduser("~"), ".spy", "logs") +if not os.path.exists(__logdir__): + os.makedirs(__logdir__, exist_ok=True) + loglevel = os.getenv("SPYLOGLEVEL", "WARNING") numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. From 0b7fb894ba262343441e7a60e191657c7e635277 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 12:35:41 +0100 Subject: [PATCH 030/135] NEW: add function to delete all syncopy log files in log dir --- syncopy/shared/log.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index ab3d0e98d..cdf1b3acd 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -4,8 +4,11 @@ # # Note: The logging setup is done in the top-level `__init.py__` file. +import os import logging import socket +import syncopy +import warnings loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). @@ -25,3 +28,24 @@ def get_parallel_logger(): host = socket.gethostname() return logging.getLogger(loggername + "_" + host) + +def delete_all_logfiles(silent=True): + """Delete all '.log' files in the Syncopy logging directory. + + The log directory that will be emptied is `syncopy.__logdir__`. + """ + logdir = syncopy.__logdir__ + num_deleted = 0 + if os.path.isdir(logdir): + filelist = [ f for f in os.listdir(logdir) if f.endswith(".log") ] + for f in filelist: + logfile = os.path.join(logdir, f) + try: + os.remove(logfile) + num_deleted += 1 + except Exception as ex: + warnings.warn(f"Could not delete log file '{logfile}': {str(ex)}") + if not silent: + print(f"Deleted {num_deleted} log files from directory '{logdir}'.") + + From b970f3fa0652d7e04cbe5965e52aa4432d2cd682 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 12:39:33 +0100 Subject: [PATCH 031/135] CHG: minor, add info on default append mode for logfiles --- syncopy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 64f4702b1..d28322681 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -127,7 +127,7 @@ host = socket.gethostname() spy_parallel_logger = logging.getLogger("syncopy_" + host) -fh = logging.FileHandler(os.path.join(__logdir__, f'syncopy_{host}.log')) +fh = logging.FileHandler(os.path.join(__logdir__, f'syncopy_{host}.log')) # The default mode is 'append'. spy_parallel_logger.addHandler(fh) spy_parallel_logger.setLevel(loglevel) From 4d264a143b1624232d80f7e151bd77fed885136b Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 6 Jan 2023 13:38:52 +0100 Subject: [PATCH 032/135] NEW: add a SPYDebug method --- syncopy/shared/errors.py | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 0fc4a9d49..af7749bc6 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -132,7 +132,9 @@ def __str__(self): def SPYExceptionHandler(*excargs, **exckwargs): """ - Docstring coming soon(ish)... + Syncopy custom ExceptionHandler. + + Prints formatted and colored messages and stack traces, and starts debugging if `%pdb` is enabled in Jupyter/iPython. """ # Depending on the number of input arguments, we're either in Jupyter/iPython @@ -284,7 +286,10 @@ def SPYExceptionHandler(*excargs, **exckwargs): def SPYWarning(msg, caller=None): """ - Standardized Syncopy warning message + Log a standardized Syncopy warning message. + + .. note:: + Depending on the currently active log level, this may or may not produce any output. Parameters ---------- @@ -324,9 +329,12 @@ def SPYWarning(msg, caller=None): coloroff=normCol)) -def SPYInfo(msg, caller=None): +def SPYInfo(msg, caller=None, tag="INFO"): """ - Standardized Syncopy info message + Log a standardized Syncopy info message. + + .. note:: + Depending on the currently active log level, this may or may not produce any output. Parameters ---------- @@ -357,7 +365,7 @@ def SPYInfo(msg, caller=None): # Plug together message string and print it if caller is None: caller = sys._getframe().f_back.f_code.co_name - PrintMsg = "{coloron:s}{bold:s}Syncopy{caller:s} INFO: {msg:s}{coloroff:s}" + PrintMsg = "{coloron:s}{bold:s}Syncopy{caller:s} {tag}: {msg:s}{coloroff:s}" logger = get_logger() logger.info(PrintMsg.format(coloron=infoCol, bold=boldEm, @@ -365,3 +373,25 @@ def SPYInfo(msg, caller=None): msg=msg, coloroff=normCol)) +def SPYDebug(msg, caller=None): + """ + Log a standardized Syncopy debug message. + + .. note:: + Depending on the currently active log level, this may or may not produce any output. + + Parameters + ---------- + msg : str + Debug message to be printed + caller : None or str + Issuer of debug message. If `None`, name of calling method is + automatically fetched and pre-pended to `msg`. + + Returns + ------- + Nothing : None + """ + if caller is None: + caller = sys._getframe().f_back.f_code.co_name + SPYInfo(msg, caller=caller, tag="DEBUG") From 0ed14dccb393af3466b73d7d40369b5a87d9dbb8 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Fri, 6 Jan 2023 14:30:56 +0100 Subject: [PATCH 033/135] Testing --- syncopy/nwanalysis/connectivity_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/syncopy/nwanalysis/connectivity_analysis.py b/syncopy/nwanalysis/connectivity_analysis.py index d33cc7217..98ab2b302 100644 --- a/syncopy/nwanalysis/connectivity_analysis.py +++ b/syncopy/nwanalysis/connectivity_analysis.py @@ -5,6 +5,7 @@ # Builtin/3rd party package imports import numpy as np +'Testing' # Syncopy imports from syncopy.shared.parsers import data_parser, scalar_parser From f8096ea73c871d44f758c7fd6b248eba1e9349da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Fri, 6 Jan 2023 15:45:55 +0100 Subject: [PATCH 034/135] CHG: log hostname for remote logging --- syncopy/__init__.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index d28322681..de3c96f85 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -7,12 +7,13 @@ import os import sys import subprocess -import socket import getpass +import socket import numpy as np from hashlib import blake2b, sha1 import logging import warnings +import platform from importlib.metadata import version, PackageNotFoundError import dask.distributed as dd @@ -120,16 +121,32 @@ warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") loglevel = "WARNING" +# The logger for local/sequential stuff -- goes to terminal. spy_logger = logging.getLogger('syncopy') -spy_logger.setLevel(loglevel) +fmt = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') +sh = logging.StreamHandler() +sh.setLevel(loglevel) +sh.setFormatter(fmt) +spy_logger.addHandler(sh) # Log to per-host files in parallel code by default. -host = socket.gethostname() +host = platform.node() spy_parallel_logger = logging.getLogger("syncopy_" + host) -fh = logging.FileHandler(os.path.join(__logdir__, f'syncopy_{host}.log')) # The default mode is 'append'. +class HostnameFilter(logging.Filter): + hostname = platform.node() + + def filter(self, record): + record.hostname = HostnameFilter.hostname + return True + +logfile = os.path.join(__logdir__, f'syncopy_{host}.log') +fh = logging.FileHandler(logfile) # The default mode is 'append'. +fh.addFilter(HostnameFilter()) +fh.setLevel(loglevel) +fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') +fh.setFormatter(fmt_with_hostname) spy_parallel_logger.addHandler(fh) -spy_parallel_logger.setLevel(loglevel) # Set upper bound for temp directory size (in GB) From 6fec1cce46e5d0e154076c0855f2fd1ca95250d5 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 11:29:27 +0100 Subject: [PATCH 035/135] NEW: add parallel logging functions --- syncopy/shared/errors.py | 14 +++++++++++++- syncopy/shared/log.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index af7749bc6..70296b370 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -11,7 +11,7 @@ # Local imports from syncopy import __tbcount__ -from syncopy.shared.log import get_logger +from syncopy.shared.log import get_logger, get_parallel_logger, loglevels import syncopy # Custom definition of bold ANSI for formatting errors/warnings in iPython/Jupyter @@ -329,6 +329,18 @@ def SPYWarning(msg, caller=None): coloroff=normCol)) +def SPYParallelLog(msg, loglevel="INFO", caller=None): + numeric_level = getattr(logging, loglevel.upper(), None) + if not isinstance(numeric_level, int): # Invalid string was set. + raise SPYValueError(legal=f"one of: {loglevels}", varname="loglevel", actual=loglevel) + if caller is None: + caller = sys._getframe().f_back.f_code.co_name + PrintMsg = "{caller:s} {msg:s}" + logger = get_parallel_logger() + logger.info(PrintMsg.format(caller=" <" + caller + ">" if len(caller) else caller, + msg=msg)) + + def SPYInfo(msg, caller=None, tag="INFO"): """ Log a standardized Syncopy info message. diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index cdf1b3acd..8fcb63dc5 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -12,6 +12,7 @@ loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). +loglevels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] def get_logger(): """Get the syncopy root logger. From 435f418b108f8463bb3f3848049a42c9b4c40279 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 11:33:57 +0100 Subject: [PATCH 036/135] NEW: add docstring for parallel logging function --- syncopy/shared/errors.py | 4 ++++ syncopy/shared/log.py | 1 + 2 files changed, 5 insertions(+) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 70296b370..827f07807 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -330,6 +330,10 @@ def SPYWarning(msg, caller=None): def SPYParallelLog(msg, loglevel="INFO", caller=None): + """Log a message in parallel code run via slurm. + + This uses the parallel logger and one file per machine. + """ numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): # Invalid string was set. raise SPYValueError(legal=f"one of: {loglevels}", varname="loglevel", actual=loglevel) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 8fcb63dc5..eae53d4d9 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -20,6 +20,7 @@ def get_logger(): Logs to console by default. To be used in everything that runs on the local computer.""" return logging.getLogger(loggername) + def get_parallel_logger(): """ Get a logger for stuff that is run in parallel. From c04512ff0e8e165dde0246560bcda848bb6f6b19 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 12:09:00 +0100 Subject: [PATCH 037/135] NEW: use parallel logger in FOOOF code. --- syncopy/shared/errors.py | 3 ++- syncopy/specest/compRoutines.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 827f07807..168287254 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -341,7 +341,8 @@ def SPYParallelLog(msg, loglevel="INFO", caller=None): caller = sys._getframe().f_back.f_code.co_name PrintMsg = "{caller:s} {msg:s}" logger = get_parallel_logger() - logger.info(PrintMsg.format(caller=" <" + caller + ">" if len(caller) else caller, + logfunc = getattr(logger, loglevel.lower()) + logfunc(PrintMsg.format(caller=" <" + caller + ">" if len(caller) else caller, msg=msg)) diff --git a/syncopy/specest/compRoutines.py b/syncopy/specest/compRoutines.py index a20e79fb4..95c88258b 100644 --- a/syncopy/specest/compRoutines.py +++ b/syncopy/specest/compRoutines.py @@ -34,7 +34,7 @@ # Local imports -from syncopy.shared.errors import SPYValueError, SPYWarning +from syncopy.shared.errors import SPYValueError, SPYWarning, SPYParallelLog from syncopy.shared.tools import best_match from syncopy.shared.computational_routine import ComputationalRoutine, propagate_properties from syncopy.shared.kwarg_decorators import process_io @@ -946,6 +946,8 @@ def fooofspy_cF(trl_dat, foi=None, timeAxis=0, if noCompute: return outShape, spectralDTypes['pow'] + + # Call actual fooof method res, metadata = fooofspy(trl_dat[0, 0, :, :], in_freqs=fooof_settings['in_freqs'], freq_range=fooof_settings['freq_range'], out_type=output, fooof_opt=method_kwargs) @@ -995,6 +997,8 @@ class FooofSpy(ComputationalRoutine): # To attach metadata to the output of the CF def process_metadata(self, data, out): + SPYParallelLog("Fetching FOOOF output metadata from file '{out.filename}'.", loglevel="DEBUG") + # General-purpose loading of metadata. mdata = metadata_from_hdf5_file(out.filename) @@ -1002,10 +1006,14 @@ def process_metadata(self, data, out): # made in the call to `freqanalysis`, because the mtmfft run before will have # consumed them. So the trial indices are always relative. + SPYParallelLog("Decoding FOOOF output metadata from HDF5 datastructures.", loglevel="DEBUG") + # Backend-specific post-processing. May or may not be needed, depending on what # you need to do in the cF to fit the return values into hdf5. out.metadata = metadata_nest(FooofSpy.decode_metadata_fooof_alltrials_from_hdf5(mdata)) + SPYParallelLog("Copying recording information to output syncopy data instance.", loglevel="DEBUG") + # Some index gymnastics to get trial begin/end "samples" if data.selection is not None: chanSec = data.selection.channel @@ -1058,6 +1066,7 @@ def decode_metadata_fooof_alltrials_from_hdf5(metadata_fooof_hdf5): label, trial_idx, call_idx = decode_unique_md_label(unique_attr_label) if label == "n_peaks": n_peaks = v + SPYParallelLog(f"FOOOF detected {n_peaks} peaks in data of trial {trial_idx} call {call_idx}.", loglevel="DEBUG") gaussian_params_out = list() peak_params_out = list() start_idx = 0 From 75bb8ec5bfd4f0156424d94036c953a9af64c496 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 14:10:05 +0100 Subject: [PATCH 038/135] NEW: add logging func --- syncopy/shared/errors.py | 16 ++++++++++++++++ syncopy/specest/freqanalysis.py | 4 +++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 168287254..b2edbf920 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -345,6 +345,22 @@ def SPYParallelLog(msg, loglevel="INFO", caller=None): logfunc(PrintMsg.format(caller=" <" + caller + ">" if len(caller) else caller, msg=msg)) +def SPYLog(msg, loglevel="INFO", caller=None): + """Log a message in seqiential code. + + This uses the standard logger that logs to console by default. + """ + numeric_level = getattr(logging, loglevel.upper(), None) + if not isinstance(numeric_level, int): # Invalid string was set. + raise SPYValueError(legal=f"one of: {loglevels}", varname="loglevel", actual=loglevel) + if caller is None: + caller = sys._getframe().f_back.f_code.co_name + PrintMsg = "{caller:s} {msg:s}" + logger = get_logger() + logfunc = getattr(logger, loglevel.lower()) + logfunc(PrintMsg.format(caller=" <" + caller + ">" if len(caller) else caller, + msg=msg)) + def SPYInfo(msg, caller=None, tag="INFO"): """ diff --git a/syncopy/specest/freqanalysis.py b/syncopy/specest/freqanalysis.py index a1188ff80..c45aa0d44 100644 --- a/syncopy/specest/freqanalysis.py +++ b/syncopy/specest/freqanalysis.py @@ -10,7 +10,7 @@ from syncopy.shared.parsers import data_parser, scalar_parser, array_parser from syncopy.shared.tools import get_defaults, get_frontend_cfg from syncopy.datatype import SpectralData -from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYWarning, SPYInfo +from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYWarning, SPYInfo, SPYLog from syncopy.shared.kwarg_decorators import (unwrap_cfg, unwrap_select, detect_parallel_client) from syncopy.shared.tools import best_match @@ -455,6 +455,8 @@ def freqanalysis(data, method='mtmfft', output='pow', # to prepare/sanitize `toi` # -------------------------------- + + if method in ["mtmconvol", "wavelet", "superlet", "welch"]: # Get start/end timing info respecting potential in-place selection From 6e8ae642e40e18ecde61a76c8da8b590dc896c73 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 14:16:41 +0100 Subject: [PATCH 039/135] CHG: add DEBUG print in specest --- syncopy/specest/freqanalysis.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/syncopy/specest/freqanalysis.py b/syncopy/specest/freqanalysis.py index c45aa0d44..e1f4519c8 100644 --- a/syncopy/specest/freqanalysis.py +++ b/syncopy/specest/freqanalysis.py @@ -450,6 +450,8 @@ def freqanalysis(data, method='mtmfft', output='pow', "polyremoval": polyremoval, "pad": pad} + SPYLog(f"Running specest method '{method}'.", loglevel="DEBUG") + # -------------------------------- # 1st: Check time-frequency inputs # to prepare/sanitize `toi` From d560a1a03f2d721affa7fa6bfca79cc427b58a55 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 14:33:16 +0100 Subject: [PATCH 040/135] FIX: fix setting of log level --- syncopy/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index de3c96f85..19d9a76ef 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -7,6 +7,7 @@ import os import sys import subprocess +import datetime import getpass import socket import numpy as np @@ -124,10 +125,12 @@ # The logger for local/sequential stuff -- goes to terminal. spy_logger = logging.getLogger('syncopy') fmt = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') -sh = logging.StreamHandler() -sh.setLevel(loglevel) +sh = logging.StreamHandler(sys.stdout) sh.setFormatter(fmt) spy_logger.addHandler(sh) +spy_logger.setLevel(loglevel) +spy_logger.debug(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") +spy_logger.info(f"Syncopy log level set to: {loglevel}.") # Log to per-host files in parallel code by default. host = platform.node() @@ -143,7 +146,7 @@ def filter(self, record): logfile = os.path.join(__logdir__, f'syncopy_{host}.log') fh = logging.FileHandler(logfile) # The default mode is 'append'. fh.addFilter(HostnameFilter()) -fh.setLevel(loglevel) +spy_parallel_logger.setLevel(loglevel) fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') fh.setFormatter(fmt_with_hostname) spy_parallel_logger.addHandler(fh) From dc2d6e0e10d224529c10cc01b387aa26c3cf40db Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 14:35:46 +0100 Subject: [PATCH 041/135] FIX: fix debug message to use fstring --- syncopy/specest/compRoutines.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/syncopy/specest/compRoutines.py b/syncopy/specest/compRoutines.py index 95c88258b..2ba778086 100644 --- a/syncopy/specest/compRoutines.py +++ b/syncopy/specest/compRoutines.py @@ -997,7 +997,7 @@ class FooofSpy(ComputationalRoutine): # To attach metadata to the output of the CF def process_metadata(self, data, out): - SPYParallelLog("Fetching FOOOF output metadata from file '{out.filename}'.", loglevel="DEBUG") + SPYParallelLog(f"Fetching FOOOF output metadata from file '{out.filename}'.", loglevel="DEBUG") # General-purpose loading of metadata. mdata = metadata_from_hdf5_file(out.filename) @@ -1006,13 +1006,13 @@ def process_metadata(self, data, out): # made in the call to `freqanalysis`, because the mtmfft run before will have # consumed them. So the trial indices are always relative. - SPYParallelLog("Decoding FOOOF output metadata from HDF5 datastructures.", loglevel="DEBUG") + SPYParallelLog(f"Decoding FOOOF output metadata from HDF5 datastructures.", loglevel="DEBUG") # Backend-specific post-processing. May or may not be needed, depending on what # you need to do in the cF to fit the return values into hdf5. out.metadata = metadata_nest(FooofSpy.decode_metadata_fooof_alltrials_from_hdf5(mdata)) - SPYParallelLog("Copying recording information to output syncopy data instance.", loglevel="DEBUG") + SPYParallelLog(f"Copying recording information to output syncopy data instance.", loglevel="DEBUG") # Some index gymnastics to get trial begin/end "samples" if data.selection is not None: From c682fe112625e6e1bffe073676a2773801a8bd0c Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 9 Jan 2023 15:47:11 +0100 Subject: [PATCH 042/135] WIP: Reorganize channel property for SpikeData Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 88 ++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 26 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 376c0fbb5..ebf2ea7c9 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -51,6 +51,7 @@ def data(self): @data.setter def data(self, inData): + # this comes from BaseData self._set_dataset_property(inData, "data") def __str__(self): @@ -334,42 +335,74 @@ class SpikeData(DiscreteData): _stackingDimLabel = "sample" _selectionKeyWords = DiscreteData._selectionKeyWords + ('channel', 'unit',) + @property + def data(self): + """ + HDF5 dataset representing discrete spike data. + + Trials are concatenated along the time axis. + """ + + if getattr(self._data, "id", None) is not None: + if self._data.id.valid == 0: + lgl = "open HDF5 file" + act = "backing HDF5 file {} has been closed" + raise SPYValueError(legal=lgl, actual=act.format(self.filename), + varname="data") + return self._data + + @data.setter + def data(self, inData): + # this comes from BaseData + self._set_dataset_property(inData, "data") + + # set the default channel labels + self.channel = self._get_default_channel() + @property def channel(self): """ :class:`numpy.ndarray` : list of original channel names for each unit""" - # if data exists but no user-defined channel labels, create them on the fly - if self._channel is None and self._data is not None: - channelNumbers = np.unique(self.data[:, self.dimord.index("channel")]) - return np.array(["channel" + str(int(i + 1)).zfill(len(str(channelNumbers.max() + 1))) - for i in channelNumbers]) return self._channel @channel.setter def channel(self, chan): - if chan is None: - self._channel = None + + if chan is None and self.data is not None: + raise SPYValueError("Cannot set `channel` to `None` with existing data.") + elif self.data is None and chan is not None: + raise SPYValueError("Cannot assign `channel` without data. " + + "Please assign data first") + else: + # chan was None + self._channel = chan return - if self.data is None: - raise SPYValueError("Syncopy: Cannot assign `channels` without data. " + - "Please assign data first") + + nChan = np.max(self.data[:, self.dimord.index("channel")]) + 1 try: - array_parser(chan, varname="channel", ntype="str") + array_parser(chan, varname="channel", ntype="str", dims=(nChan, )) except Exception as exc: raise exc - # Remove duplicate entries from channel array but preserve original order - # (e.g., `[2, 0, 0, 1]` -> `[2, 0, 1`); allows for complex subset-selections - _, idx = np.unique(chan, return_index=True) - chan = np.array(chan)[np.sort(idx)] - nchan = np.unique(self.data[:, self.dimord.index("channel")]).size - if chan.size != nchan: - lgl = "channel label array of length {0:d}".format(nchan) - act = "array of length {0:d}".format(chan.size) - raise SPYValueError(legal=lgl, varname="channel", actual=act) - self._channel = chan + def _get_default_channel(self): + + """ + Creates the default channel labels + """ + + if self.data is not None: + # channel entries in self.data are 0-based + nChan = np.max(self.data[:, self.dimord.index("channel")]) + channel_arr = np.arange(nChan + 1) + channel_labels = np.array(["channel" + str(int(i + 1)).zfill(len(str(nChan)) + 1) + for i in channel_arr]) + else: + channel_labels = None + + return channel_labels + @property def unit(self): """ :class:`numpy.ndarray(str)` : unit names""" @@ -484,6 +517,9 @@ def __init__(self, """ + # instance attribute to allow modification + self._hdfFileAttributeProperties = DiscreteData._hdfFileAttributeProperties + ("channel", "unit") + self._unit = None self._channel = None @@ -493,12 +529,12 @@ def __init__(self, trialdefinition=trialdefinition, samplerate=samplerate, dimord=dimord) - - # instance attribute to allow modification - self._hdfFileAttributeProperties = DiscreteData._hdfFileAttributeProperties + ("channel",) - self.channel = channel - self.unit = unit + # use the setters, data is already attached + if channel is not None: + self.channel = channel + if unit is not None: + self.unit = unit class EventData(DiscreteData): From ddf439bb67c65efefdaf0ce4548a18c98dbea4e3 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 16:00:36 +0100 Subject: [PATCH 043/135] NEW: use parallel logger ater init to test it --- syncopy/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 19d9a76ef..37246562e 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -133,8 +133,10 @@ spy_logger.info(f"Syncopy log level set to: {loglevel}.") # Log to per-host files in parallel code by default. +# Note that this setup handles only the logger of the current host. host = platform.node() -spy_parallel_logger = logging.getLogger("syncopy_" + host) +parallel_logger_name = "syncopy_" + host +spy_parallel_logger = logging.getLogger(parallel_logger_name) class HostnameFilter(logging.Filter): hostname = platform.node() @@ -150,6 +152,7 @@ def filter(self, record): fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') fh.setFormatter(fmt_with_hostname) spy_parallel_logger.addHandler(fh) +spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile}' at level {loglevel}.") # Set upper bound for temp directory size (in GB) From 602566d5e9341bdf970ccd3aac552b0d7482ee8b Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 9 Jan 2023 16:06:13 +0100 Subject: [PATCH 044/135] CHG: Remove property computation for unit Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 45 ++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index ebf2ea7c9..0749e8775 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -356,9 +356,6 @@ def data(self, inData): # this comes from BaseData self._set_dataset_property(inData, "data") - # set the default channel labels - self.channel = self._get_default_channel() - @property def channel(self): """ :class:`numpy.ndarray` : list of original channel names for each unit""" @@ -373,8 +370,7 @@ def channel(self, chan): elif self.data is None and chan is not None: raise SPYValueError("Cannot assign `channel` without data. " + "Please assign data first") - else: - # chan was None + elif chan is None: self._channel = chan return @@ -398,29 +394,28 @@ def _get_default_channel(self): channel_arr = np.arange(nChan + 1) channel_labels = np.array(["channel" + str(int(i + 1)).zfill(len(str(nChan)) + 1) for i in channel_arr]) - else: - channel_labels = None + return channel_labels - return channel_labels + else: + return None @property def unit(self): """ :class:`numpy.ndarray(str)` : unit names""" - if self.data is not None and self._unit is None: - unitIndices = np.unique(self.data[:, self.dimord.index("unit")]) - return np.array(["unit" + str(int(i)).zfill(len(str(unitIndices.max()))) - for i in unitIndices]) + return self._unit @unit.setter def unit(self, unit): - if unit is None: - self._unit = None - return - if self.data is None: + if unit is None and self.data is not None: + raise SPYValueError("Cannot set `unit` to `None` with existing data.") + elif self.data is None and unit is not None: raise SPYValueError("Syncopy - SpikeData - unit: Cannot assign `unit` without data. " + "Please assign data first") + elif unit is None: + self._unit = None + return nunit = np.unique(self.data[:, self.dimord.index("unit")]).size try: @@ -429,6 +424,19 @@ def unit(self, unit): raise exc self._unit = np.array(unit) + def _get_default_unit(self): + + """ + Creates the default unit labels + """ + + if self.data is not None: + unitIndices = np.unique(self.data[:, self.dimord.index("unit")]) + return np.array(["unit" + str(int(i)).zfill(len(str(unitIndices.max()))) + for i in unitIndices]) + else: + return None + # Helper function that extracts by-trial unit-indices def _get_unit(self, trials, units=None): """ @@ -533,8 +541,13 @@ def __init__(self, # use the setters, data is already attached if channel is not None: self.channel = channel + else: + self.channel = self._get_default_channel() + if unit is not None: self.unit = unit + else: + self.unit = self._get_default_unit() class EventData(DiscreteData): From 8d578904666ac0507faac53a7384db512cb627a5 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 16:10:44 +0100 Subject: [PATCH 045/135] NEW: allow different log levels for local and parallel loggers --- syncopy/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 37246562e..234100ec8 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -134,6 +134,11 @@ # Log to per-host files in parallel code by default. # Note that this setup handles only the logger of the current host. +parloglevel = os.getenv("SPYPARLOGLEVEL", loglevel) +numeric_level = getattr(logging, parloglevel.upper(), None) +if not isinstance(numeric_level, int): # An invalid string was set as the env variable, use default. + warnings.warn("Invalid log level set in environment variable 'SPYPARLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + parloglevel = "WARNING" host = platform.node() parallel_logger_name = "syncopy_" + host spy_parallel_logger = logging.getLogger(parallel_logger_name) @@ -148,7 +153,7 @@ def filter(self, record): logfile = os.path.join(__logdir__, f'syncopy_{host}.log') fh = logging.FileHandler(logfile) # The default mode is 'append'. fh.addFilter(HostnameFilter()) -spy_parallel_logger.setLevel(loglevel) +spy_parallel_logger.setLevel(parloglevel) fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') fh.setFormatter(fmt_with_hostname) spy_parallel_logger.addHandler(fh) From 89421c38361e5f17d8effe599c2e900c7d56ab07 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 16:19:26 +0100 Subject: [PATCH 046/135] NEW: use logger in backend func --- syncopy/specest/fooofspy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/syncopy/specest/fooofspy.py b/syncopy/specest/fooofspy.py index 94211bc5d..9e85324c6 100644 --- a/syncopy/specest/fooofspy.py +++ b/syncopy/specest/fooofspy.py @@ -8,6 +8,8 @@ # Builtin/3rd party package imports import numpy as np from fooof import FOOOF +import logging +import platform # Constants available_fooof_out_types = ['fooof', 'fooof_aperiodic', 'fooof_peaks'] @@ -94,6 +96,9 @@ def fooofspy(data_arr, in_freqs, freq_range=None, if in_freqs is None: raise ValueError('infreqs: The input frequencies are required and must not be None.') + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Running FOOOF backend function on data chunk with shape {data_arr.shape}.") + invalid_fooof_opts = [i for i in fooof_opt.keys() if i not in available_fooof_options] if invalid_fooof_opts: raise ValueError("fooof_opt: invalid keys: '{inv}', allowed keys are: '{lgl}'.".format(inv=invalid_fooof_opts, lgl=fooof_opt.keys())) From fda7338bbc7aa63b87575d65fc249da19691867a Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 9 Jan 2023 16:28:19 +0100 Subject: [PATCH 047/135] FIX: add tag in SPYInfo --- syncopy/shared/errors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index b2edbf920..dfa3ef53d 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -403,6 +403,7 @@ def SPYInfo(msg, caller=None, tag="INFO"): logger.info(PrintMsg.format(coloron=infoCol, bold=boldEm, caller=" <" + caller + ">" if len(caller) else caller, + tag=tag, msg=msg, coloroff=normCol)) From fe63ee2d27b6ba491aac744cc94bf6db7418fcb4 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 9 Jan 2023 16:40:32 +0100 Subject: [PATCH 048/135] CHG: Allow original labels for selections - it's probably confusing to have non-existent channels after a selection Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 0749e8775..7fde317b1 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -366,19 +366,20 @@ def channel(self): def channel(self, chan): if chan is None and self.data is not None: - raise SPYValueError("Cannot set `channel` to `None` with existing data.") + raise SPYValueError("channel labels, cannot set `channel` to `None` with existing data.") elif self.data is None and chan is not None: - raise SPYValueError("Cannot assign `channel` without data. " + + raise SPYValueError(f"non-empty SpikeData", "cannot assign `channel` without data. " + "Please assign data first") elif chan is None: self._channel = chan return - nChan = np.max(self.data[:, self.dimord.index("channel")]) + 1 - try: - array_parser(chan, varname="channel", ntype="str", dims=(nChan, )) - except Exception as exc: - raise exc + # we need at least as many labels as there are distinct channels + nChan_min = np.unique(self.data[:, self.dimord.index("channel")]).size + + if nChan_min > len(chan): + raise SPYValueError(f"at least {nChan_min} labels") + array_parser(chan, varname="channel", ntype="str") self._channel = chan @@ -418,10 +419,8 @@ def unit(self, unit): return nunit = np.unique(self.data[:, self.dimord.index("unit")]).size - try: - array_parser(unit, varname="unit", ntype="str", dims=(nunit,)) - except Exception as exc: - raise exc + array_parser(unit, varname="unit", ntype="str", dims=(nunit,)) + self._unit = np.array(unit) def _get_default_unit(self): From b256df67edf90f9318b719bd564bd305fd410f69 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Mon, 9 Jan 2023 16:53:21 +0100 Subject: [PATCH 049/135] CSD estimation implemented as the output using coherence method --- syncopy/nwanalysis/connectivity_analysis.py | 72 +++++++++------------ 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/syncopy/nwanalysis/connectivity_analysis.py b/syncopy/nwanalysis/connectivity_analysis.py index 98ab2b302..af270ed7e 100644 --- a/syncopy/nwanalysis/connectivity_analysis.py +++ b/syncopy/nwanalysis/connectivity_analysis.py @@ -4,19 +4,8 @@ # # Builtin/3rd party package imports -import numpy as np -'Testing' - -# Syncopy imports -from syncopy.shared.parsers import data_parser, scalar_parser -from syncopy.shared.tools import get_defaults, best_match, get_frontend_cfg -from syncopy.datatype import CrossSpectralData, AnalogData, SpectralData -from syncopy.shared.errors import ( - SPYValueError, - SPYWarning, - SPYInfo) -from syncopy.shared.kwarg_decorators import (unwrap_cfg, unwrap_select, - detect_parallel_client) +from syncopy.nwanalysis.AV_compRoutines import NormalizeCrossSpectra, NormalizeCrossCov, GrangerCausality +from syncopy.nwanalysis.ST_compRoutines import CrossSpectra, CrossCovariance, SpectralDyadicProduct from syncopy.shared.input_processors import ( process_taper, process_foi, @@ -24,13 +13,21 @@ check_effective_parameters, check_passed_kwargs ) - -from syncopy.nwanalysis.ST_compRoutines import CrossSpectra, CrossCovariance, SpectralDyadicProduct -from syncopy.nwanalysis.AV_compRoutines import NormalizeCrossSpectra, NormalizeCrossCov, GrangerCausality +from syncopy.shared.kwarg_decorators import (unwrap_cfg, unwrap_select, + detect_parallel_client) +from syncopy.shared.errors import ( + SPYValueError, + SPYWarning, + SPYInfo) +from syncopy.datatype import CrossSpectralData, AnalogData, SpectralData +from syncopy.shared.tools import get_defaults, best_match, get_frontend_cfg +from syncopy.shared.parsers import data_parser, scalar_parser +import numpy as np +# Syncopy imports availableMethods = ("coh", "corr", "granger") -coh_outputs = {"abs", "pow", "complex", "fourier", "angle", "real", "imag"} +coh_outputs = {"abs", "pow", "complex", "fourier", "angle", "real", "imag", "csd"} @unwrap_cfg @@ -40,7 +37,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", foi=None, foilim=None, pad='maxperlen', polyremoval=0, tapsmofrq=None, nTaper=None, taper="hann", taper_opt=None, **kwargs): - """ Perform connectivity analysis of Syncopy :class:`~syncopy.SpectralData` OR directly :class:`~syncopy.AnalogData` objects @@ -202,7 +198,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", act = f"{data.__class__.__name__}" raise SPYValueError(lgl, 'data', act) timeAxis = data.dimord.index("time") - # Get everything of interest in local namespace defaults = get_defaults(connectivityanalysis) lcls = locals() @@ -255,7 +250,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # --- method specific processing --- if method == 'corr': - if not isinstance(data, AnalogData): lgl = f"AnalogData instance as input for method {method}" actual = f"{data.__class__.__name__}" @@ -309,7 +303,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", polyremoval, log_dict, timeAxis) # SpectralData input elif isinstance(data, SpectralData): - # cross-spectra need complex input spectra if data.data.dtype != np.complex64 and data.data.dtype != np.complex128: lgl = "complex valued spectra, set `output='fourier` in spy.freqanalysis!" @@ -335,7 +328,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # --- Set up of computation of trial-averaged CSDs is complete --- if method == 'coh': - if output not in coh_outputs: lgl = f"one of {coh_outputs}" raise SPYValueError(lgl, varname="output", actual=output) @@ -345,7 +337,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", av_compRoutine = NormalizeCrossSpectra(output=output) if method == 'granger': - # spectral analysis only possible with AnalogData besides = ['tapsmofrq'] if isinstance(data, AnalogData) else None check_effective_parameters(GrangerCausality, defaults, lcls, besides=besides) @@ -357,7 +348,6 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", cond_max=1e4 ) - # ------------------------------------------------- # Call the chosen single trial ComputationalRoutine # ------------------------------------------------- @@ -379,33 +369,33 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # ---------------------------------------------------------------------------------- # Sanitize output and call the chosen ComputationalRoutine on the averaged ST output # ---------------------------------------------------------------------------------- - - out = CrossSpectralData(dimord=st_dimord) - - # now take the trial average from the single trial CR as input - av_compRoutine.initialize(st_out, out._stackingDim, chan_per_worker=None) - av_compRoutine.pre_check() # make sure we got a trial_average - av_compRoutine.compute(st_out, out, parallel=kwargs.get("parallel"), - log_dict=log_dict) - - # attach potential older cfg's from the input - # to support chained frontend calls.. - out.cfg.update(data.cfg) - # attach frontend parameters for replay - out.cfg.update({'connectivityanalysis': new_cfg}) - return out + if output == 'csd': + st_out.cfg.update(data.cfg) + st_out.cfg.update({'cross_spectral': new_cfg}) + return st_out + else: + out = CrossSpectralData(dimord=st_dimord) + # now take the trial average from the single trial CR as input + av_compRoutine.initialize(st_out, out._stackingDim, chan_per_worker=None) + av_compRoutine.pre_check() # make sure we got a trial_average + av_compRoutine.compute(st_out, out, parallel=kwargs.get("parallel"), + log_dict=log_dict) + # attach potential older cfg's from the input + # to support chained frontend calls.. + out.cfg.update(data.cfg) + # attach frontend parameters for replay + out.cfg.update({'connectivityanalysis': new_cfg}) + return out def cross_spectra(data, method, nSamples, foi, foilim, tapsmofrq, nTaper, taper, taper_opt, polyremoval, log_dict, timeAxis): - ''' Calculates the single trial cross-spectral densities from AnalogData ''' - # --- Basic foi sanitization --- foi, foilim = process_foi(foi, foilim, data.samplerate) From f1c075506b26e9431ae99a4cb5f8431b99991599 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Tue, 10 Jan 2023 12:11:53 +0100 Subject: [PATCH 050/135] CSD estimation implemented as a method of connectivity, cfg output corrected for methods other that coh --- syncopy/nwanalysis/connectivity_analysis.py | 34 ++++++++++++++------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/syncopy/nwanalysis/connectivity_analysis.py b/syncopy/nwanalysis/connectivity_analysis.py index af270ed7e..d9ac948fd 100644 --- a/syncopy/nwanalysis/connectivity_analysis.py +++ b/syncopy/nwanalysis/connectivity_analysis.py @@ -26,8 +26,8 @@ # Syncopy imports -availableMethods = ("coh", "corr", "granger") -coh_outputs = {"abs", "pow", "complex", "fourier", "angle", "real", "imag", "csd"} +availableMethods = ("coh", "corr", "granger", "csd") +connectivity_outputs = {"abs", "pow", "complex", "fourier", "angle", "real", "imag"} @unwrap_cfg @@ -69,6 +69,18 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", * **nTaper** : (optional) number of orthogonal tapers for slepian tapers * **pad**: either pad to an absolute length in seconds or set to `'nextpow2'` + "csd" : ('Multi-) tapered cross spectral density estimate + Computes the normalized cross spectral densities between all channel combinations + + output : complex spectrum + + **Spectral analysis** (input is :class:`~syncopy.AnalogData`): + + * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers` + * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz) + * **nTaper** : (optional) number of orthogonal tapers for slepian tapers + * **pad**: either pad to an absolute length in seconds or set to `'nextpow2'` + "corr" : Cross-correlations Computes the one sided (positive lags) cross-correlations between all channel combinations of :class:`~syncopy.AnalogData`. @@ -104,7 +116,7 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", A non-empty Syncopy :class:`~syncopy.SpectralData` or :class:`~syncopy.AnalogData` object method : str - Connectivity estimation method, one of ``'coh'`, 'corr', 'granger'`` + Connectivity estimation method, one of ``'coh'`, 'corr', 'granger', 'csd'`` output : str Relevant for cross-spectral density estimation (``method='coh'``) Use ``'pow'`` for absolute squared coherence, ``'abs'`` for absolute value of coherence @@ -277,8 +289,7 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # hard coded as class attribute st_dimord = CrossCovariance.dimord - elif method in ['coh', 'granger']: - + elif method in ['coh', 'granger', 'csd']: nTrials = len(data.trials) if nTrials == 1: lgl = "multi-trial input data, spectral connectivity measures critically depend on trial averaging!" @@ -327,9 +338,9 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # --- Set up of computation of trial-averaged CSDs is complete --- - if method == 'coh': - if output not in coh_outputs: - lgl = f"one of {coh_outputs}" + if method in ('coh', 'csd'): + if output not in connectivity_outputs: + lgl = f"one of {connectivity_outputs}" raise SPYValueError(lgl, varname="output", actual=output) log_dict['output'] = output @@ -369,7 +380,8 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # ---------------------------------------------------------------------------------- # Sanitize output and call the chosen ComputationalRoutine on the averaged ST output # ---------------------------------------------------------------------------------- - if output == 'csd': + if method == 'csd': + new_cfg.update({'output': st_out.data.dtype.name}) st_out.cfg.update(data.cfg) st_out.cfg.update({'cross_spectral': new_cfg}) return st_out @@ -384,6 +396,7 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # to support chained frontend calls.. out.cfg.update(data.cfg) # attach frontend parameters for replay + new_cfg.update({'output': out.data.dtype.name if method != 'coh' else output}) out.cfg.update({'connectivityanalysis': new_cfg}) return out @@ -415,8 +428,7 @@ def cross_spectra(data, method, nSamples, msg = "Multi-channel Granger analysis can be numerically unstable, it is recommended to have at least 10 times the number of trials compared to the number of channels. Try calculating in sub-groups of fewer channels!" SPYWarning(msg) - if method in ['coh', 'granger']: - + if method in ['coh', 'granger', 'csd']: # --- set up computation of the single trial CSDs --- # Construct array of maximally attainable frequencies From 2f2e964b0c87008170da18a0c143a0b1f2582eb0 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Tue, 10 Jan 2023 12:45:41 +0100 Subject: [PATCH 051/135] NEW: work on logging doc --- doc/source/developer/logging.rst | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index 871022eca..a5e0b56b0 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -3,9 +3,16 @@ Controlling Logging in Syncopy =============================== -Syncopy uses the `Python logging module `_ for logging, and logs to a logger named `'syncopy'` which is handled by the console. +Syncopy uses the `Python logging module `_ for logging. It uses two different loggers: +one for code that runs on the local machine, and another one for logging the parallelelized code that +is run by the remote workers in a high performance computing (HPC) cluster environment. -To adapt the logging behaviour of Syncopy, one can configure the logger as explained in the documentation for the logging module. E.g.: +Logging code that runs locally +------------------------------- + +For all code that is run on the local machine, Syncopy logs to a logger named `'syncopy'` which is handled by the console. + +To adapt the local logging behaviour of Syncopy, one can configure the logger as explained in the documentation for the logging module, e.g., in your application that uses Syncopy: .. code-block:: python @@ -17,10 +24,22 @@ To adapt the logging behaviour of Syncopy, one can configure the logger as expla # Change the log level: logger.setLevel(logging.DEBUG) - # Make it log to a file instead of the console: + # Make it log to a file: fh = logging.FileHandler('syncopy_log_within_my_app.log') logger.addHandler(fh) + # The rest of your application code goes here. + + +Logging code that potentially runs remotely +-------------------------------------------- + +The parallel code that performs the heavy lifting on the Syncopy data will be executed on remote machines (cluster nodes) when Syncopy is run in an HPC environment. Therefore, +special handling is required for these parts of the code, and we need to log to one log file per remote machine to avoid race conditions and + + +Log levels +----------- -The default log level is for the Syncopy logger is `'WARNING'`. To change the log level, you can either use the logging API in your application code as explained above, or set the environment variable `'SPYLOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs `_ for details on the supported log levels. +The default log level is for the Syncopy logger is `'logging.WARNING'` (from now on referred to as `'WARNING'`). This means that you will not see any Syncopy messages below that threshold, i.e., messages printed with log levels `'DEBUG'` and `'INFO'`. To change the log level, you can either use the logging API in your application code as explained above, or set the environment variable `'SPYLOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs of the logging module `_ for details on the supported log levels. From 8774df33b27e6f658e84adbfdb040f3e74a557ed Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Tue, 10 Jan 2023 12:46:34 +0100 Subject: [PATCH 052/135] NEW: explain where users can find logdir value --- syncopy/shared/log.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index eae53d4d9..8042fb295 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -26,6 +26,8 @@ def get_parallel_logger(): Get a logger for stuff that is run in parallel. Logs to a machine-specific file in the SPYLOGDIR by default. To be used in computational routines. + + The log directory used is `syncopy.__logdir__`. It can be changed by setting the environment variable SPYLOGDIR before running an application that uses Syncopy. """ host = socket.gethostname() return logging.getLogger(loggername + "_" + host) From 83455e70a94369c900e1688033c0766ac6565751 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Tue, 10 Jan 2023 14:28:42 +0100 Subject: [PATCH 053/135] CSD keep trials option implemented --- syncopy/nwanalysis/connectivity_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/nwanalysis/connectivity_analysis.py b/syncopy/nwanalysis/connectivity_analysis.py index d9ac948fd..31672c080 100644 --- a/syncopy/nwanalysis/connectivity_analysis.py +++ b/syncopy/nwanalysis/connectivity_analysis.py @@ -296,7 +296,7 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", act = "only one trial" raise SPYValueError(lgl, 'data', act) - if keeptrials is not False: + if keeptrials is not False and method in ('coh', 'granger'): lgl = "False, trial averaging needed for 'coh' and 'granger'!" act = keeptrials raise SPYValueError(lgl, varname="keeptrials", actual=act) From 0c72cfe934a211d8683c43d730908e2eb740ffa0 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 10 Jan 2023 16:45:58 +0100 Subject: [PATCH 054/135] CHG: Further streamline channel/unit lookup - now also selections are very fast Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 110 +++++++++++++++++------------- 1 file changed, 64 insertions(+), 46 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 7fde317b1..c3ab8e737 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -14,7 +14,7 @@ from .base_data import BaseData, Indexer, FauxTrial from .methods.definetrial import definetrial from syncopy.shared.parsers import scalar_parser, array_parser -from syncopy.shared.errors import SPYValueError +from syncopy.shared.errors import SPYValueError, SPYError from syncopy.shared.tools import best_match __all__ = ["SpikeData", "EventData"] @@ -154,8 +154,8 @@ def trialid(self, trlid): return if self.data is None: - print("SyNCoPy core - trialid: Cannot assign `trialid` without data. " + - "Please assing data first") + SPYError("SyNCoPy core - trialid: Cannot assign `trialid` without data. " + + "Please assing data first") return scount = np.nanmax(self.data[:, self.dimord.index("sample")]) try: @@ -335,26 +335,14 @@ class SpikeData(DiscreteData): _stackingDimLabel = "sample" _selectionKeyWords = DiscreteData._selectionKeyWords + ('channel', 'unit',) - @property - def data(self): - """ - HDF5 dataset representing discrete spike data. + def _compute_unique(self): - Trials are concatenated along the time axis. - """ + if self.data is None: + return - if getattr(self._data, "id", None) is not None: - if self._data.id.valid == 0: - lgl = "open HDF5 file" - act = "backing HDF5 file {} has been closed" - raise SPYValueError(legal=lgl, actual=act.format(self.filename), - varname="data") - return self._data - - @data.setter - def data(self, inData): - # this comes from BaseData - self._set_dataset_property(inData, "data") + # this is costly + self.channel_idx = np.unique(self.data[:, self.dimord.index("channel")]) + self.unit_idx = np.unique(self.data[:, self.dimord.index("unit")]) @property def channel(self): @@ -364,26 +352,32 @@ def channel(self): @channel.setter def channel(self, chan): - - if chan is None and self.data is not None: - raise SPYValueError("channel labels, cannot set `channel` to `None` with existing data.") - elif self.data is None and chan is not None: - raise SPYValueError(f"non-empty SpikeData", "cannot assign `channel` without data. " + - "Please assign data first") - elif chan is None: + if self.data is None: + if chan is not None: + raise SPYValueError(f"non-empty SpikeData", "cannot assign `channel` without data. " + + "Please assign data first") + # empy labels for empty data is fine self._channel = chan return - # we need at least as many labels as there are distinct channels - nChan_min = np.unique(self.data[:, self.dimord.index("channel")]).size + # there is data + else: + if chan is None: + raise SPYValueError("channel labels, cannot set `channel` to `None` with existing data.") + + # we have data and new labels + if self.channel_idx is None: + self._compute_unique() - if nChan_min > len(chan): - raise SPYValueError(f"at least {nChan_min} labels") - array_parser(chan, varname="channel", ntype="str") + # we need as many labels as there are distinct channels + nChan = self.channel_idx.size + if nChan != len(chan): + raise SPYValueError(f"exactly {nChan} channel labels") + array_parser(chan, varname="channel", ntype="str", dims=(nChan, )) self._channel = chan - def _get_default_channel(self): + def _default_channel_labels(self): """ Creates the default channel labels @@ -391,10 +385,9 @@ def _get_default_channel(self): if self.data is not None: # channel entries in self.data are 0-based - nChan = np.max(self.data[:, self.dimord.index("channel")]) - channel_arr = np.arange(nChan + 1) - channel_labels = np.array(["channel" + str(int(i + 1)).zfill(len(str(nChan)) + 1) - for i in channel_arr]) + chan_max = self.channel_idx.max() + channel_labels = np.array(["channel" + str(int(i + 1)).zfill(len(str(chan_max)) + 1) + for i in self.channel_idx]) return channel_labels else: @@ -408,6 +401,22 @@ def unit(self): @unit.setter def unit(self, unit): + if self.data is None: + if unit is not None: + raise SPYValueError(f"non-empty SpikeData", "cannot assign `unit` without data. " + + "Please assign data first") + # empy labels for empty data is fine + self._unit = unit + return + + # there is data + else: + if unit is None: + raise SPYValueError("unit labels, cannot set `unit` to `None` with existing data.") + + # we have data and new labels + if self.unit_idx is None: + self._compute_unique() if unit is None and self.data is not None: raise SPYValueError("Cannot set `unit` to `None` with existing data.") @@ -418,21 +427,23 @@ def unit(self, unit): self._unit = None return - nunit = np.unique(self.data[:, self.dimord.index("unit")]).size + nunit = self.unit_idx.size + if nunit != len(unit): + raise SPYValueError(f"exactly {nunit} unit labels") array_parser(unit, varname="unit", ntype="str", dims=(nunit,)) self._unit = np.array(unit) - def _get_default_unit(self): + def _default_unit_labels(self): """ Creates the default unit labels """ if self.data is not None: - unitIndices = np.unique(self.data[:, self.dimord.index("unit")]) - return np.array(["unit" + str(int(i)).zfill(len(str(unitIndices.max()))) - for i in unitIndices]) + unit_max = self.unit_idx.max() + return np.array(["unit" + str(int(i)).zfill(len(str(unit_max))) + for i in self.unit_idx]) else: return None @@ -528,7 +539,9 @@ def __init__(self, self._hdfFileAttributeProperties = DiscreteData._hdfFileAttributeProperties + ("channel", "unit") self._unit = None + self.unit_idx = None self._channel = None + self.channel_idx = None # Call parent initializer super().__init__(data=data, @@ -537,16 +550,21 @@ def __init__(self, samplerate=samplerate, dimord=dimord) - # use the setters, data is already attached + # for fast lookup and labels + self._compute_unique() + + # use the setters to assign initial labels, if channel is not None: + # this rightfully fails for empty data self.channel = channel else: - self.channel = self._get_default_channel() + # sets to None if no data + self.channel = self._default_channel_labels() if unit is not None: self.unit = unit else: - self.unit = self._get_default_unit() + self.unit = self._default_unit_labels() class EventData(DiscreteData): From 048db30eb035cd30e6666e4185fb40f9187d8374 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 10 Jan 2023 17:18:53 +0100 Subject: [PATCH 055/135] FIX: channel property must be array Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index c3ab8e737..234d83b00 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -375,7 +375,7 @@ def channel(self, chan): if nChan != len(chan): raise SPYValueError(f"exactly {nChan} channel labels") array_parser(chan, varname="channel", ntype="str", dims=(nChan, )) - self._channel = chan + self._channel = np.array(chan) def _default_channel_labels(self): From 6427b35fe0e4fa7e185e92adcc1fa6f270775c4a Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 11:08:24 +0100 Subject: [PATCH 056/135] NEW: add more debug log messages in backend funcs --- syncopy/specest/mtmconvol.py | 5 +++++ syncopy/specest/mtmfft.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/syncopy/specest/mtmconvol.py b/syncopy/specest/mtmconvol.py index 923666782..46bc6bc2f 100644 --- a/syncopy/specest/mtmconvol.py +++ b/syncopy/specest/mtmconvol.py @@ -5,6 +5,8 @@ # Builtin/3rd party package imports import numpy as np +import logging +import platform from scipy import signal # local imports @@ -114,6 +116,9 @@ def mtmconvol(data_arr, samplerate, nperseg, noverlap=None, taper="hann", # Short time Fourier transforms (nTime x nTapers x nFreq x nChannels) ftr = np.zeros((nTime, windows.shape[0], nFreq, nChannels), dtype='complex64') + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Running mtmconvol on {len(windows)} windows, data chunk has {nSamples} samples and {nChannels} channels.") + for taperIdx, win in enumerate(windows): # ftr has shape (nFreq, nChannels, nTime) pxx, _, _ = stft(data_arr, samplerate, window=win, diff --git a/syncopy/specest/mtmfft.py b/syncopy/specest/mtmfft.py index c729d809f..27490e0bd 100644 --- a/syncopy/specest/mtmfft.py +++ b/syncopy/specest/mtmfft.py @@ -6,6 +6,8 @@ # Builtin/3rd party package imports import numpy as np from scipy import signal +import logging +import platform # local imports from ._norm_spec import _norm_spec, _norm_taper @@ -95,6 +97,9 @@ def mtmfft(data_arr, # Fourier transforms (nTapers x nFreq x nChannels) ftr = np.zeros((windows.shape[0], nFreq, nChannels), dtype='complex64') + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Running mtmfft on {len(windows)} windows, data chunk has {nSamples} samples and {nChannels} channels.") + for taperIdx, win in enumerate(windows): win = np.tile(win, (nChannels, 1)).T win *= data_arr From aba3965778789920dcab77049ce32e1996a1c018 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 13:20:02 +0100 Subject: [PATCH 057/135] NEW: add more debug information --- syncopy/specest/stft.py | 6 ++++++ syncopy/specest/superlet.py | 11 +++++++++-- syncopy/specest/wavelet.py | 5 +++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/syncopy/specest/stft.py b/syncopy/specest/stft.py index 93eb00ce4..c589af876 100644 --- a/syncopy/specest/stft.py +++ b/syncopy/specest/stft.py @@ -6,6 +6,8 @@ # Builtin/3rd party package imports import numpy as np import scipy.signal as sci_sig +import logging +import platform # local imports from ._norm_spec import _norm_spec @@ -132,6 +134,10 @@ def stft(dat, # Apply window by multiplication dat = dat * window + logger = logging.getLogger("syncopy_" + platform.node()) + pad_status = "with padding" if padded else "without padding" + logger.debug(f"Running short time Fourier transform {pad_status}, detrend={detrend} and overlap of {noverlap}.") + times = np.arange(nperseg / 2, dat.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / fs if boundary is not None: diff --git a/syncopy/specest/superlet.py b/syncopy/specest/superlet.py index 1b69061be..ea15bf097 100644 --- a/syncopy/specest/superlet.py +++ b/syncopy/specest/superlet.py @@ -7,6 +7,8 @@ # Builtin/3rd party package imports import numpy as np +import logging +import platform from scipy.signal import fftconvolve @@ -24,7 +26,7 @@ def superlet( Performs Superlet Transform (SLT) according to Moca et al. [1]_ Both multiplicative SLT and fractional adaptive SLT are available. The former is recommended for a narrow frequency band of interest, - whereas the is better suited for the analysis of a broad range + whereas the latter is better suited for the analysis of a broad range of frequencies. A superlet (SL) is a set of Morlet wavelets with increasing number @@ -61,7 +63,7 @@ def superlet( than 3 increase `order_min` as to never have less than 3 cycles in a wavelet! adaptive : bool - Wether to perform multiplicative SLT or fractional adaptive SLT. + Whether to perform fractional adaptive SLT or multiplicative SLT. If set to True, the order of the wavelet set will increase linearly with the frequencies of interest from `order_min` to `order_max`. If set to False the same SL will be used for @@ -80,10 +82,13 @@ def superlet( """ + logger = logging.getLogger("syncopy_" + platform.node()) # adaptive SLT if adaptive: + logger.debug(f"Running fractional adaptive superlet transform with order_min={order_min}, order_max={order_max} and c_1={c_1} on data with shape {data_arr.shape}.") + gmean_spec = FASLT(data_arr, samplerate, scales, @@ -94,6 +99,8 @@ def superlet( # multiplicative SLT else: + logger.debug(f"Running multiplicative superlet transform with order_min={order_min}, order_max={order_max} and c_1={c_1} on data with shape {data_arr.shape}.") + gmean_spec = multiplicativeSLT(data_arr, samplerate, scales, diff --git a/syncopy/specest/wavelet.py b/syncopy/specest/wavelet.py index 8f14f5837..a2a20f40d 100644 --- a/syncopy/specest/wavelet.py +++ b/syncopy/specest/wavelet.py @@ -5,6 +5,8 @@ # Builtin/3rd party package imports import numpy as np +import logging +import platform # Local imports from syncopy.specest.wavelets import cwt @@ -37,6 +39,9 @@ def wavelet(data_arr, samplerate, scales, wavelet): Shape is (len(scales),) + data_arr.shape """ + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Running wavelet transform on data with shape {data_arr.shape} and samplerate {samplerate}.") + spec = cwt(data_arr, wavelet=wavelet, widths=scales, dt=1 / samplerate, axis=0) return spec From 3f5a3ea72ffb5cdc248c8f199807f1807293f8e5 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 13:51:11 +0100 Subject: [PATCH 058/135] NEW: add more debug messages in backend funcs --- syncopy/statistics/psth.py | 5 +++++ syncopy/statistics/summary_stats.py | 10 +++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/syncopy/statistics/psth.py b/syncopy/statistics/psth.py index 93d93839e..e70399eaa 100644 --- a/syncopy/statistics/psth.py +++ b/syncopy/statistics/psth.py @@ -1,4 +1,6 @@ import numpy as np +import logging +import platform from scipy.stats import iqr @@ -62,6 +64,9 @@ def psth(trl_dat, channels = trl_dat[:, 1] units = trl_dat[:, 2] + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Computing peristimulus time histogram (PSTH) on data with {samples.size} samples, {channels.size} channels, {units.size} units and samplerate {samplerate}.") + # get relative spike times for all events in trial times = _calc_time(samples, trl_start, onset, samplerate) diff --git a/syncopy/statistics/summary_stats.py b/syncopy/statistics/summary_stats.py index ae4ab6c5d..db2e4234f 100644 --- a/syncopy/statistics/summary_stats.py +++ b/syncopy/statistics/summary_stats.py @@ -6,6 +6,8 @@ # Builtin/3rd party package imports import numpy as np +import logging +import platform # Local imports # from .selectdata import _get_selection_size @@ -74,7 +76,7 @@ def std(spy_data, dim, keeptrials=True, **kwargs): Must be present in the ``spy_data`` object, e.g. 'channel' or 'trials' keeptrials : bool - Set to ``False`` to trigger additional trial averagin + Set to ``False`` to trigger additional trial averaging. Has no effect if ``dim='trials'``. Returns @@ -206,6 +208,9 @@ def itc(spec_data, **kwargs): act = "real valued spectral data" raise SPYValueError(lgl, 'spec_data', act) + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Computing intertrial coherence on data chunk with shape {spec_data.shape}.") + # takes care of remaining checks res = _trial_statistics(spec_data, operation='itc') @@ -260,6 +265,9 @@ def _statistics(spy_data, operation, dim, keeptrials=True, **kwargs): 'dim': dim, 'keeptrials': keeptrials} + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Computing descriptive statistic {operation} on input from {spy_data.filename} along dimension {dim}, keeptrials={keeptrials}.") + # If no active selection is present, create a "fake" all-to-all selection # to harmonize processing down the road (and attach `_cleanup` attribute for later removal) if spy_data.selection is None: From 2c711f86197a04dd05f38b77bb96b55f2aa14d1b Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 14:10:49 +0100 Subject: [PATCH 059/135] NEW: add more debug messages in backend funcs --- syncopy/preproc/compRoutines.py | 18 +++++++++++++++++- syncopy/statistics/spike_psth.py | 4 +++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/syncopy/preproc/compRoutines.py b/syncopy/preproc/compRoutines.py index a596317bb..3cc66f84d 100644 --- a/syncopy/preproc/compRoutines.py +++ b/syncopy/preproc/compRoutines.py @@ -7,6 +7,7 @@ # Builtin/3rd party package imports import numpy as np import scipy.signal as sci +import logging, platform from inspect import signature # syncopy imports @@ -391,6 +392,9 @@ def hilbert_cF(dat, output='abs', timeAxis=0, noCompute=False, chunkShape=None): if noCompute: return outShape, fmt + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Computing Hilbert transformation on data chunk with shape {dat.shape} along axis 0.") + trafo = sci.hilbert(dat, axis=0) return spectralConversions[output](trafo) @@ -473,6 +477,9 @@ def downsample_cF(dat, outShape[0] = int(np.ceil(dat.shape[0] / skipped)) return tuple(outShape), dat.dtype + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Downsampling data chunk with shape {dat.shape} from samplerate {samplerate} to {new_samplerate}.") + resampled = downsample(dat, samplerate, new_samplerate) return resampled @@ -583,6 +590,9 @@ def resample_cF(dat, new_nSamples = int(np.ceil(nSamples * fs_ratio)) return (new_nSamples, dat.shape[1]), dat.dtype + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Resampling data chunk with shape {dat.shape} from samplerate {samplerate} to {new_samplerate} with lpfreq={lpfreq}, order={order}.") + resampled = resample(dat, samplerate, new_samplerate, @@ -637,7 +647,7 @@ def detrending_cF(dat, polyremoval=None, timeAxis=0, noCompute=False, chunkShape """ Simple cF to wire SciPy's `detrend` to our CRs, - supported are constant and linear detrending + supported are constant and linear detrending. Parameters ---------- @@ -688,6 +698,9 @@ def detrending_cF(dat, polyremoval=None, timeAxis=0, noCompute=False, chunkShape if noCompute: return outShape, np.float32 + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Detrending data chunk with shape {dat.shape} with polyremoval={polyremoval}.") + # detrend if polyremoval == 0: dat = sci.detrend(dat, type='constant', axis=0, overwrite_data=True) @@ -782,6 +795,9 @@ def standardize_cF(dat, polyremoval=None, timeAxis=0, noCompute=False, chunkShap elif polyremoval == 1: dat = sci.detrend(dat, type='linear', axis=0, overwrite_data=True) + logger = logging.getLogger("syncopy_" + platform.node()) + logger.debug(f"Standardizing data chunk with shape {dat.shape} (prior polyremoval was {polyremoval}).") + # standardize dat = (dat - np.mean(dat, axis=0)) / np.std(dat, axis=0) diff --git a/syncopy/statistics/spike_psth.py b/syncopy/statistics/spike_psth.py index f6b6a986f..04b8ee797 100644 --- a/syncopy/statistics/spike_psth.py +++ b/syncopy/statistics/spike_psth.py @@ -5,6 +5,8 @@ import numpy as np from copy import deepcopy +import logging +import platform # Syncopy imports import syncopy as spy @@ -159,7 +161,7 @@ def spike_psth(data, # apply the updated selection data.selectdata(select, inplace=True) - + # now redefine local variables trl_def = data.selection.trialdefinition sinfo = data.selection.trialdefinition[:, :2] From 8356bf3b5dc02fcc77dc16c87fbc6133248cbf5b Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 14:27:31 +0100 Subject: [PATCH 060/135] FIX: fix shape output for SpectralData instance --- syncopy/statistics/summary_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/statistics/summary_stats.py b/syncopy/statistics/summary_stats.py index db2e4234f..a193248ff 100644 --- a/syncopy/statistics/summary_stats.py +++ b/syncopy/statistics/summary_stats.py @@ -209,7 +209,7 @@ def itc(spec_data, **kwargs): raise SPYValueError(lgl, 'spec_data', act) logger = logging.getLogger("syncopy_" + platform.node()) - logger.debug(f"Computing intertrial coherence on data chunk with shape {spec_data.shape}.") + logger.debug(f"Computing intertrial coherence on SpectralData instancewith shape {spec_data.data.shape}.") # takes care of remaining checks res = _trial_statistics(spec_data, operation='itc') From 46389d52eec5203badcfc939962783f568b01088 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 14:37:28 +0100 Subject: [PATCH 061/135] CHG: also log local stuff to a logfile, in addition to console --- syncopy/__init__.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 234100ec8..bc4244f01 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -122,12 +122,18 @@ warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") loglevel = "WARNING" -# The logger for local/sequential stuff -- goes to terminal. +# The logger for local/sequential stuff -- goes to terminal and to a file. spy_logger = logging.getLogger('syncopy') fmt = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') sh = logging.StreamHandler(sys.stdout) sh.setFormatter(fmt) spy_logger.addHandler(sh) + +logfile = os.path.join(__logdir__, f'syncopy.log') +fh = logging.FileHandler(logfile) # The default mode is 'append'. +spy_logger.addHandler(fh) + + spy_logger.setLevel(loglevel) spy_logger.debug(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") spy_logger.info(f"Syncopy log level set to: {loglevel}.") @@ -150,14 +156,14 @@ def filter(self, record): record.hostname = HostnameFilter.hostname return True -logfile = os.path.join(__logdir__, f'syncopy_{host}.log') -fh = logging.FileHandler(logfile) # The default mode is 'append'. -fh.addFilter(HostnameFilter()) +logfile_par = os.path.join(__logdir__, f'syncopy_{host}.log') +fhp = logging.FileHandler(logfile_par) # The default mode is 'append'. +fhp.addFilter(HostnameFilter()) spy_parallel_logger.setLevel(parloglevel) fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') -fh.setFormatter(fmt_with_hostname) -spy_parallel_logger.addHandler(fh) -spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile}' at level {loglevel}.") +fhp.setFormatter(fmt_with_hostname) +spy_parallel_logger.addHandler(fhp) +spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") # Set upper bound for temp directory size (in GB) From 9a2a4f2f169dce7da514c1d36336ba919f4dfe42 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 11 Jan 2023 14:39:49 +0100 Subject: [PATCH 062/135] NEW: setup logging of uncaught exceptions --- syncopy/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index bc4244f01..6e8ee540a 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -165,6 +165,15 @@ def filter(self, record): spy_parallel_logger.addHandler(fhp) spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") +## Setup global handler to log uncaught exceptions: +def handle_exception(exc_type, exc_value, exc_traceback): + if issubclass(exc_type, KeyboardInterrupt): + sys.__excepthook__(exc_type, exc_value, exc_traceback) + return + spy_parallel_logger.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) + +sys.excepthook = handle_exception + # Set upper bound for temp directory size (in GB) __storagelimit__ = 10 From 77b1852eabc722e94c3374aa4a6f134bf8d59994 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Wed, 11 Jan 2023 15:25:52 +0100 Subject: [PATCH 063/135] FIX: coherence test --- syncopy/tests/test_connectivity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index 9afc1580e..4732d0359 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -15,7 +15,7 @@ import syncopy as spy from syncopy import AnalogData, SpectralData -import syncopy.nwanalysis.connectivity_analysis as ca +from syncopy.nwanalysis.connectivity_analysis import connectivity_outputs from syncopy import connectivityanalysis as cafunc import syncopy.tests.synth_data as synth_data import syncopy.tests.helpers as helpers @@ -500,7 +500,7 @@ def test_coh_polyremoval(self): def test_coh_outputs(self): - for output in ca.coh_outputs: + for output in connectivity_outputs: coh = cafunc(self.data, method='coh', output=output) From 9a61c46382854c1f72e0a47cd26c8290d310bd27 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Wed, 11 Jan 2023 16:17:56 +0100 Subject: [PATCH 064/135] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79a2007d3..29f9222d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file. ### Fixed - fix bug #394 'Copying a spy.StructDict returns a dict'. +- serializable `.cfg` #392 ## [2022.12] From 6a445a993daaa0d2f6b19fa8c3992213ebbb2b44 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Wed, 11 Jan 2023 16:32:53 +0100 Subject: [PATCH 065/135] FIX: Test issue #257 - either initialize with `data=None`, or `data.size != 0` Changes to be committed: modified: syncopy/datatype/discrete_data.py modified: syncopy/tests/test_discretedata.py --- syncopy/datatype/discrete_data.py | 4 +++- syncopy/tests/test_discretedata.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 234d83b00..465b25018 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -318,7 +318,9 @@ def __init__(self, data=None, samplerate=None, trialid=None, **kwargs): # Fill in dimensional info definetrial(self, kwargs.get("trialdefinition")) - + elif self.data.size == 0: + # initialization with empty data not allowed + raise SPYValueError("non empty data set", 'data') class SpikeData(DiscreteData): """Spike times of multi- and/or single units diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index 9a668f464..6c8e307cd 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -55,8 +55,8 @@ def test_empty(self): def test_issue_257_fixed_no_error_for_empty_data(self): """This tests that the data object is created without throwing an error, see #257.""" - data = SpikeData(np.column_stack(([],[],[])), dimord = ['sample', 'channel', 'unit'], samplerate = 30000) - assert data.dimord == ["sample", "channel", "unit"] + with pytest.raises(SPYValueError, match='non empty'): + data = SpikeData(np.column_stack(([],[],[])), dimord = ['sample', 'channel', 'unit'], samplerate = 30000) def test_nparray(self): dummy = SpikeData(self.data) From 0ae0718c664708be0de4bfe481a4447dd66951d2 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Wed, 11 Jan 2023 16:36:20 +0100 Subject: [PATCH 066/135] FIX: catch empty data Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 465b25018..75f0d2a73 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -310,17 +310,17 @@ def __init__(self, data=None, samplerate=None, trialid=None, **kwargs): # Call initializer super().__init__(data=data, **kwargs) - if self.data is not None and self.data.size != 0: + if self.data is not None: + + if self.data.size == 0: + # initialization with empty data not allowed + raise SPYValueError("non empty data set", 'data') # In case of manual data allocation (reading routine would leave a # mark in `cfg`), fill in missing info if self.sampleinfo is None: - # Fill in dimensional info definetrial(self, kwargs.get("trialdefinition")) - elif self.data.size == 0: - # initialization with empty data not allowed - raise SPYValueError("non empty data set", 'data') class SpikeData(DiscreteData): """Spike times of multi- and/or single units From c4a34a28e5bf97196aa7547f68201b79bcf5e865 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 11 Jan 2023 16:54:33 +0100 Subject: [PATCH 067/135] NEW: update logging dev docs --- doc/source/developer/logging.rst | 44 ++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index a5e0b56b0..c5bd33a73 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -7,10 +7,32 @@ Syncopy uses the `Python logging module `_ for details on the supported log levels. + + +Log file location +----------------- + +All Syncopy log files are saved in a configurable directory which we refer to as `SPYLOGDIR`. By default, `SPYLOGDIR` is set to the directory `.spy/logs/` in your home directory (accessible as `~/.spy/logs/` under Linux and Mac OS), and it can be adapted by setting the environment variable `SPYLOGDIR` before running your application. + +E.g., if your Python script using Syncopy is `~/neuro/paperfig1.py`, you can set the log level and log directory on the command line like this in the Bash shell: + +.. code-block:: shell + export SPYLOGDIR=/tmp/spy + export SPYLOGLEVEL=DEBUG + ~/neuro/paperfig1.py + + + + Logging code that runs locally ------------------------------- -For all code that is run on the local machine, Syncopy logs to a logger named `'syncopy'` which is handled by the console. +For all code that is run on the local machine, Syncopy logs to a logger named `'syncopy'` which is handled by both the console and the logfile `'SPYLOGDIR/syncopy.log'`. To adapt the local logging behaviour of Syncopy, one can configure the logger as explained in the documentation for the logging module, e.g., in your application that uses Syncopy: @@ -24,22 +46,28 @@ To adapt the local logging behaviour of Syncopy, one can configure the logger as # Change the log level: logger.setLevel(logging.DEBUG) - # Make it log to a file: - fh = logging.FileHandler('syncopy_log_within_my_app.log') + # Add another handler that logs to a file: + fh = logging.FileHandler('syncopy_debug_log.log') logger.addHandler(fh) + logger.info("My app starts now.") # The rest of your application code goes here. Logging code that potentially runs remotely -------------------------------------------- -The parallel code that performs the heavy lifting on the Syncopy data will be executed on remote machines (cluster nodes) when Syncopy is run in an HPC environment. Therefore, -special handling is required for these parts of the code, and we need to log to one log file per remote machine to avoid race conditions and +The parallel code that performs the heavy lifting on the Syncopy data (i.e., what we call `compute functions`) will be executed on remote machines (cluster nodes) when Syncopy is run in an HPC environment. Therefore, +special handling is required for these parts of the code, and we need to log to one log file per remote machine to avoid race conditions. Here is how to log with the remote logger: -Log levels ------------ +.. code-block:: python + + import syncopy + import logging, platform + + par_logger = logging.getLogger("syncopy_" + platform.node()) + par_logger.info("Code run on remote machine is being run.") + -The default log level is for the Syncopy logger is `'logging.WARNING'` (from now on referred to as `'WARNING'`). This means that you will not see any Syncopy messages below that threshold, i.e., messages printed with log levels `'DEBUG'` and `'INFO'`. To change the log level, you can either use the logging API in your application code as explained above, or set the environment variable `'SPYLOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs of the logging module `_ for details on the supported log levels. From e057290bc73a1347cfa3285e74e62d2cd9944498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 11 Jan 2023 17:07:21 +0100 Subject: [PATCH 068/135] NEW: finish logging docs --- doc/source/developer/logging.rst | 12 +++++++----- syncopy/datatype/discrete_data.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index c5bd33a73..a8983b1dc 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -27,8 +27,6 @@ E.g., if your Python script using Syncopy is `~/neuro/paperfig1.py`, you can set ~/neuro/paperfig1.py - - Logging code that runs locally ------------------------------- @@ -57,17 +55,21 @@ To adapt the local logging behaviour of Syncopy, one can configure the logger as Logging code that potentially runs remotely -------------------------------------------- -The parallel code that performs the heavy lifting on the Syncopy data (i.e., what we call `compute functions`) will be executed on remote machines (cluster nodes) when Syncopy is run in an HPC environment. Therefore, -special handling is required for these parts of the code, and we need to log to one log file per remote machine to avoid race conditions. Here is how to log with the remote logger: +The parallel code that performs the heavy lifting on the Syncopy data (i.e., what we call `compute functions`) will be executed on remote machines when Syncopy is run in an HPC environment. Therefore, +special handling is required for these parts of the code, and we need to log to one log file per remote machine. +Syncopy automatically configures a suitable logger named `syncopy_` on each host, where `` is the hostname. Each of these loggers is attached to the respective logfile `'SPYLOGDIR/syncopy_.log'`, where `` is the hostname, which ensures that logging works properly even if you log into the same directory on all remote machines (e.g., a home directory that is mounted on all machines via a network file system). +Here is how to log with the remote logger: .. code-block:: python import syncopy import logging, platform + # ... + # In some cF or backend function: par_logger = logging.getLogger("syncopy_" + platform.node()) par_logger.info("Code run on remote machine is being run.") - +This is all you need to do. If you want to configure different log levels for the remote logger and the local one, you can configure the environment variable `SPYPARLOGLEVEL` in addition to `SPYLOGLEVEL`. diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 376c0fbb5..9a8343dd0 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -493,7 +493,7 @@ def __init__(self, trialdefinition=trialdefinition, samplerate=samplerate, dimord=dimord) - + # instance attribute to allow modification self._hdfFileAttributeProperties = DiscreteData._hdfFileAttributeProperties + ("channel",) From bb65e6ad67d0cb14b8423f19bfb57852e9014a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 11 Jan 2023 17:18:22 +0100 Subject: [PATCH 069/135] NEW: add logging tests --- syncopy/tests/test_logging.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 syncopy/tests/test_logging.py diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py new file mode 100644 index 000000000..ef7248c9d --- /dev/null +++ b/syncopy/tests/test_logging.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# +# Test logging. +# + +import os + +# Local imports +import syncopy as spy + + +class TestLogging: + + def test_logfile_exists(self): + logfile = os.path.join(spy.__logdir__, "syncopy.log") + assert os.path.isfile(logfile) + + From be648564c5ed4906d9da69408a7f254ddca1c5b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 11 Jan 2023 17:34:42 +0100 Subject: [PATCH 070/135] NEW: extend logging tests --- syncopy/tests/test_logging.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py index ef7248c9d..2a6222ff6 100644 --- a/syncopy/tests/test_logging.py +++ b/syncopy/tests/test_logging.py @@ -7,6 +7,8 @@ # Local imports import syncopy as spy +from syncopy.shared.log import get_logger +from syncopy.shared.errors import SPYLog class TestLogging: @@ -15,4 +17,27 @@ def test_logfile_exists(self): logfile = os.path.join(spy.__logdir__, "syncopy.log") assert os.path.isfile(logfile) + def test_default_log_level_is_warning(self): + logfile = os.path.join(spy.__logdir__, "syncopy.log") + assert os.path.isfile(logfile) + num_lines_bofore = sum(1 for line in open(logfile)) + + # Log something with log level info and DEBUG, which should not affect the logfile. + logger = get_logger() + logger.info("I am adding an INFO level log entry.") + SPYLog("I am adding a DEBUG level log entry.", loglevel="DEBUG") + + num_lines_after_info_debug = sum(1 for line in open(logfile)) + + assert num_lines_bofore == num_lines_after_info_debug + + # Now log something with log level WARNING + SPYLog("I am adding a WARNING level log entry.", loglevel="WARNING") + + num_lines_after_warning = sum(1 for line in open(logfile)) + assert num_lines_after_info_debug + 1 == num_lines_after_warning + + + + From 7879676189a1f174b6a3ce9918fa9df7fc9000db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Sch=C3=A4fer?= Date: Wed, 11 Jan 2023 18:12:45 +0100 Subject: [PATCH 071/135] CHG: also add timestamp to logfile, not only console --- syncopy/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 6e8ee540a..3a8daf9ff 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -131,6 +131,7 @@ logfile = os.path.join(__logdir__, f'syncopy.log') fh = logging.FileHandler(logfile) # The default mode is 'append'. +fh.setFormatter(fmt) spy_logger.addHandler(fh) From 2862b9a6d15106f0e201d1d2ece6e40affb20f81 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Thu, 12 Jan 2023 09:26:10 +0100 Subject: [PATCH 072/135] CHG: use OS-independent path --- syncopy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 6e8ee540a..56a4dc8c0 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -50,7 +50,7 @@ try: dd.get_client() except ValueError: - silence_file = os.path.expanduser("~/.spy/silentstartup") + silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): print(msg) From b7c9e6616f91c1657a469485f422a794eb4735c8 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Thu, 12 Jan 2023 12:38:04 +0100 Subject: [PATCH 073/135] NEW: add TODO about excep handler --- syncopy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index bdf117710..6447958e0 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -173,7 +173,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): return spy_parallel_logger.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) -sys.excepthook = handle_exception +sys.excepthook = handle_exception # TODO: this may get overwritten below with SPYExceptionHandler, should log in there. # Set upper bound for temp directory size (in GB) From baa865aac038c6f865d202635557a2be06e5c8f1 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Thu, 12 Jan 2023 13:27:13 +0100 Subject: [PATCH 074/135] test for output data type added --- syncopy/tests/test_connectivity.py | 47 +++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index 9afc1580e..fed8d3901 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -213,12 +213,11 @@ def test_gr_selections(self): # test one final selection into a result # obtained via orignal SpectralData input - selections[0].pop('latency') + selections[0].pop('latency') result_ad = cafunc(self.data, self.cfg, method='granger', select=selections[0]) result_spec = cafunc(self.spec, method='granger', select=selections[0]) assert np.allclose(result_ad.trials[0], result_spec.trials[0], atol=1e-3) - def test_gr_foi(self): try: @@ -441,7 +440,7 @@ def test_coh_selections(self): # test one final selection into a result # obtained via orignal SpectralData input - selections[0].pop('latency') + selections[0].pop('latency') result_ad = cafunc(self.data, self.cfg, method='coh', select=selections[0]) result_spec = cafunc(self.spec, method='coh', select=selections[0]) assert np.allclose(result_ad.trials[0], result_spec.trials[0], atol=1e-3) @@ -516,6 +515,46 @@ def test_coh_outputs(self): assert np.all(np.imag(coh.trials[0]) == 0) +class TestCSD: + nSamples = 1400 + nChannels = 4 + nTrials = 100 + fs = 1000 + + # -- two harmonics with individual phase diffusion -- + + f1, f2 = 20, 40 + # a lot of phase diffusion (1% per step) in the 20Hz band + s1 = synth_data.phase_diffusion(nTrials, freq=f1, + eps=.01, + nChannels=nChannels, + nSamples=nSamples) + + # little diffusion in the 40Hz band + s2 = synth_data.phase_diffusion(nTrials, freq=f2, + eps=.001, + nChannels=nChannels, + nSamples=nSamples) + + wn = synth_data.white_noise(nTrials, nChannels=nChannels, nSamples=nSamples) + + # superposition + data = s1 + s2 + wn + data.samplerate = fs + time_span = [-1, nSamples / fs - 1] # -1s offset + + # spectral analysis + cfg = spy.StructDict() + cfg.tapsmofrq = 1.5 + cfg.foilim = [5, 60] + cfg.method = 'csd' + + spec = spy.connectivityanalysis(data, cfg) + + def test_data_output_type(self): + assert self.spec.data.dtype.name == 'complex64' + + class TestCorrelation: nChannels = 5 @@ -715,9 +754,9 @@ def plot_corr(res, i, j, label=''): ax.legend() - if __name__ == '__main__': T1 = TestGranger() T2 = TestCoherence() T3 = TestCorrelation() T4 = TestSpectralInput() + T5 = TestCSD() From 33e1338f719986f1e9ae577c450be82a42d11128 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Thu, 12 Jan 2023 13:35:51 +0100 Subject: [PATCH 075/135] CSD is not the normalized spectra --- syncopy/nwanalysis/connectivity_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/nwanalysis/connectivity_analysis.py b/syncopy/nwanalysis/connectivity_analysis.py index 31672c080..61555162e 100644 --- a/syncopy/nwanalysis/connectivity_analysis.py +++ b/syncopy/nwanalysis/connectivity_analysis.py @@ -70,7 +70,7 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", * **pad**: either pad to an absolute length in seconds or set to `'nextpow2'` "csd" : ('Multi-) tapered cross spectral density estimate - Computes the normalized cross spectral densities between all channel combinations + Computes the cross spectral estimates between all channel combinations output : complex spectrum From 0c1f98855e4f5eebf8027a249d76acf424082401 Mon Sep 17 00:00:00 2001 From: Katharine Shapcott Date: Thu, 12 Jan 2023 15:40:08 +0100 Subject: [PATCH 076/135] CHG: Use _trialslice for speed in discrete --- syncopy/datatype/discrete_data.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index fc71a4871..1cb2190ad 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -210,12 +210,15 @@ def _preview_trial(self, trialno): syncopy.datatype.base_data.FauxTrial : class definition and further details syncopy.shared.computational_routine.ComputationalRoutine : Syncopy compute engine """ - - trialIdx = np.where(self.trialid == trialno)[0] + trlSlice = self._trialslice[trialno] + trialIdx = np.arange(trlSlice.start, trlSlice.stop) #np.where(self.trialid == trialno)[0] nCol = len(self.dimord) - idx = [trialIdx.tolist(), slice(0, nCol)] + idx = [[], slice(0, nCol)] if self.selection is not None: # selections are harmonized, just take `.time` idx[0] = trialIdx[self.selection.time[self.selection.trial_ids.index(trialno)]].tolist() + else: + idx[0] = trialIdx.tolist() + shp = [len(idx[0]), nCol] return FauxTrial(shp, tuple(idx), self.data.dtype, self.dimord) @@ -257,7 +260,7 @@ def _get_time(self, trials, toi=None, toilim=None): if toilim is not None: allTrials = self.trialtime for trlno in trials: - trlTime = allTrials[self.trialid == trlno] + trlTime = allTrials[self._trialslice[trlno]] _, selTime = best_match(trlTime, toilim, span=True) selTime = selTime.tolist() if len(selTime) > 1 and np.diff(trlTime).min() > 0: @@ -268,11 +271,11 @@ def _get_time(self, trials, toi=None, toilim=None): elif toi is not None: allTrials = self.trialtime for trlno in trials: - trlTime = allTrials[self.trialid == trlno] + trlTime = allTrials[self._trialslice[trlno]] _, arrayIdx = best_match(trlTime, toi) # squash duplicate values then readd _, xdi = np.unique(trlTime[arrayIdx], return_index=True) - arrayIdx = arrayIdx[np.sort(xdi)] + arrayIdx = arrayIdx[xdi] # we assume sorted data selTime = [] for t in arrayIdx: selTime += np.where(trlTime[t] == trlTime)[0].tolist() @@ -424,9 +427,8 @@ def _get_unit(self, trials, units=None): """ if units is not None: indices = [] - allUnits = self.data[:, self.dimord.index("unit")] for trlno in trials: - thisTrial = allUnits[self.trialid == trlno] + thisTrial = self.data[self._trialslice[trlno], self.dimord.index("unit")] trialUnits = [] for unit in units: trialUnits += list(np.where(thisTrial == unit)[0]) @@ -550,9 +552,8 @@ def _get_eventid(self, trials, eventids=None): """ if eventids is not None: indices = [] - allEvents = self.data[:, self.dimord.index("eventid")] for trlno in trials: - thisTrial = allEvents[self.trialid == trlno] + thisTrial = self.data[self._trialslice[trlno], self.dimord.index("eventid")] trialEvents = [] for event in eventids: trialEvents += list(np.where(thisTrial == event)[0]) From 3e3e7138aff78bac4acb1e205cc6a68d022648c7 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Thu, 12 Jan 2023 16:36:43 +0100 Subject: [PATCH 077/135] WIP: New ad selection tests Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 817 ++++--------------------------- 1 file changed, 101 insertions(+), 716 deletions(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 24217d422..828471b77 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -8,6 +8,7 @@ import numpy as np import inspect import dask.distributed as dd +from numbers import Number # Local imports import syncopy.datatype as spd @@ -31,727 +32,111 @@ # "time" + "channel" + "trial" `AnalogData` selections etc.) -class TestSelector(): +class Test_AD_Selections(): - # Set up "global" parameters for data objects to be tested (we only test - # equidistant trials here) + # Set up "global" parameters for data objects to be tested nChannels = 10 - nSamples = 30 - nTrials = 5 - lenTrial = int(nSamples / nTrials) - 1 - nFreqs = 15 - nSpikes = 100 + nSamples = 5 # per trial + nTrials = 3 samplerate = 2.0 data = {} trl = {} - # Prepare selector results for valid/invalid selections - selectDict = {} - selectDict["channel"] = {"valid": (["channel03", "channel01"], - ["channel03", "channel01", "channel01", "channel02"], # repetition - ["channel01", "channel01", "channel02", "channel03"], # preserve repetition - "channel03", # string -> scalar - 0, # scalar - [4, 2, 5], - [4, 2, 2, 5, 5], # repetition - [0, 0, 1, 2, 3], # preserve repetition, don't convert to slice - range(0, 3), - range(5, 8), - None, - "all", - [0, 1, 2, 3], # contiguous list... - [2, 3, 5]), # non-contiguous list... - "result": ([2, 0], - [2, 0, 0, 1], - [0, 0, 1, 2], - [2], - [0], - [4, 2, 5], - [4, 2, 2, 5, 5], - [0, 0, 1, 2, 3], - slice(0, 3, 1), - slice(5, 8, 1), - slice(None, None, 1), - slice(None, None, 1), - slice(0, 4, 1), # ...gets converted to slice - [2, 3, 5]), # stays as is - "invalid": (["channel200", "channel400"], - ["invalid"], - tuple("wrongtype"), - "notall", - range(0, 100), - [40, 60, 80]), - "errors": (SPYValueError, - SPYValueError, - SPYTypeError, - SPYValueError, - SPYValueError, - SPYValueError)} - - selectDict["taper"] = {"valid": ([4, 2, 3], - [4, 2, 2, 3], # repetition - [0, 1, 1, 2, 3], # preserve repetition, don't convert to slice - range(0, 3), - range(2, 5), - None, - "all", - 0, # scalar - [0, 1, 2, 3], # contiguous list... - [1, 3, 4]), # non-contiguous list... - "result": ([4, 2, 3], - [4, 2, 2, 3], - [0, 1, 1, 2, 3], - slice(0, 3, 1), - slice(2, 5, 1), - slice(None, None, 1), - slice(None, None, 1), - [0], - slice(0, 4, 1), # ...gets converted to slice - [1, 3, 4]), # stays as is - "invalid": (["taper_typo", "channel400"], - tuple("wrongtype"), - "notall", - range(0, 100), - slice(80, None), - slice(-20, None), - slice(-15, -2), - slice(5, 1), - [40, 60, 80]), - "errors": (SPYValueError, - SPYTypeError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError)} - - # only define valid inputs, the expected (trial-dependent) results are computed below - selectDict["unit"] = {"valid": (["unit3", "unit1"], - ["unit3", "unit1", "unit1", "unit2"], # repetition - ["unit1", "unit1", "unit2", "unit3"], # preserve repetition - [4, 2, 3], - [4, 2, 2, 3], # repetition - [0, 0, 2, 3], # preserve repetition, don't convert to slice - range(0, 3), - range(2, 5), - "all", - "unit3", # string -> scalar - 4, # scalar - [0, 1, 2, 3], # contiguous list... - [1, 3, 4]), # non-contiguous list... - "invalid": (["unit7", "unit77"], - tuple("wrongtype"), - "notall", - range(0, 100), - slice(80, None), - slice(-20, None), - slice(-15, -2), - slice(5, 1), - [40, 60, 80]), - "errors": (SPYValueError, - SPYTypeError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError)} - - # only define valid inputs, the expected (trial-dependent) results are computed below - selectDict["eventid"] = {"valid": ([1, 0], - [1, 1, 0], # repetition - [0, 0, 1, 2], # preserve repetition, don't convert to slice - range(0, 2), - range(1, 2), - "all", - 1, # scalar - [0, 1]), # contiguous list... - "invalid": (["eventid", "eventid"], - tuple("wrongtype"), - "notall", - range(0, 100), - slice(80, None), - slice(-20, None), - slice(-15, -2), - slice(5, 1), - [40, 60, 80]), - "errors": (SPYValueError, - SPYTypeError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError)} - - selectDict["latency"] = {"invalid": (["notnumeric", "stillnotnumeric"], - tuple("wrongtype"), - "notall", - range(0, 10), - [np.nan, 1], - [0.5, 1.5 , 2.0], # more than 2 components - [2.0, 1.5]), # lower bound > upper bound - "errors": (SPYValueError, - SPYTypeError, - SPYValueError, - SPYTypeError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError)} - selectDict["frequency"] = {"invalid": (["notnumeric", "stillnotnumeric"], - tuple("wrongtype"), - "notall", - range(0, 10), - [np.nan, 1], - [-1, 2], # lower limit out of bounds - [2, 900], # upper limit out of bounds - [2, 7, 6], # more than 2 components - [9, 2]), # lower bound > upper bound - "errors": (SPYValueError, - SPYTypeError, - SPYValueError, - SPYTypeError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError, - SPYValueError)} - - # Generate 2D array simulating an AnalogData array - data["AnalogData"] = np.arange(1, nChannels * nSamples + 1).reshape(nSamples, nChannels) - trl["AnalogData"] = np.vstack([np.arange(0, nSamples, nTrials), - np.arange(lenTrial, nSamples + nTrials, nTrials), - np.ones((lenTrial + 1, )), - np.arange(1, lenTrial + 2) * np.pi]).T - - # Generate a 4D array simulating a SpectralData array - data["SpectralData"] = np.arange(1, nChannels * nSamples * nTrials * nFreqs + 1).reshape(nSamples, nTrials, nFreqs, nChannels) - trl["SpectralData"] = trl["AnalogData"] - - # Use a fixed random number generator seed to simulate a 2D SpikeData array - seed = np.random.RandomState(13) - data["SpikeData"] = np.vstack([seed.choice(nSamples, size=nSpikes), - seed.choice(np.arange(0, nChannels), size=nSpikes), - seed.choice(int(nChannels/2), size=nSpikes)]).T - trl["SpikeData"] = trl["AnalogData"] - - # Use a triple-trigger pattern to simulate EventData w/non-uniform trials - data["EventData"] = np.vstack([np.arange(0, nSamples, 1), - np.zeros((int(nSamples), ))]).T - data["EventData"][1::3, 1] = 1 - data["EventData"][2::3, 1] = 2 - trl["EventData"] = trl["AnalogData"] - - # Append customized columns to EventData dataset - data["EventDataDimord"] = np.hstack([data["EventData"], data["EventData"]]) - trl["EventDataDimord"] = trl["AnalogData"] - customEvtDimord = ["sample", "eventid", "custom1", "custom2"] - - # Define data classes to be used in tests below - classes = ["AnalogData", "SpectralData", "SpikeData", "EventData"] - - # test `Selector` constructor w/all data classes - def test_general(self): - - # construct expected results for `DiscreteData` objects defined above - mapDict = {"SpikeData" : "unit", "EventData" : "eventid"} - for dset in ["SpikeData", "EventData", "EventDataDimord"]: - dclass = "".join(dset.partition("Data")[:2]) - prop = mapDict[dclass] - dimord = self.customEvtDimord if dset == "EventDataDimord" else None - discrete = getattr(spd, dclass)(data=self.data[dset], - trialdefinition=self.trl[dclass], - samplerate=self.samplerate, - dimord=dimord) - propIdx = discrete.dimord.index(prop) - - # convert selection from `selectDict` to a usable integer-list - allResults = [] - for selection in self.selectDict[prop]["valid"]: - if isinstance(selection, slice): - if selection.start is selection.stop is None: - selects = [None] - else: - selects = list(range(getattr(discrete, prop).size))[selection] - elif isinstance(selection, range): - selects = list(selection) - elif isinstance(selection, str): - if selection == "all": - selects = [None] - else: - selection = [selection] - elif np.issubdtype(type(selection), np.number): - selection = [selection] - - if isinstance(selection, (list, np.ndarray)): - if isinstance(selection[0], str): - avail = getattr(discrete, prop) - else: - avail = np.arange(getattr(discrete, prop).size) - selects = [] - for sel in selection: - selects += list(np.where(avail == sel)[0]) - - # alternate (expensive) way to get by-trial selection indices - result = [] - print(prop) - print(selects, selection) - for trial in discrete.trials: - if selects[0] is None: - res = slice(0, trial.shape[0], 1) - else: - res = [] - for sel in selects: - res += list(np.where(trial[:, propIdx] == sel)[0]) - if len(res) > 1: - steps = np.diff(res) - if steps.min() == steps.max() == 1: - res = slice(res[0], res[-1] + 1, 1) - result.append(res) - print(result) - print() - allResults.append(result) - - self.selectDict[prop]["result"] = tuple(allResults) - - # wrong type of data and/or selector - with pytest.raises(SPYTypeError): - Selector(np.empty((3,)), {}) - with pytest.raises(SPYValueError): - Selector(spd.AnalogData(), {}) - ang = AnalogData(data=self.data["AnalogData"], - trialdefinition=self.trl["AnalogData"], - samplerate=self.samplerate) - with pytest.raises(SPYTypeError): - Selector(ang, ()) - with pytest.raises(SPYValueError): - Selector(ang, {"wrongkey": [1]}) - - # set/clear in-place data selection (both setting and clearing are idempotent, - # i.e., repeated execution must work, hence the double whammy) - ang.selectdata(trials=[3, 1]) - ang.selectdata(trials=[3, 1]) - ang.selectdata(clear=True) - ang.selectdata(clear=True) - with pytest.raises(SPYValueError) as spyval: - ang.selectdata(trials=[3, 1], clear=True) - assert "no data selectors if `clear = True`" in str(spyval.value) - - # show full/squeezed arrays - # for a single trial an array is returned directly - assert len(ang.show(channel=0, trials=0).shape) == 1 - # multiple trials get returned in a list - assert [len(trl.shape) == 2 for trl in ang.show(channel=0, squeeze=False)] - - # test latency returns arrays for single trial and - # lists for multiple trial selections - assert isinstance(ang.show(trials=0, latency=[0.5, 1]), np.ndarray) - assert isinstance(ang.show(trials=[0, 1], latency=[1, 2]), list) - - # test invalid indexing for .show operations - with pytest.raises(SPYValueError) as err: - ang.show(trials=[1, 0]) - assert "expected unique and sorted" in str(err) - - # go through all data-classes defined above - for dset in self.data.keys(): - dclass = "".join(dset.partition("Data")[:2]) - dimord = self.customEvtDimord if dset == "EventDataDimord" else None - dummy = getattr(spd, dclass)(data=self.data[dset], - trialdefinition=self.trl[dclass], - samplerate=self.samplerate, - dimord=dimord) - - # test trial selection - selection = Selector(dummy, {"trials": [3, 1]}) - assert selection.trial_ids == [3, 1] - selected = selectdata(dummy, trials=[3, 1]) - assert np.array_equal(selected.trials[0], dummy.trials[3]) - assert np.array_equal(selected.trials[1], dummy.trials[1]) - assert selected.trialdefinition.shape == (2, 4) - assert np.array_equal(selected.trialdefinition[:, -1], dummy.trialdefinition[[3, 1], -1]) - - # scalar selection - selection = Selector(dummy, {"trials": 2}) - assert selection.trial_ids == [2] - selected = selectdata(dummy, trials=2) - assert np.array_equal(selected.trials[0], dummy.trials[2]) - assert selected.trialdefinition.shape == (1, 4) - assert np.array_equal(selected.trialdefinition[:, -1], dummy.trialdefinition[[2], -1]) - - # array selection - selection = Selector(dummy, {"trials": np.array([3, 1])}) - assert selection.trial_ids == [3, 1] - selected = selectdata(dummy, trials=[3, 1]) - assert np.array_equal(selected.trials[0], dummy.trials[3]) - assert np.array_equal(selected.trials[1], dummy.trials[1]) - assert selected.trialdefinition.shape == (2, 4) - assert np.array_equal(selected.trialdefinition[:, -1], dummy.trialdefinition[[3, 1], -1]) - - # select all - for trlSec in [None, "all"]: - selection = Selector(dummy, {"trials": trlSec}) - assert selection.trial_ids == list(range(len(dummy.trials))) - selected = selectdata(dummy, trials=trlSec) - for tk, trl in enumerate(selected.trials): - assert np.array_equal(trl, dummy.trials[tk]) - assert np.array_equal(selected.trialdefinition, dummy.trialdefinition) - - # invalid trials - with pytest.raises(SPYValueError): - Selector(dummy, {"trials": [-1, 9]}) - - # test "simple" property setters handled by `_selection_setter` - for prop in ["channel", "taper", "unit", "eventid"]: - if hasattr(dummy, prop): - expected = self.selectDict[prop]["result"] - for sk, sel in enumerate(self.selectDict[prop]["valid"]): - solution = expected[sk] - if dclass == "SpikeData" and prop == "channel": - if isinstance(solution, slice): - start, stop, step = solution.start, solution.stop, solution.step - if start is None: - start = 0 - elif start < 0: - start = len(dummy.channel) + start - if stop is None: - stop = len(dummy.channel) - elif stop < 0: - stop = len(dummy.channel) + stop - if step not in [None, 1]: - solution = list(range(start, stop))[solution] - else: - solution = slice(start, stop, step) - - # ensure typos in selectino keywords are caught - with pytest.raises(SPYValueError) as spv: - Selector(dummy, {prop + "x": sel}) - assert "expected dict with one or all of the following keys:" in str(spv.value) - - # once we're sure `Selector` works, actually select data - selection = Selector(dummy, {prop : sel}) - assert getattr(selection, prop) == solution - selected = selectdata(dummy, {prop : sel}) - - # process `unit` and `enventid` - if prop in selection._byTrialProps: - propIdx = selected.dimord.index(prop) - propArr = np.unique(selected.data[:, propIdx]).astype(np.intp) - assert set(getattr(selected, prop)) == set(getattr(dummy, prop)[propArr]) - tk = 0 - for trialno in range(len(dummy.trials)): - if solution[trialno]: # do not try to compare empty selections - assert np.array_equal(selected.trials[tk], - dummy.trials[trialno][solution[trialno], :]) - tk += 1 - - # `channel` is a special case for `SpikeData` objects - elif dclass == "SpikeData" and prop == "channel": - chanIdx = selected.dimord.index("channel") - chanArr = np.arange(dummy.channel.size) - assert set(selected.data[:, chanIdx]).issubset(chanArr[solution]) - assert set(selected.channel) == set(dummy.channel[solution]) - - # everything else (that is not a `DiscreteData` child) - else: - idx = [slice(None)] * len(dummy.dimord) - idx[dummy.dimord.index(prop)] = solution - assert np.array_equal(np.array(dummy.data)[tuple(idx)], - selected.data) - assert np.array_equal(getattr(selected, prop), - getattr(dummy, prop)[solution]) - - # ensure invalid selection trigger expected errors - for ik, isel in enumerate(self.selectDict[prop]["invalid"]): - with pytest.raises(self.selectDict[prop]["errors"][ik]): - Selector(dummy, {prop : isel}) - else: - - # ensure objects that don't have a `prop` attribute complain - with pytest.raises(SPYValueError): - Selector(dummy, {prop : [0]}) - - # ensure invalid `latency` specifications trigger expected errors - if hasattr(dummy, "time") or hasattr(dummy, "trialtime"): - for ik, isel in enumerate(self.selectDict["latency"]["invalid"]): - with pytest.raises(self.selectDict["latency"]["errors"][ik]): - spy.selectdata(dummy, {"latency": isel}) - else: - # ensure objects that don't have `time` props complain properly - with pytest.raises(SPYValueError): - Selector(dummy, {"latency": [-.5]}) - - # ensure invalid `frequency` specifications trigger expected errors - if hasattr(dummy, "freq"): - for ik, isel in enumerate(self.selectDict['frequency']["invalid"]): - with pytest.raises(self.selectDict['frequency']["errors"][ik]): - Selector(dummy, {'frequency': isel}) - else: - # ensure objects without `freq` property complain properly - with pytest.raises(SPYValueError): - Selector(dummy, {"frequency": [0]}) - - def test_continuous_latency(self): - - # this only works w/the equidistant trials constructed above!!! - selDict = {"latency": (None, # trivial "selection" of entire contents - "all", # trivial "selection" of entire contents - [0.5, 1.5], # regular range - [1.5, 2.0])} # minimal range (just two-time points) - - # all trials have same time-scale: take 1st one as reference - trlTime = (np.arange(0, self.trl["AnalogData"][0, 1] - self.trl["AnalogData"][0, 0]) - + self.trl["AnalogData"][0, 2]) / self.samplerate - - ang = AnalogData(data=self.data["AnalogData"], - trialdefinition=self.trl["AnalogData"], - samplerate=self.samplerate) - angIdx = [slice(None)] * len(ang.dimord) - timeIdx = ang.dimord.index("time") - - # the below check only works for equidistant trials! - for timeSel in selDict['latency']: - sel = Selector(ang, {'latency': timeSel}).time - if timeSel is None or timeSel == "all": - idx = slice(None) - else: - idx = np.intersect1d(np.where(trlTime >= timeSel[0])[0], - np.where(trlTime <= timeSel[1])[0]) - - # check that correct data was selected (all trials identical, just take 1st one) - assert np.array_equal(ang.trials[0][idx, :], - ang.trials[0][sel[0], :]) - if not isinstance(idx, slice) and len(idx) > 1: - timeSteps = np.diff(idx) - if timeSteps.min() == timeSteps.max() == 1: - idx = slice(idx[0], idx[-1] + 1, 1) - result = [idx] * len(ang.trials) - - # check correct format of selector (list -> slice etc.) - assert np.array_equal(result, sel) - - # perform actual data-selection and ensure identity of results - selected = selectdata(ang, {'latency': timeSel}) - for trialno in range(len(ang.trials)): - angIdx[timeIdx] = result[trialno] - assert np.array_equal(selected.trials[trialno], - ang.trials[trialno][tuple(angIdx)]) - - # test `latency` selection w/`SpikeData` and `EventData` - def test_discrete_latency(self): - - selDict = {"latency": (None, # trivial "selection" of entire contents - "all", # trivial "selection" of entire contents - [0.5, 1.5], # regular range - [1.5, 2.0])} # minimal range (just two-time points) - - # the below method of extracting spikes satisfying `latency` only works w/equidistant trials! - for dset in ["SpikeData", "EventData", "EventDataDimord"]: - dclass = "".join(dset.partition("Data")[:2]) - dimord = self.customEvtDimord if dset == "EventDataDimord" else None - discrete = getattr(spd, dclass)(data=self.data[dset], - trialdefinition=self.trl[dclass], - samplerate=self.samplerate, - dimord=dimord) - for timeSel in selDict["latency"]: - sel = Selector(discrete, {'latency': timeSel}).time - result = [] - - # compute sel by hand - for trlno in range(len(discrete.trials)): - trlTime = discrete.time[trlno] - if timeSel is None or timeSel == "all": - idx = np.arange(trlTime.size).tolist() - else: - idx = np.intersect1d(np.where(trlTime >= timeSel[0])[0], - np.where(trlTime <= timeSel[1])[0]).tolist() - - # check that correct data was selected - assert np.array_equal(discrete.trials[trlno][idx, :], - discrete.trials[trlno][sel[trlno], :]) - if not isinstance(idx, slice) and len(idx) > 1: - timeSteps = np.diff(idx) - if timeSteps.min() == timeSteps.max() == 1: - idx = slice(idx[0], idx[-1] + 1, 1) - result.append(idx) - - # check correct format of selector (list -> slice etc.) - assert np.array_equal(result, sel) - - # perform actual data-selection and ensure identity of results - selected = selectdata(discrete, {'latency': timeSel}) - assert selected.dimord == discrete.dimord - for trialno in range(len(discrete.trials)): - assert np.array_equal(selected.trials[trialno], - discrete.trials[trialno][result[trialno],:]) - - def test_spectral_frequency(self): - - # this selection only works w/the dummy frequency data constructed above!!! - selDict = {"frequency": (None, # trivial "selection" of entire contents, - "all", # trivial "selection" of entire contents - [2, 11], # regular range - [1, 2], # minimal range (just two-time points) - )} - spc = SpectralData(data=self.data['SpectralData'], - trialdefinition=self.trl['SpectralData'], - samplerate=self.samplerate) - allFreqs = spc.freq - spcIdx = [slice(None)] * len(spc.dimord) - freqIdx = spc.dimord.index("freq") - - for freqSel in selDict["frequency"]: - sel = Selector(spc, {"frequency": freqSel}).freq - if freqSel is None or freqSel == "all": - idx = slice(None) - else: - idx = np.intersect1d(np.where(allFreqs >= freqSel[0])[0], - np.where(allFreqs <= freqSel[1])[0]) - - # check that correct data was selected (all trials identical, just take 1st one) - assert np.array_equal(spc.freq[idx], spc.freq[sel]) - if not isinstance(idx, slice) and len(idx) > 1: - freqSteps = np.diff(idx) - if freqSteps.min() == freqSteps.max() == 1: - idx = slice(idx[0], idx[-1] + 1, 1) - - # check correct format of selector (list -> slice etc.) - assert np.array_equal(idx, sel) - - # perform actual data-selection and ensure identity of results - selected = selectdata(spc, {"frequency": freqSel}) - spcIdx[freqIdx] = idx - assert np.array_equal(selected.freq, spc.freq[sel]) - for trialno in range(len(spc.trials)): - assert np.array_equal(selected.trials[trialno], - spc.trials[trialno][tuple(spcIdx)]) - - def test_selector_trials(self): - - ang = AnalogData(data=self.data["AnalogData"], - trialdefinition=self.trl["AnalogData"], - samplerate=self.samplerate) - - # check original shapes - assert all([trl.shape[1] == self.nChannels for trl in ang.trials]) - assert all([trl.shape[0] == self.lenTrial for trl in ang.trials]) - - # test inplace channel, trial and latency selection - # ang.time[0] = array([0.5, 1. , 1.5, 2. , 2.5]) - # this latency selection hence takes the last two samples - select = {'channel': [2, 7, 9], 'trials': [0, 3, 5], 'latency': [1, 2]} - ang.selectdata(**select, inplace=True) - - # now check shapes and number of trials returned by Selector - # checks channel axis - assert all([trl.shape[1] == 3 for trl in ang.selection.trials]) - # checks time axis - assert len(ang.selection.trials) == 3 - - # test for non-existing trials, trial indices are relative here! - select = {'trials': [0, 3, 5]} - ang.selectdata(**select, inplace=True) - assert ang.selection.trial_ids[2] == 5 - # this returns original trial 6 (with index 5) - assert np.array_equal(ang.selection.trials[2], ang.trials[5]) - # we only have 3 trials selected here, so max. relative index is 2 - with pytest.raises(SPYValueError, match='less or equals 2'): - ang.selection.trials[5] - - # Fancy indexing is not allowed so far - select = {'channel': [7, 7, 8]} - ang.selectdata(**select, inplace=True) - with pytest.raises(SPYValueError, match='fancy selection with repetition'): - ang.selection.trials[0] - select = {'channel': [7, 3, 8]} - ang.selectdata(**select, inplace=True) - with pytest.raises(SPYValueError, match='fancy non-ordered selection'): - ang.selection.trials[0] - - def test_parallel(self, testcluster): - # collect all tests of current class and repeat them in parallel - client = dd.Client(testcluster) - all_tests = [attr for attr in self.__dir__() - if (inspect.ismethod(getattr(self, attr)) and attr != "test_parallel")] - for test in all_tests: - getattr(self, test)() - flush_local_cluster(testcluster) - client.close() - - -def _get_mtmfft_cfg_without_selection(): - cfg = StructDict() - cfg.out = "pow" - cfg.method = "mtmfft" - cfg.taper = "hann" - cfg.keeptrials = True - return cfg - -class TestSelectionBug332(): - def test_cF_no_selections(self): - data_len = 501 # length of spectral signal - nTrials = 20 - nChannels = 1 - adt = _get_fooof_signal(nTrials=nTrials, nChannels=nChannels) - assert adt.selection is None - cfg = _get_mtmfft_cfg_without_selection() - assert not 'select' in cfg - out = freqanalysis(cfg, adt) - assert out.data.shape == (nTrials, 1, data_len, nChannels), f"expected shape {(nTrials, 1, data_len, nChannels)} but found out.data.shape={out.data.shape}" - - def test_cF_selection_in_cfg(self): - data_len = 501 # length of spectral signal - nTrials = 20 - nChannels = 1 - adt = _get_fooof_signal(nTrials=nTrials, nChannels=nChannels) - assert adt.selection is None - cfg = _get_mtmfft_cfg_without_selection() - selected_trials = [3, 5, 7] - - cfg.select = { 'trials': selected_trials } # Add selection to cfg. - assert 'select' in cfg - out = freqanalysis(cfg, adt) - assert out.data.shape == (len(selected_trials), 1, data_len, nChannels), f"expected shape {(len(selected_trials), 1, data_len, nChannels)} but found out.data.shape={out.data.shape}" - - def test_cF_inplace_selection_in_data(self): - data_len = 501 # length of spectral signal - nTrials = 20 - nChannels = 1 - adt = _get_fooof_signal(nTrials=nTrials, nChannels=nChannels) - cfg = _get_mtmfft_cfg_without_selection() - assert not 'select' in cfg - selected_trials = [3, 5, 7] - - assert adt.selection is None - spy.selectdata(adt, trials=selected_trials, inplace=True) # Add in-place selection to input data. - assert adt.selection is not None - - out = freqanalysis(cfg, adt) - assert out.data.shape == (len(selected_trials), 1, data_len, nChannels), f"expected shape {(len(selected_trials), 1, data_len, nChannels)} but found out.data.shape={out.data.shape}" - - def test_selections_in_both_not_allowed(self): - data_len = 501 # length of spectral signal - nTrials = 20 - nChannels = 1 - adt = _get_fooof_signal(nTrials=nTrials, nChannels=nChannels) - cfg = _get_mtmfft_cfg_without_selection() - selected_trials = [3, 5, 7] - - cfg.select = { 'trials': selected_trials } - spy.selectdata(adt, trials=selected_trials, inplace=True) # Add in-place selection to input data. - - assert adt.selection is not None - assert 'select' in cfg - - with pytest.raises(SPYError, match="Selection found both"): - out = freqanalysis(cfg, adt) - #assert out.data.shape == (len(selected_trials), 1, data_len, nChannels), f"expected shape {(len(selected_trials), 1, data_len, nChannels)} but found out.data.shape={out.data.shape}" + trldef = np.vstack([np.arange(0, nSamples * nTrials, nSamples), + np.arange(0, nSamples * nTrials, nSamples) + nSamples, + np.ones(nTrials) * -1]).T + + # this is a running array with shape: nSamples*nTrials x nChannels + # and with data[i, j] = i+1 + j * nSamples*nTrials + + data = np.arange(1, nTrials * nChannels * nSamples + 1).reshape(nChannels, nSamples * nTrials).T + adata = spy.AnalogData(data=data, samplerate=samplerate, + trialdefinition=trldef) + + + # map selection keywords to selector attributes (holding the idx to access selected data) + map_sel_attr = dict(trials = 'trial_ids', + channel = 'channel', + latency = 'time', + ) + + def test_ad_selection(self): + + """ + Create a simple selection and test the returned data + """ + + selection = {'trials': 1, 'channel': [6, 2], 'latency': [0, 1]} + + # pick the data by hand, latency [0, 1] covers 2nd - 4th sample index + # as time axis is array([-0.5, 0. , 0.5, 1. , 1.5]) + + solution = T1.adata.data[self.nSamples : self.nSamples * 2] + solution = np.column_stack([solution[1:4, 6], solution[1:4, 2]]) + res = spy.selectdata(T1.adata, selection) + + assert np.all(solution == res.data) + + + def test_valid_ad(self): + + """ + Instantiate Selector class and check its only attributes (the idx) + """ + + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") + valid_selections = [ + ({ + 'channel': ["channel03", "channel01"], + 'latency': [0, 1], + 'trials': np.arange(2)}, + { + # these are the idx used to access the actual data + 'channel': [2, 0], + 'latency': 2 * [slice(1, 4, 1)], + 'trials': [0, 1] + }), + ({ + # with some repetitions + 'channel': [7, 3, 3], + 'trials': [0, 1, 1]}, + { + 'channel': [7, 3, 3], + 'trials': [0, 1, 1] + }) + + ] + + for selection in valid_selections: + # instantiate Selector and check attributes + sel_kwargs, solution = selection + selector_object = Selector(self.adata, sel_kwargs) + for sel_kw in sel_kwargs.keys(): + attr_name = self.map_sel_attr[sel_kw] + assert getattr(selector_object, attr_name) == solution[sel_kw] + + def test_invalid_ad(self): + + # each selection test is a 3-tuple: (selection kwargs, Error, error message sub-string) + invalid_selections = [ + ({'channel': ["channel33", "channel01"]}, + SPYValueError, "existing names or indices"), + ({'channel': "my-non-existing-channel"}, + SPYValueError, "existing names or indices"), + ({'channel': 99}, + SPYValueError, "existing names or indices"), + ({'latency': 1}, SPYTypeError, "expected array_like"), + ({'latency': [0, 10]}, SPYValueError, "at least one trial covering the latency window"), + ({'latency': 'sth'}, SPYValueError, "'maxperiod'"), + ({'trials': [-3]}, SPYValueError, "all array elements to be bound"), + ({'trials': ['1', '6']}, SPYValueError, "expected dtype = numeric") + ] + + for selection in invalid_selections: + sel_kw, error, err_str = selection + with pytest.raises(error, match=err_str): + spy.selectdata(self.adata, sel_kw) + if __name__ == '__main__': - T1 = TestSelector() + T1 = Test_AD_Selections() From 349478a1d7d9d158af25d157ba96497fdfba201e Mon Sep 17 00:00:00 2001 From: tensionhead Date: Thu, 12 Jan 2023 17:42:28 +0100 Subject: [PATCH 078/135] WIP: new spectral data selector tests Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 202 ++++++++++++++++++++++--------- 1 file changed, 147 insertions(+), 55 deletions(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 828471b77..1ee0f4e9a 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -23,120 +23,212 @@ import syncopy as spy -# The procedure here is: -# (1) test if `Selector` instance was constructed correctly (i.e., indexing tuples -# look as expected, ordered list -> slice conversion works etc.) -# (2) test if data was correctly selected from source object (i.e., compare shapes, -# property contents and actual numeric data arrays) -# Multi-selections are not tested here but in the respective class tests (e.g., -# "time" + "channel" + "trial" `AnalogData` selections etc.) +# map selection keywords to selector attributes (holding the idx to access selected data) +map_sel_attr = dict(trials = 'trial_ids', + channel = 'channel', + latency = 'time', + taper = 'taper', + frequency = 'freq' + ) -class Test_AD_Selections(): +class TestAnalogSelections: - # Set up "global" parameters for data objects to be tested nChannels = 10 nSamples = 5 # per trial nTrials = 3 samplerate = 2.0 - data = {} - trl = {} trldef = np.vstack([np.arange(0, nSamples * nTrials, nSamples), np.arange(0, nSamples * nTrials, nSamples) + nSamples, np.ones(nTrials) * -1]).T - # this is a running array with shape: nSamples*nTrials x nChannels + # this is an array running from 1 - nChannels * nSamples * nTrials + # with shape: nSamples*nTrials x nChannels # and with data[i, j] = i+1 + j * nSamples*nTrials - data = np.arange(1, nTrials * nChannels * nSamples + 1).reshape(nChannels, nSamples * nTrials).T + adata = spy.AnalogData(data=data, samplerate=samplerate, trialdefinition=trldef) - - # map selection keywords to selector attributes (holding the idx to access selected data) - map_sel_attr = dict(trials = 'trial_ids', - channel = 'channel', - latency = 'time', - ) - def test_ad_selection(self): - """ - Create a simple selection and test the returned data + """ + Create a simple selection and check that the returned data is correct """ selection = {'trials': 1, 'channel': [6, 2], 'latency': [0, 1]} - + res = spy.selectdata(T1.adata, selection) + # pick the data by hand, latency [0, 1] covers 2nd - 4th sample index # as time axis is array([-0.5, 0. , 0.5, 1. , 1.5]) solution = T1.adata.data[self.nSamples : self.nSamples * 2] solution = np.column_stack([solution[1:4, 6], solution[1:4, 2]]) - res = spy.selectdata(T1.adata, selection) assert np.all(solution == res.data) - - def test_valid_ad(self): + + def test_ad_valid(self): """ Instantiate Selector class and check its only attributes (the idx) """ - + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") valid_selections = [ - ({ - 'channel': ["channel03", "channel01"], - 'latency': [0, 1], - 'trials': np.arange(2)}, - { + ( + {'channel': ["channel03", "channel01"], + 'latency': [0, 1], + 'trials': np.arange(2)}, # these are the idx used to access the actual data - 'channel': [2, 0], - 'latency': 2 * [slice(1, 4, 1)], - 'trials': [0, 1] - }), - ({ - # with some repetitions - 'channel': [7, 3, 3], - 'trials': [0, 1, 1]}, - { - 'channel': [7, 3, 3], - 'trials': [0, 1, 1] - }) - - ] - + {'channel': [2, 0], + 'latency': 2 * [slice(1, 4, 1)], + 'trials': [0, 1]} + ), + ( + # 2nd selection with some repetitions + {'channel': [7, 3, 3], + 'trials': [0, 1, 1]}, + # 'solutions' + {'channel': [7, 3, 3], + 'trials': [0, 1, 1]} + ) + ] + + for selection in valid_selections: # instantiate Selector and check attributes sel_kwargs, solution = selection selector_object = Selector(self.adata, sel_kwargs) for sel_kw in sel_kwargs.keys(): - attr_name = self.map_sel_attr[sel_kw] + attr_name = map_sel_attr[sel_kw] assert getattr(selector_object, attr_name) == solution[sel_kw] - def test_invalid_ad(self): + def test_ad_invalid(self): # each selection test is a 3-tuple: (selection kwargs, Error, error message sub-string) invalid_selections = [ ({'channel': ["channel33", "channel01"]}, SPYValueError, "existing names or indices"), ({'channel': "my-non-existing-channel"}, - SPYValueError, "existing names or indices"), + SPYValueError, "existing names or indices"), ({'channel': 99}, - SPYValueError, "existing names or indices"), + SPYValueError, "existing names or indices"), ({'latency': 1}, SPYTypeError, "expected array_like"), ({'latency': [0, 10]}, SPYValueError, "at least one trial covering the latency window"), - ({'latency': 'sth'}, SPYValueError, "'maxperiod'"), + ({'latency': 'sth-wrong'}, SPYValueError, "'maxperiod'"), ({'trials': [-3]}, SPYValueError, "all array elements to be bound"), - ({'trials': ['1', '6']}, SPYValueError, "expected dtype = numeric") + ({'trials': ['1', '6']}, SPYValueError, "expected dtype = numeric"), + ({'trials': slice(2)}, SPYTypeError, "expected serializable data type") ] for selection in invalid_selections: sel_kw, error, err_str = selection with pytest.raises(error, match=err_str): spy.selectdata(self.adata, sel_kw) - + +class TestSpectralSelections: + + nChannels = 3 + nSamples = 3 # per trial + nTrials = 3 + nTaper = 2 + nFreqs = 3 + samplerate = 2.0 + + trldef = np.vstack([np.arange(0, nSamples * nTrials, nSamples), + np.arange(0, nSamples * nTrials, nSamples) + nSamples, + np.ones(nTrials) * 2]).T + + # this is an array running from 1 - nChannels * nSamples * nTrials * nFreq * nTaper + data = np.arange(1, nChannels * nSamples * nTrials * nFreqs * nTaper + 1).reshape(nSamples * nTrials, nTaper, nFreqs, nChannels) + sdata = spy.SpectralData(data=data, samplerate=samplerate, + trialdefinition=trldef) + # freq labels + sdata.freq = [20, 40, 60] + + def test_spectral_selection(self): + + """ + Create a simple selection and check that the returned data is correct + """ + + selection = {'trials': 1, + 'channel': [1, 0], + 'latency': [1, 1.5], + 'frequency': [25, 50]} + res = spy.selectdata(T2.sdata, selection) + + # pick the data by hand, dimord is: ['time', 'taper', 'freq', 'channel'] + # latency [1, 1.5] covers 1st - 2nd sample index + # as time axis is array([1., 1.5, 2.]) + # frequency covers only 2nd index (40 Hz) + + # pick trial + solution = T2.sdata.data[self.nSamples : self.nSamples * 2] + # pick channels, frequency and latency and re-stack + solution = np.stack([solution[:2, :, [1], 1], solution[:2, :, [1], 0]], axis=-1) + + assert np.all(solution == res.data) + + def test_spectral_valid(self): + + """ + Instantiate Selector class and check its only attributes (the idx) + test mainly additional dimensions (taper and freq) here + """ + + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") + valid_selections = [ + ( + {'frequency': np.array([30, 60]), + 'taper': [1, 0]}, + # the 'solutions' + {'frequency': slice(1, 3, 1), + 'taper': [1, 0]}, + ), + # 2nd selection + ( + {'frequency': 'all', + 'taper': 'taper2', + 'latency': [1.2, 1.7], + 'trials': np.arange(1,3)}, + # the 'solutions' + {'frequency': slice(None), + 'taper': [1], + 'latency': [[1], [1]], + 'trials': [1, 2]}, + ) + ] + + for selection in valid_selections: + # instantiate Selector and check attributes + sel_kwargs, solution = selection + selector_object = Selector(self.sdata, sel_kwargs) + for sel_kw in sel_kwargs.keys(): + attr_name = map_sel_attr[sel_kw] + assert getattr(selector_object, attr_name) == solution[sel_kw] + + def test_spectral_invalid(self): + + # each selection test is a 3-tuple: (selection kwargs, Error, error message sub-string) + invalid_selections = [ + ({'frequency': '40Hz'}, SPYValueError, "'all' or `None` or float or list/array"), + ({'frequency': 4}, SPYValueError, "all array elements to be bounded"), + ({'frequency': slice(None)}, SPYTypeError, "expected serializable data type"), + ({'frequency': range(20,60)}, SPYTypeError, "expected array_like"), + ({'frequency': np.arange(20,60)}, SPYValueError, "expected array of shape"), + ({'taper': 'taper13'}, SPYValueError, "existing names or indices"), + ({'taper': [18, 99]}, SPYValueError, "existing names or indices"), + ] + + for selection in invalid_selections: + sel_kw, error, err_str = selection + with pytest.raises(error, match=err_str): + spy.selectdata(self.sdata, sel_kw) + if __name__ == '__main__': - T1 = Test_AD_Selections() + T1 = TestAnalogSelections() + T2 = TestSpectralSelections() From 241226c4e77751af6c6a87773cc7319bf7efb0a0 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Thu, 12 Jan 2023 18:05:21 +0100 Subject: [PATCH 079/135] WIP: csd selection tests Changes to be committed: modified: syncopy/tests/test_selectdata.py Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 102 ++++++++++++++++++++++++++++--- 1 file changed, 94 insertions(+), 8 deletions(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 1ee0f4e9a..4c82dd9d8 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -80,7 +80,7 @@ def test_ad_valid(self): ( {'channel': ["channel03", "channel01"], 'latency': [0, 1], - 'trials': np.arange(2)}, + 'trials': np.arange(2)}, # these are the idx used to access the actual data {'channel': [2, 0], 'latency': 2 * [slice(1, 4, 1)], @@ -160,7 +160,7 @@ def test_spectral_selection(self): 'latency': [1, 1.5], 'frequency': [25, 50]} res = spy.selectdata(T2.sdata, selection) - + # pick the data by hand, dimord is: ['time', 'taper', 'freq', 'channel'] # latency [1, 1.5] covers 1st - 2nd sample index # as time axis is array([1., 1.5, 2.]) @@ -180,7 +180,7 @@ def test_spectral_valid(self): test mainly additional dimensions (taper and freq) here """ - # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") valid_selections = [ ( {'frequency': np.array([30, 60]), @@ -200,9 +200,9 @@ def test_spectral_valid(self): 'taper': [1], 'latency': [[1], [1]], 'trials': [1, 2]}, - ) + ) ] - + for selection in valid_selections: # instantiate Selector and check attributes sel_kwargs, solution = selection @@ -218,17 +218,103 @@ def test_spectral_invalid(self): ({'frequency': '40Hz'}, SPYValueError, "'all' or `None` or float or list/array"), ({'frequency': 4}, SPYValueError, "all array elements to be bounded"), ({'frequency': slice(None)}, SPYTypeError, "expected serializable data type"), - ({'frequency': range(20,60)}, SPYTypeError, "expected array_like"), + ({'frequency': range(20,60)}, SPYTypeError, "expected array_like"), ({'frequency': np.arange(20,60)}, SPYValueError, "expected array of shape"), ({'taper': 'taper13'}, SPYValueError, "existing names or indices"), - ({'taper': [18, 99]}, SPYValueError, "existing names or indices"), + ({'taper': [18, 99]}, SPYValueError, "existing names or indices"), ] for selection in invalid_selections: sel_kw, error, err_str = selection with pytest.raises(error, match=err_str): spy.selectdata(self.sdata, sel_kw) - + +class TestCrossSpectralSelections: + + nChannels = 3 + nSamples = 3 # per trial + nTrials = 3 + nFreqs = 3 + samplerate = 2.0 + + trldef = np.vstack([np.arange(0, nSamples * nTrials, nSamples), + np.arange(0, nSamples * nTrials, nSamples) + nSamples, + np.ones(nTrials) * 2]).T + + # this is an array running from 1 - nChannels * nSamples * nTrials * nFreq * nTaper + data = np.arange(1, nChannels**2 * nSamples * nTrials * nFreqs + 1).reshape(nSamples * nTrials, nFreqs, nChannels, nChannels) + csd_data = spy.CrossSpectralData(data=data, samplerate=samplerate) + csd_data.trialdefinition = trldef + + # freq labels + csd_data.freq = [20, 40, 60] + + def test_csd_selection(self): + + """ + Create a simple selection and check that the returned data is correct + """ + + selection = {'trials': 1, + 'channel_i': [0, 1], + 'latency': [1.5, 2], + 'frequency': [25, 60]} + + res = spy.selectdata(self.csd_data, selection) + + # pick the data by hand, dimord is: ['time', 'freq', 'channel_i', 'channel_j'] + # latency [1, 1.5] covers 2nd - 3rd sample index + # as time axis is array([1., 1.5, 2.]) + # frequency covers 2nd and 3rd index (40 and 60Hz) + + # pick trial + solution = self.csd_data.data[self.nSamples : self.nSamples * 2] + # pick channels, frequency and latency and re-stack + solution = solution[1:3, 1:3, :2, :] + assert np.all(solution == res.data) + + + def test_csd_valid(self): + + """ + Instantiate Selector class and check its only attributes (the idx) + test mainly additional dimensions (channel_i, channel_j) here + """ + + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") + valid_selections = [ + ( + {'frequency': np.array([30, 60]), + 'taper': [1, 0]}, + # the 'solutions' + {'frequency': slice(1, 3, 1), + 'taper': [1, 0]}, + ), + # 2nd selection + ( + {'frequency': 'all', + 'taper': 'taper2', + 'latency': [1.2, 1.7], + 'trials': np.arange(1,3)}, + # the 'solutions' + {'frequency': slice(None), + 'taper': [1], + 'latency': [[1], [1]], + 'trials': [1, 2]}, + ) + ] + + for selection in valid_selections: + # instantiate Selector and check attributes + sel_kwargs, solution = selection + selector_object = Selector(self.sdata, sel_kwargs) + for sel_kw in sel_kwargs.keys(): + attr_name = map_sel_attr[sel_kw] + assert getattr(selector_object, attr_name) == solution[sel_kw] + + + if __name__ == '__main__': T1 = TestAnalogSelections() T2 = TestSpectralSelections() + T3 = TestCrossSpectralSelections() From 4fba0038b5b1071df4aebf488aab8e6a75a7f520 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 13 Jan 2023 10:48:50 +0100 Subject: [PATCH 080/135] CHG: remove extra exception handler, already in place below --- syncopy/__init__.py | 9 --------- syncopy/shared/errors.py | 2 +- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 6447958e0..7124872b6 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -166,15 +166,6 @@ def filter(self, record): spy_parallel_logger.addHandler(fhp) spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") -## Setup global handler to log uncaught exceptions: -def handle_exception(exc_type, exc_value, exc_traceback): - if issubclass(exc_type, KeyboardInterrupt): - sys.__excepthook__(exc_type, exc_value, exc_traceback) - return - spy_parallel_logger.critical("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) - -sys.excepthook = handle_exception # TODO: this may get overwritten below with SPYExceptionHandler, should log in there. - # Set upper bound for temp directory size (in GB) __storagelimit__ = 10 diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index dfa3ef53d..111633858 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -195,7 +195,7 @@ def SPYExceptionHandler(*excargs, **exckwargs): # Show generated message and leave (or kick-off debugging in Jupyer/iPython if %pdb is on) logger = get_logger() - logger.error(emsg) + logger.critical(emsg) if isipy: if ipy.call_pdb: ipy.InteractiveTB.debugger() From 21b6ebd32e0670a3a4b6deb3e0b3b26808f89fb7 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Fri, 13 Jan 2023 11:29:52 +0100 Subject: [PATCH 081/135] update of cfg removed --- syncopy/nwanalysis/connectivity_analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncopy/nwanalysis/connectivity_analysis.py b/syncopy/nwanalysis/connectivity_analysis.py index 61555162e..eba8cdc9d 100644 --- a/syncopy/nwanalysis/connectivity_analysis.py +++ b/syncopy/nwanalysis/connectivity_analysis.py @@ -381,9 +381,9 @@ def connectivityanalysis(data, method="coh", keeptrials=False, output="abs", # Sanitize output and call the chosen ComputationalRoutine on the averaged ST output # ---------------------------------------------------------------------------------- if method == 'csd': - new_cfg.update({'output': st_out.data.dtype.name}) + # new_cfg.update({'output': st_out.data.dtype.name}) st_out.cfg.update(data.cfg) - st_out.cfg.update({'cross_spectral': new_cfg}) + st_out.cfg.update({'connectivityanalysis': new_cfg}) return st_out else: out = CrossSpectralData(dimord=st_dimord) From 897056815cd26868cdbaeb859b890cc4ce2c62df Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Fri, 13 Jan 2023 11:31:08 +0100 Subject: [PATCH 082/135] Spectral input check done, parallel and config test failing- need help --- syncopy/tests/test_connectivity.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index 9b4d2fe15..7cfb2a52b 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -554,6 +554,31 @@ class TestCSD: def test_data_output_type(self): assert self.spec.data.dtype.name == 'complex64' + # @skip_low_mem + # def test_csd_parallel(T, testcluster=None): + + # ppl.ioff() + # client = dd.Client(testcluster) + # # all_tests = [attr for attr in self.__dir__() + # # if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] + + # # for test in all_tests: + # # test_method = getattr(self, test) + # # test_method() + # client.close() + # ppl.ion() + + # def test_csd_cfg(self): + + # call = lambda cfg: cafunc(self.spec, cfg) + + # run_cfg_test(call, method='csd', + # cfg=get_defaults(cafunc)) + + def test_csd_input(self): + + assert not isinstance(self.spec, SpectralData) + class TestCorrelation: From 8b1ad4e1d8323a8da511b0b394f87caf52a2ccff Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 13 Jan 2023 12:19:38 +0100 Subject: [PATCH 083/135] CHG: adapt logging test --- syncopy/tests/test_logging.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py index 2a6222ff6..3fb53a3f2 100644 --- a/syncopy/tests/test_logging.py +++ b/syncopy/tests/test_logging.py @@ -18,9 +18,13 @@ def test_logfile_exists(self): assert os.path.isfile(logfile) def test_default_log_level_is_warning(self): + + # Ensure the log level is at default (that user did not change SPYLOGLEVEL on test system) + assert os.getenv("SPYLOGLEVEL", "WARNING") == "WARNING" + logfile = os.path.join(spy.__logdir__, "syncopy.log") assert os.path.isfile(logfile) - num_lines_bofore = sum(1 for line in open(logfile)) + num_lines_initial = sum(1 for line in open(logfile)) # The log file gets appended, so it will most likely *not* be empty. # Log something with log level info and DEBUG, which should not affect the logfile. logger = get_logger() @@ -29,13 +33,13 @@ def test_default_log_level_is_warning(self): num_lines_after_info_debug = sum(1 for line in open(logfile)) - assert num_lines_bofore == num_lines_after_info_debug + assert num_lines_initial == num_lines_after_info_debug # Now log something with log level WARNING SPYLog("I am adding a WARNING level log entry.", loglevel="WARNING") num_lines_after_warning = sum(1 for line in open(logfile)) - assert num_lines_after_info_debug + 1 == num_lines_after_warning + assert num_lines_after_warning > num_lines_after_info_debug From 4a05fb20ec99fe39718951eff08ce06733010a35 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Fri, 13 Jan 2023 13:37:48 +0100 Subject: [PATCH 084/135] uncommented tests --- syncopy/tests/test_connectivity.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index 7cfb2a52b..6b9829efd 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -554,26 +554,26 @@ class TestCSD: def test_data_output_type(self): assert self.spec.data.dtype.name == 'complex64' - # @skip_low_mem - # def test_csd_parallel(T, testcluster=None): + @skip_low_mem + def test_csd_parallel(self, testcluster=None): - # ppl.ioff() - # client = dd.Client(testcluster) - # # all_tests = [attr for attr in self.__dir__() - # # if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] + ppl.ioff() + client = dd.Client(testcluster) + all_tests = [attr for attr in self.__dir__() + if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] - # # for test in all_tests: - # # test_method = getattr(self, test) - # # test_method() - # client.close() - # ppl.ion() + for test in all_tests: + test_method = getattr(self, test) + test_method() + client.close() + ppl.ion() - # def test_csd_cfg(self): + def test_csd_cfg(self): - # call = lambda cfg: cafunc(self.spec, cfg) + call = lambda cfg: cafunc(self.spec, cfg) - # run_cfg_test(call, method='csd', - # cfg=get_defaults(cafunc)) + run_cfg_test(call, method='csd', + cfg=get_defaults(cafunc)) def test_csd_input(self): From 4ad1e2f3d706afcf2ab857e7600f3aaf099cb0ba Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 13 Jan 2023 13:59:08 +0100 Subject: [PATCH 085/135] NEW: disable tqdm unless in TTY --- syncopy/io/load_ft.py | 2 +- syncopy/io/load_nwb.py | 4 ++-- syncopy/io/load_tdt.py | 2 +- syncopy/io/utils.py | 6 +++--- syncopy/shared/computational_routine.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/syncopy/io/load_ft.py b/syncopy/io/load_ft.py index 479f36994..a94d3523a 100644 --- a/syncopy/io/load_ft.py +++ b/syncopy/io/load_ft.py @@ -288,7 +288,7 @@ def _read_hdf_structure(h5Group, dtype=np.float32, shape=[nTotalSamples, nChannels]) - pbar = tqdm(trl_refs, desc=f"{struct_name} - loading {nTrials} trials") + pbar = tqdm(trl_refs, desc=f"{struct_name} - loading {nTrials} trials", disable=None) SampleCounter = 0 # trial stacking # one swipe per trial diff --git a/syncopy/io/load_nwb.py b/syncopy/io/load_nwb.py index 1c51b547e..670b69bfe 100644 --- a/syncopy/io/load_nwb.py +++ b/syncopy/io/load_nwb.py @@ -228,7 +228,7 @@ def load_nwb(filename, memuse=3000, container=None): memuse *= 1024**2 # Process analog time series data and convert stuff block by block (if necessary) - pbar = tqdm(angSeries, position=0) + pbar = tqdm(angSeries, position=0, disable=None) for acqValue in pbar: # Show dataset name in progress bar label pbar.set_description("Loading {} from disk".format(acqValue.name)) @@ -261,7 +261,7 @@ def load_nwb(filename, memuse=3000, container=None): rem = int(angDset.shape[0] % nSamp) blockList = [nSamp] * int(angDset.shape[0] // nSamp) + [rem] * int(rem > 0) - for m, M in enumerate(tqdm(blockList, desc=pbarDesc, position=1, leave=False)): + for m, M in enumerate(tqdm(blockList, desc=pbarDesc, position=1, leave=False, disable=None)): st_samp, end_samp = m * nSamp, m * nSamp + M angDset[st_samp : end_samp, :] = acqValue.data[st_samp : end_samp, :] if acqValue.channel_conversion is not None: diff --git a/syncopy/io/load_tdt.py b/syncopy/io/load_tdt.py index 5ba9cc69b..383b35ab5 100644 --- a/syncopy/io/load_tdt.py +++ b/syncopy/io/load_tdt.py @@ -737,7 +737,7 @@ def data_aranging(self, Files, DataInfo_loaded): len(Files), len(idxStartStop), self.chan_in_chunks, hdf_out_path ) ) - for (start, stop) in tqdm(iterable=idxStartStop, desc="chunk", unit="chunk"): + for (start, stop) in tqdm(iterable=idxStartStop, desc="chunk", unit="chunk", disable=None): data = [self.read_data(Files[jj]) for jj in range(start, stop)] data = np.vstack(data).T if start == 0: diff --git a/syncopy/io/utils.py b/syncopy/io/utils.py index 338e3d1bb..64f798e21 100644 --- a/syncopy/io/utils.py +++ b/syncopy/io/utils.py @@ -235,18 +235,18 @@ def cleanup(older_than=24, interactive=True): # Delete all session-remains at once elif choice == "S": - for fls in tqdm(flsList, desc="Deleting session data..."): + for fls in tqdm(flsList, desc="Deleting session data...", disable=None): _rm_session(fls) # Deleate all dangling files at once elif choice == "D": - for dat in tqdm(dangling, desc="Deleting dangling data..."): + for dat in tqdm(dangling, desc="Deleting dangling data...", disable=None): _rm_session([dat]) # Delete everything elif choice == "R": for contents in tqdm(flsList + [[dat] for dat in dangling], - desc="Deleting temporary data..."): + desc="Deleting temporary data...", disable=None): _rm_session(contents) # Don't do anything for now, continue w/dangling data diff --git a/syncopy/shared/computational_routine.py b/syncopy/shared/computational_routine.py index 3cf860c28..5f72f7da4 100644 --- a/syncopy/shared/computational_routine.py +++ b/syncopy/shared/computational_routine.py @@ -930,7 +930,7 @@ def compute_sequential(self, data, out): with h5py.File(out.filename, "r+") as h5fout: target = h5fout[self.outDatasetName] - for nblock in tqdm(range(self.numTrials), bar_format=self.tqdmFormat): + for nblock in tqdm(range(self.numTrials), bar_format=self.tqdmFormat, disable=None): # Extract respective indexing tuples from constructed lists ingrid = self.sourceLayout[nblock] From 296d682a1ea51bf16fbcb907421e097baf4cdbc5 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Fri, 13 Jan 2023 14:19:33 +0100 Subject: [PATCH 086/135] FIX: Preserve trial selection order upon latency selection Changes to be committed: modified: syncopy/datatype/base_data.py modified: syncopy/shared/latency.py modified: syncopy/tests/test_selectdata.py --- syncopy/datatype/base_data.py | 2 +- syncopy/shared/latency.py | 7 +++++-- syncopy/tests/test_selectdata.py | 24 +++++++++++++----------- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index ba9fc4c9a..832e73531 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -1557,7 +1557,7 @@ def __init__(self, data, select): ) # We first need to know which trials are of interest here (assuming - # that any valid input object *must* have a `trials` attribute) + # that any valid input object *must* have a `trials_ids` attribute) self.trial_ids = (data, select) # Now set any possible selection attribute (depending on type of `data`) diff --git a/syncopy/shared/latency.py b/syncopy/shared/latency.py index b33546a14..aac10e077 100644 --- a/syncopy/shared/latency.py +++ b/syncopy/shared/latency.py @@ -121,7 +121,7 @@ def create_trial_selection(data, window): # beginnings and ends of all (selected) trials in trigger-relative time in seconds if data.selection is not None: trl_starts, trl_ends = data.selection.trialintervals[:, 0], data.selection.trialintervals[:, 1] - trl_idx = np.arange(len(data.selection.trials)) + trl_idx = np.array(data.selection.trial_ids) else: trl_starts, trl_ends = data.trialintervals[:, 0], data.trialintervals[:, 1] trl_idx = np.arange(len(data.trials)) @@ -143,7 +143,10 @@ def create_trial_selection(data, window): else: sel_ids = np.array(data.selection.trial_ids)[bmask] # match fitting trials with selected ones - fit_trl_idx = np.intersect1d(data.selection.trial_ids, sel_ids) + intersection = np.intersect1d(data.selection.trial_ids, sel_ids) + # intersect result is sorted, restore original selection order + fit_trl_idx = np.array( + [trl_id for trl_id in data.selection.trial_ids if trl_id in intersection]) numDiscard = len(data.selection.trial_ids) - len(fit_trl_idx) if fit_trl_idx.size == 0: diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 4c82dd9d8..e63cf3478 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -58,12 +58,12 @@ def test_ad_selection(self): """ selection = {'trials': 1, 'channel': [6, 2], 'latency': [0, 1]} - res = spy.selectdata(T1.adata, selection) + res = spy.selectdata(self.adata, selection) # pick the data by hand, latency [0, 1] covers 2nd - 4th sample index # as time axis is array([-0.5, 0. , 0.5, 1. , 1.5]) - solution = T1.adata.data[self.nSamples : self.nSamples * 2] + solution = self.adata.data[self.nSamples : self.nSamples * 2] solution = np.column_stack([solution[1:4, 6], solution[1:4, 2]]) assert np.all(solution == res.data) @@ -159,7 +159,7 @@ def test_spectral_selection(self): 'channel': [1, 0], 'latency': [1, 1.5], 'frequency': [25, 50]} - res = spy.selectdata(T2.sdata, selection) + res = spy.selectdata(self.sdata, selection) # pick the data by hand, dimord is: ['time', 'taper', 'freq', 'channel'] # latency [1, 1.5] covers 1st - 2nd sample index @@ -167,7 +167,7 @@ def test_spectral_selection(self): # frequency covers only 2nd index (40 Hz) # pick trial - solution = T2.sdata.data[self.nSamples : self.nSamples * 2] + solution = self.sdata.data[self.nSamples : self.nSamples * 2] # pick channels, frequency and latency and re-stack solution = np.stack([solution[:2, :, [1], 1], solution[:2, :, [1], 0]], axis=-1) @@ -229,6 +229,7 @@ def test_spectral_invalid(self): with pytest.raises(error, match=err_str): spy.selectdata(self.sdata, sel_kw) + class TestCrossSpectralSelections: nChannels = 3 @@ -255,22 +256,23 @@ def test_csd_selection(self): Create a simple selection and check that the returned data is correct """ - selection = {'trials': 1, + selection = {'trials': [1, 0], 'channel_i': [0, 1], 'latency': [1.5, 2], 'frequency': [25, 60]} res = spy.selectdata(self.csd_data, selection) - # pick the data by hand, dimord is: ['time', 'freq', 'channel_i', 'channel_j'] # latency [1, 1.5] covers 2nd - 3rd sample index # as time axis is array([1., 1.5, 2.]) # frequency covers 2nd and 3rd index (40 and 60Hz) - # pick trial - solution = self.csd_data.data[self.nSamples : self.nSamples * 2] - # pick channels, frequency and latency and re-stack - solution = solution[1:3, 1:3, :2, :] + # pick trials + solution = np.concatenate([self.csd_data.data[self.nSamples: self.nSamples * 2], + self.csd_data.data[: self.nSamples]], axis=0) + + # pick channels, frequency and latency + solution = np.concatenate([solution[1:3, 1:3, :2, :], solution[4:6, 1:3, :2, :]]) assert np.all(solution == res.data) @@ -307,7 +309,7 @@ def test_csd_valid(self): for selection in valid_selections: # instantiate Selector and check attributes sel_kwargs, solution = selection - selector_object = Selector(self.sdata, sel_kwargs) + selector_object = Selector(self.csd_data, sel_kwargs) for sel_kw in sel_kwargs.keys(): attr_name = map_sel_attr[sel_kw] assert getattr(selector_object, attr_name) == solution[sel_kw] From dc8ece27f2db0bf1674eefd42db36182d2598bee Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 13 Jan 2023 15:15:35 +0100 Subject: [PATCH 087/135] CHG: log exceptions to parallel logger, they may come from parallel code --- syncopy/shared/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index 111633858..a698fd293 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -194,7 +194,7 @@ def SPYExceptionHandler(*excargs, **exckwargs): cols.Normal if isipy else "") # Show generated message and leave (or kick-off debugging in Jupyer/iPython if %pdb is on) - logger = get_logger() + logger = get_parallel_logger() logger.critical(emsg) if isipy: if ipy.call_pdb: From b2ae72eb4784a4f31ac3cdd68edffc29268e09a1 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 13 Jan 2023 15:17:48 +0100 Subject: [PATCH 088/135] CHG: log exceptions to parallel logger, they may come from parallel code --- syncopy/shared/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/shared/errors.py b/syncopy/shared/errors.py index a698fd293..8e185cf88 100644 --- a/syncopy/shared/errors.py +++ b/syncopy/shared/errors.py @@ -275,7 +275,7 @@ def SPYExceptionHandler(*excargs, **exckwargs): # Show generated message and get outta here - logger = get_logger() + logger = get_parallel_logger() logger.critical(emsg) # Kick-start debugging in case %pdb is enabled in Jupyter/iPython From 5f07f89f6b55da3c0f95f0acb1fa5553485db017 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Fri, 13 Jan 2023 15:59:39 +0100 Subject: [PATCH 089/135] Test checking --- syncopy/tests/test_connectivity.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index 6b9829efd..f1e4be96d 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -547,12 +547,13 @@ class TestCSD: cfg = spy.StructDict() cfg.tapsmofrq = 1.5 cfg.foilim = [5, 60] - cfg.method = 'csd' - spec = spy.connectivityanalysis(data, cfg) + spec = spy.freqanalysis(data, cfg, output='fourier', keeptapers=True) def test_data_output_type(self): - assert self.spec.data.dtype.name == 'complex64' + cross_spec = spy.connectivityanalysis(self.spec, method='csd') + assert np.all(self.spec.freq == cross_spec.freq) + assert cross_spec.data.dtype.name == 'complex64' @skip_low_mem def test_csd_parallel(self, testcluster=None): @@ -572,8 +573,7 @@ def test_csd_cfg(self): call = lambda cfg: cafunc(self.spec, cfg) - run_cfg_test(call, method='csd', - cfg=get_defaults(cafunc)) + run_cfg_test(call, method='csd', cfg=get_defaults(cafunc)) def test_csd_input(self): @@ -731,6 +731,26 @@ def test_corr_polyremoval(self): helpers.run_polyremoval_test(call) +def run_csd_cfg_test(method_call, method, cfg, positivity=True): + + cfg.method = method + if method != 'granger': + cfg.frequency = [0, 70] + # test general tapers with + # additional parameters + cfg.taper = 'kaiser' + cfg.taper_opt = {'beta': 2} + + cfg.output = 'abs' + + result = method_call(cfg) + + # check here just for finiteness and positivity + assert np.all(np.isfinite(result.data)) + if positivity: + assert np.all(result.data[0, ...] >= -1e-10) + + def run_cfg_test(method_call, method, cfg, positivity=True): cfg.method = method From 671b76ce698cdef4c25ef2f6d1e9e9e2ccffd16c Mon Sep 17 00:00:00 2001 From: tensionhead Date: Fri, 13 Jan 2023 16:50:24 +0100 Subject: [PATCH 090/135] WIP: CSD and Spike selection tests Changes to be committed: modified: syncopy/datatype/base_data.py modified: syncopy/datatype/methods/selectdata.py modified: syncopy/tests/test_selectdata.py --- syncopy/datatype/base_data.py | 2 +- syncopy/datatype/methods/selectdata.py | 6 +- syncopy/tests/test_selectdata.py | 168 +++++++++++++++++++------ 3 files changed, 138 insertions(+), 38 deletions(-) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index 832e73531..6442f59b9 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -2156,7 +2156,7 @@ def _selection_setter(self, data, select, selectkey): "channel_j", ]: if len(idxList) > 1: - err = "Multi-channel-pair selections not supported" + err = "Unordered (low to high) or non-contiguous multi-channel-pair selections not supported" raise NotImplementedError(err) idxList = idxList[0] diff --git a/syncopy/datatype/methods/selectdata.py b/syncopy/datatype/methods/selectdata.py index 2f49cc43a..25ae7a761 100644 --- a/syncopy/datatype/methods/selectdata.py +++ b/syncopy/datatype/methods/selectdata.py @@ -285,10 +285,12 @@ def selectdata(data, "'".join(key + "', " for key in kwargs.keys())[:-2] raise SPYValueError(legal=lgl, varname="selection kwargs", actual=act) - # warn the user for ineffective selection keywords, e.g. 'frequency' for AnalogData + # get out if unsuitable selection keywords given, e.g. 'frequency' for AnalogData for key, value in selectDict.items(): if key not in expected and value is not None: - SPYWarning(f"No {key} selection available for {data.__class__.__name__}") + lgl = f"one of {data.__class__._selectionKeyWords}" + act = f"no `{key}` selection available for {data.__class__.__name__}" + raise SPYValueError(lgl, 'selection arguments', act) # now just keep going with the selection keys relevant for that particular data type selectDict = {key: selectDict[key] for key in data._selectionKeyWords} diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index e63cf3478..3f9091400 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -17,7 +17,6 @@ from syncopy.datatype.base_data import Selector from syncopy.datatype.methods.selectdata import selectdata from syncopy.shared.errors import SPYError, SPYValueError, SPYTypeError -from syncopy.tests.test_specest_fooof import _get_fooof_signal from syncopy.shared.tools import StructDict from syncopy import freqanalysis @@ -28,10 +27,32 @@ channel = 'channel', latency = 'time', taper = 'taper', - frequency = 'freq' + frequency = 'freq', + channel_i = 'channel_i', + channel_j = 'channel_j' ) +class TestGeneral: + + adata = spy.AnalogData(data=np.ones((2, 2)), samplerate=1) + csd_data = spy.CrossSpectralData(data=np.ones((2, 2, 2, 2)), samplerate=1) + + def test_Selector_init(self): + + with pytest.raises(SPYTypeError, match="Wrong type of `data`"): + Selector(np.arange(10), latency=[0, 4]) + + def test_invalid_sel_key(self): + + # AnalogData has no `frequency` + with pytest.raises(SPYValueError, match="no `frequency` selection available"): + spy.selectdata(self.adata, frequency=[1, 10]) + # CrossSpectralData has no `channel` (but channel_i, channel_j) + with pytest.raises(SPYValueError, match="no `channel` selection available"): + spy.selectdata(self.csd_data, channel=0) + + class TestAnalogSelections: nChannels = 10 @@ -68,7 +89,6 @@ def test_ad_selection(self): assert np.all(solution == res.data) - def test_ad_valid(self): """ @@ -183,23 +203,23 @@ def test_spectral_valid(self): # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") valid_selections = [ ( - {'frequency': np.array([30, 60]), - 'taper': [1, 0]}, - # the 'solutions' - {'frequency': slice(1, 3, 1), - 'taper': [1, 0]}, + {'frequency': np.array([30, 60]), + 'taper': [1, 0]}, + # the 'solutions' + {'frequency': slice(1, 3, 1), + 'taper': [1, 0]}, ), # 2nd selection ( - {'frequency': 'all', - 'taper': 'taper2', - 'latency': [1.2, 1.7], - 'trials': np.arange(1,3)}, - # the 'solutions' - {'frequency': slice(None), - 'taper': [1], - 'latency': [[1], [1]], - 'trials': [1, 2]}, + {'frequency': 'all', + 'taper': 'taper2', + 'latency': [1.2, 1.7], + 'trials': np.arange(1,3)}, + # the 'solutions' + {'frequency': slice(None), + 'taper': [1], + 'latency': [[1], [1]], + 'trials': [1, 2]}, ) ] @@ -262,6 +282,7 @@ def test_csd_selection(self): 'frequency': [25, 60]} res = spy.selectdata(self.csd_data, selection) + # pick the data by hand, dimord is: ['time', 'freq', 'channel_i', 'channel_j'] # latency [1, 1.5] covers 2nd - 3rd sample index # as time axis is array([1., 1.5, 2.]) @@ -275,7 +296,6 @@ def test_csd_selection(self): solution = np.concatenate([solution[1:3, 1:3, :2, :], solution[4:6, 1:3, :2, :]]) assert np.all(solution == res.data) - def test_csd_valid(self): """ @@ -286,23 +306,16 @@ def test_csd_valid(self): # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") valid_selections = [ ( - {'frequency': np.array([30, 60]), - 'taper': [1, 0]}, - # the 'solutions' - {'frequency': slice(1, 3, 1), - 'taper': [1, 0]}, + {'channel_i': [0, 1], 'channel_j': [1, 2], 'latency': [1, 2]}, + # the 'solutions' + {'channel_i': slice(0, 2, 1), 'channel_j': slice(1, 3, 1), + 'latency': 3 * [slice(0, 3, 1)]}, ), # 2nd selection ( - {'frequency': 'all', - 'taper': 'taper2', - 'latency': [1.2, 1.7], - 'trials': np.arange(1,3)}, - # the 'solutions' - {'frequency': slice(None), - 'taper': [1], - 'latency': [[1], [1]], - 'trials': [1, 2]}, + {'channel_i': ['channel2', 'channel3'], 'channel_j': 1}, + # the 'solutions' + {'channel_i': slice(1, 3, 1), 'channel_j': 1}, ) ] @@ -314,9 +327,94 @@ def test_csd_valid(self): attr_name = map_sel_attr[sel_kw] assert getattr(selector_object, attr_name) == solution[sel_kw] + def test_csd_invalid(self): + + # each selection test is a 3-tuple: (selection kwargs, Error, error message sub-string) + invalid_selections = [ + ( + {'channel_i': [0, 2]}, NotImplementedError, + r"Unordered \(low to high\) or non-contiguous multi-channel-pair selections not supported" + ), + ( + {'channel_i': [1, 0]}, NotImplementedError, + r"Unordered \(low to high\) or non-contiguous multi-channel-pair selections not supported" + ), + ( + {'channel_j': ['channel3', 'channel1']}, NotImplementedError, + r"Unordered \(low to high\) or non-contiguous multi-channel-pair selections not supported" + ) + + ] + + for selection in invalid_selections: + sel_kw, error, err_str = selection + with pytest.raises(error, match=err_str): + spy.selectdata(self.csd_data, sel_kw) + + +class TestSpikeSelections: + + nChannels = 10 + nTrials = 5 + samplerate = 2.0 + nSpikes = 20 + T_max = 2 * nSpikes # in samples, not seconds! + nSamples = T_max / nTrials + rng = np.random.default_rng(42) + + data = np.vstack([np.sort(rng.choice(range(T_max), size=nSpikes)), + rng.choice(np.arange(0, nChannels), size=nSpikes), + rng.choice(nChannels // 2, size=nSpikes)]).T + + trldef = np.vstack([np.arange(0, T_max, nSamples), + np.arange(0, T_max, nSamples) + nSamples, + np.ones(nTrials) * -2]).T + + spike_data = spy.SpikeData(data=data, samplerate=1, trialdefinition=trldef) + + def test_spike_selection(self): + + """ + Create a typical selection and check that the returned data is correct + """ + + selection = {'trials': [2, 4], + 'channel': [6, 2], + 'unit': [0, 3], + 'latency': [-1, 4]} + res = self.spike_data.selectdata(selection) + + # hand pick selection from the arrays + dat_arr = self.spike_data.data + + # these are trial intervals in sample indices! + trial2 = self.spike_data.trialdefinition[2, :2] + trial4 = self.spike_data.trialdefinition[4, :2] + + # create boolean mask for trials + bm = (dat_arr[:, 0] >= trial2[0]) & (dat_arr[:, 0] <= trial2[1]) + bm = bm | (dat_arr[:, 0] >= trial4[0]) & (dat_arr[:, 0] <= trial4[1]) + + # add channels [6, 2] + bm = bm & ((dat_arr[:, 1] == 6) | (dat_arr[:, 1] == 2)) + + # units [0, 3] + bm = bm & ((dat_arr[:, 2] == 0) | (dat_arr[:, 2] == 3)) + + # latency [-1, 4] + # to index all trials at once + time_vec = np.concatenate([t for t in self.spike_data.time]) + bm = bm & ((time_vec >= -1) & (time_vec <= 4)) + + # finally compare to selection result + assert np.all(dat_arr[bm] == res.data[()]) if __name__ == '__main__': - T1 = TestAnalogSelections() - T2 = TestSpectralSelections() - T3 = TestCrossSpectralSelections() + T1 = TestGeneral() + T2 = TestAnalogSelections() + T3 = TestSpectralSelections() + T4 = TestCrossSpectralSelections() + T5 = TestSpikeSelections() + + sdata = T5.spike_data From 502ed2158ac0b355f893858aea5239bc98f35abb Mon Sep 17 00:00:00 2001 From: tensionhead Date: Fri, 13 Jan 2023 18:03:41 +0100 Subject: [PATCH 091/135] NEW: Selectdata tests - the internal index representation is wild, at least for the DiscreteData types Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 184 +++++++++++++++++++++++++++---- 1 file changed, 165 insertions(+), 19 deletions(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 3f9091400..6613fb03a 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -11,7 +11,6 @@ from numbers import Number # Local imports -import syncopy.datatype as spd from syncopy.tests.misc import flush_local_cluster from syncopy.datatype import AnalogData, SpectralData from syncopy.datatype.base_data import Selector @@ -23,13 +22,15 @@ import syncopy as spy # map selection keywords to selector attributes (holding the idx to access selected data) -map_sel_attr = dict(trials = 'trial_ids', - channel = 'channel', - latency = 'time', - taper = 'taper', - frequency = 'freq', - channel_i = 'channel_i', - channel_j = 'channel_j' +map_sel_attr = dict(trials='trial_ids', + channel='channel', + latency='time', + taper='taper', + frequency='freq', + channel_i='channel_i', + channel_j='channel_j', + unit='unit', + eventid='eventid' ) @@ -41,7 +42,7 @@ class TestGeneral: def test_Selector_init(self): with pytest.raises(SPYTypeError, match="Wrong type of `data`"): - Selector(np.arange(10), latency=[0, 4]) + Selector(np.arange(10), {'latency': [0, 4]}) def test_invalid_sel_key(self): @@ -75,7 +76,7 @@ class TestAnalogSelections: def test_ad_selection(self): """ - Create a simple selection and check that the returned data is correct + Create a typical selection and check that the returned data is correct """ selection = {'trials': 1, 'channel': [6, 2], 'latency': [0, 1]} @@ -92,7 +93,7 @@ def test_ad_selection(self): def test_ad_valid(self): """ - Instantiate Selector class and check its only attributes (the idx) + Instantiate Selector class and check only its attributes (the idx) """ # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") @@ -116,7 +117,6 @@ def test_ad_valid(self): ) ] - for selection in valid_selections: # instantiate Selector and check attributes sel_kwargs, solution = selection @@ -148,6 +148,16 @@ def test_ad_invalid(self): with pytest.raises(error, match=err_str): spy.selectdata(self.adata, sel_kw) + def test_ad_parallel(self, testcluster=None): + # collect all tests of current class and repeat them in parallel + client = dd.Client(testcluster) + all_tests = [attr for attr in self.__dir__() + if (inspect.ismethod(getattr(self, attr)) and "parallel" not in attr)] + for test in all_tests: + getattr(self, test)() + flush_local_cluster(testcluster) + client.close() + class TestSpectralSelections: @@ -172,7 +182,7 @@ class TestSpectralSelections: def test_spectral_selection(self): """ - Create a simple selection and check that the returned data is correct + Create a typical selection and check that the returned data is correct """ selection = {'trials': 1, @@ -196,7 +206,7 @@ def test_spectral_selection(self): def test_spectral_valid(self): """ - Instantiate Selector class and check its only attributes (the idx) + Instantiate Selector class and check only its attributes (the idx) test mainly additional dimensions (taper and freq) here """ @@ -273,7 +283,7 @@ class TestCrossSpectralSelections: def test_csd_selection(self): """ - Create a simple selection and check that the returned data is correct + Create a typical selection and check that the returned data is correct """ selection = {'trials': [1, 0], @@ -299,7 +309,7 @@ def test_csd_selection(self): def test_csd_valid(self): """ - Instantiate Selector class and check its only attributes (the idx) + Instantiate Selector class and check only its attributes (the idx) test mainly additional dimensions (channel_i, channel_j) here """ @@ -391,7 +401,7 @@ def test_spike_selection(self): trial2 = self.spike_data.trialdefinition[2, :2] trial4 = self.spike_data.trialdefinition[4, :2] - # create boolean mask for trials + # create boolean mask for trials [2, 4] bm = (dat_arr[:, 0] >= trial2[0]) & (dat_arr[:, 0] <= trial2[1]) bm = bm | (dat_arr[:, 0] >= trial4[0]) & (dat_arr[:, 0] <= trial4[1]) @@ -409,6 +419,143 @@ def test_spike_selection(self): # finally compare to selection result assert np.all(dat_arr[bm] == res.data[()]) + def test_spike_valid(self): + + """ + Instantiate Selector class and check only its attributes, the idx + used by `_preview_trial` in the end + """ + + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") + valid_selections = [ + ( + # units get apparently indexed on a per trial basis + {'trials': np.arange(1, 4), 'channel': ['channel03', 'channel01'], 'unit': [2, 0]}, + {'trials': [1, 2, 3], 'channel': [2, 0], 'unit': [[], [], [1, 5]]}, + ), + # 2nd selection + ( + # time/latency idx can be mixed lists and slices O.0 + # and channel 'all' selections can still be effectively subsets.. + {'trials': [0, 4], 'latency': [0, 3], 'channel': 'all'}, + {'trials': [0, 4], 'latency': [slice(0, 4, 1), [1]], 'channel': [1, 2, 3, 5, 9]}, + ) + ] + + for selection in valid_selections: + # instantiate Selector and check attributes + sel_kwargs, solution = selection + selector_object = Selector(self.spike_data, sel_kwargs) + for sel_kw in sel_kwargs.keys(): + attr_name = map_sel_attr[sel_kw] + assert getattr(selector_object, attr_name) == solution[sel_kw] + + def test_spike_invalid(self): + + # each selection test is a 3-tuple: (selection kwargs, Error, error message sub-string) + invalid_selections = [ + ({'channel': ["channel33", "channel01"]}, SPYValueError, "existing names or indices"), + ({'channel': "my-non-existing-channel"}, SPYValueError, "existing names or indices"), + ({'channel': slice(None)}, SPYTypeError, "expected serializable data type"), + ({'unit': 99}, SPYValueError, "existing names or indices"), + ({'unit': slice(None)}, SPYTypeError, "expected serializable data type"), + ({'latency': [-1, 10]}, SPYValueError, "at least one trial covering the latency window"), + ] + + for selection in invalid_selections: + sel_kw, error, err_str = selection + with pytest.raises(error, match=err_str): + spy.selectdata(self.spike_data, sel_kw) + + +class TestEventSelections: + + """ + This data type probably needs some adjustments.. + """ + + nSamples = 4 + nTrials = 5 + eIDs = [0, 111, 31] # event ids + rng = np.random.default_rng(42) + + trldef = np.vstack([np.arange(0, nSamples * nTrials, nSamples), + np.arange(0, nSamples * nTrials, nSamples) + nSamples, + np.ones(nTrials) * -1]).T + + # Use a triple-trigger pattern to simulate EventData w/non-uniform trials + data = np.vstack([np.arange(0, nSamples * nTrials, 1), + rng.choice(eIDs, size=nSamples * nTrials)]).T + + edata = spy.EventData(data=data, samplerate=1, trialdefinition=trldef) + + def test_event_selection(self): + + # eIDs[1] = 111, a bit funny that here we need an index actually... + selection = {'eventid': 1, 'latency': [0, 1], 'trials': [0, 3]} + res = spy.selectdata(self.edata, selection) + + # hand pick selection from the arrays + dat_arr = self.edata.data + + # these are trial intervals in sample indices! + trial0 = self.edata.trialdefinition[0, :2] + trial3 = self.edata.trialdefinition[3, :2] + + # create boolean mask for trials [0, 3] + bm = (dat_arr[:, 0] >= trial0[0]) & (dat_arr[:, 0] <= trial0[1]) + bm = bm | (dat_arr[:, 0] >= trial3[0]) & (dat_arr[:, 0] <= trial3[1]) + + # add eventid eIDs[1] + bm = bm & (dat_arr[:, 1] == 111) + + # latency [0, 1] + # to index all trials at once + time_vec = np.concatenate([t for t in self.edata.time]) + bm = bm & ((time_vec >= 0) & (time_vec <= 1)) + + # finally compare to selection result + assert np.all(dat_arr[bm] == res.data[()]) + + def test_event_valid(self): + """ + Instantiate Selector class and check only its attributes, the idx + used by `_preview_trial` in the end + """ + + # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") + valid_selections = [ + ( + # eventids get apparently indexed on a per trial basis + {'trials': np.arange(1, 4), 'eventid': [0, 2]}, + {'trials': [1, 2, 3], 'eventid': [[2], slice(0, 2, 1), []]} + ), + ] + + for selection in valid_selections: + # instantiate Selector and check attributes + sel_kwargs, solution = selection + selector_object = Selector(self.edata, sel_kwargs) + for sel_kw in sel_kwargs.keys(): + attr_name = map_sel_attr[sel_kw] + assert getattr(selector_object, attr_name) == solution[sel_kw] + + def test_event_invalid(self): + + """ + eventid seems to be only indexable ([0, 1, 2]) instead of using the actual + numerical values ([0, 111, 31]), this should most likely change in the future.. + """ + # each selection test is a 3-tuple: (selection kwargs, Error, error message sub-string) + invalid_selections = [ + ({'eventid': [111, 31]}, SPYValueError, "existing names or indices"), + ({'eventid': '111'}, SPYValueError, "expected dtype = numeric"), + ] + + for selection in invalid_selections: + sel_kw, error, err_str = selection + with pytest.raises(error, match=err_str): + spy.selectdata(self.edata, sel_kw) if __name__ == '__main__': T1 = TestGeneral() @@ -416,5 +563,4 @@ def test_spike_selection(self): T3 = TestSpectralSelections() T4 = TestCrossSpectralSelections() T5 = TestSpikeSelections() - - sdata = T5.spike_data + T6 = TestEventSelections() From dd30d26fca7cf8ea8e7721cc9bc94214ecf5c046 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Fri, 13 Jan 2023 18:07:53 +0100 Subject: [PATCH 092/135] FIX: Formatting and co. Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 47 +++++++++++++++----------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 6613fb03a..c1db9b9dc 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -8,16 +8,12 @@ import numpy as np import inspect import dask.distributed as dd -from numbers import Number + # Local imports -from syncopy.tests.misc import flush_local_cluster -from syncopy.datatype import AnalogData, SpectralData from syncopy.datatype.base_data import Selector -from syncopy.datatype.methods.selectdata import selectdata -from syncopy.shared.errors import SPYError, SPYValueError, SPYTypeError -from syncopy.shared.tools import StructDict -from syncopy import freqanalysis +from syncopy.shared.errors import SPYValueError, SPYTypeError +from syncopy.tests.misc import flush_local_cluster import syncopy as spy @@ -85,7 +81,7 @@ def test_ad_selection(self): # pick the data by hand, latency [0, 1] covers 2nd - 4th sample index # as time axis is array([-0.5, 0. , 0.5, 1. , 1.5]) - solution = self.adata.data[self.nSamples : self.nSamples * 2] + solution = self.adata.data[self.nSamples:self.nSamples * 2] solution = np.column_stack([solution[1:4, 6], solution[1:4, 2]]) assert np.all(solution == res.data) @@ -99,21 +95,21 @@ def test_ad_valid(self): # each selection test is a 2-tuple: (selection kwargs, dict with same kws and the idx "solutions") valid_selections = [ ( - {'channel': ["channel03", "channel01"], - 'latency': [0, 1], - 'trials': np.arange(2)}, - # these are the idx used to access the actual data - {'channel': [2, 0], - 'latency': 2 * [slice(1, 4, 1)], - 'trials': [0, 1]} + {'channel': ["channel03", "channel01"], + 'latency': [0, 1], + 'trials': np.arange(2)}, + # these are the idx used to access the actual data + {'channel': [2, 0], + 'latency': 2 * [slice(1, 4, 1)], + 'trials': [0, 1]} ), ( - # 2nd selection with some repetitions - {'channel': [7, 3, 3], - 'trials': [0, 1, 1]}, - # 'solutions' - {'channel': [7, 3, 3], - 'trials': [0, 1, 1]} + # 2nd selection with some repetitions + {'channel': [7, 3, 3], + 'trials': [0, 1, 1]}, + # 'solutions' + {'channel': [7, 3, 3], + 'trials': [0, 1, 1]} ) ] @@ -197,7 +193,7 @@ def test_spectral_selection(self): # frequency covers only 2nd index (40 Hz) # pick trial - solution = self.sdata.data[self.nSamples : self.nSamples * 2] + solution = self.sdata.data[self.nSamples:self.nSamples * 2] # pick channels, frequency and latency and re-stack solution = np.stack([solution[:2, :, [1], 1], solution[:2, :, [1], 0]], axis=-1) @@ -224,7 +220,7 @@ def test_spectral_valid(self): {'frequency': 'all', 'taper': 'taper2', 'latency': [1.2, 1.7], - 'trials': np.arange(1,3)}, + 'trials': np.arange(1, 3)}, # the 'solutions' {'frequency': slice(None), 'taper': [1], @@ -248,8 +244,8 @@ def test_spectral_invalid(self): ({'frequency': '40Hz'}, SPYValueError, "'all' or `None` or float or list/array"), ({'frequency': 4}, SPYValueError, "all array elements to be bounded"), ({'frequency': slice(None)}, SPYTypeError, "expected serializable data type"), - ({'frequency': range(20,60)}, SPYTypeError, "expected array_like"), - ({'frequency': np.arange(20,60)}, SPYValueError, "expected array of shape"), + ({'frequency': range(20, 60)}, SPYTypeError, "expected array_like"), + ({'frequency': np.arange(20, 60)}, SPYValueError, "expected array of shape"), ({'taper': 'taper13'}, SPYValueError, "existing names or indices"), ({'taper': [18, 99]}, SPYValueError, "existing names or indices"), ] @@ -557,6 +553,7 @@ def test_event_invalid(self): with pytest.raises(error, match=err_str): spy.selectdata(self.edata, sel_kw) + if __name__ == '__main__': T1 = TestGeneral() T2 = TestAnalogSelections() From 2c066a4d6d2e322864fa08c12fbd6dd7ec8acbe8 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 16 Jan 2023 16:32:07 +0100 Subject: [PATCH 093/135] NEW: TrialIndexer - can be used both as an iterable or via explicit (single element) indexing Changes to be committed: modified: syncopy/__init__.py modified: syncopy/datatype/__init__.py modified: syncopy/datatype/base_data.py modified: syncopy/datatype/discrete_data.py new file: syncopy/datatype/util.py modified: syncopy/statistics/spike_psth.py modified: syncopy/statistics/summary_stats.py modified: syncopy/tests/test_selectdata.py --- syncopy/__init__.py | 2 +- syncopy/datatype/__init__.py | 2 + syncopy/datatype/base_data.py | 175 ++++------------------------ syncopy/datatype/discrete_data.py | 12 +- syncopy/datatype/util.py | 127 ++++++++++++++++++++ syncopy/statistics/spike_psth.py | 1 - syncopy/statistics/summary_stats.py | 5 +- syncopy/tests/test_selectdata.py | 11 +- 8 files changed, 160 insertions(+), 175 deletions(-) create mode 100644 syncopy/datatype/util.py diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 0719e3fb0..d9e1322cf 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -127,7 +127,7 @@ from .preproc import * # Register session -__session__ = datatype.base_data.SessionLogger() +__session__ = datatype.util.SessionLogger() # Override default traceback (differentiate b/w Jupyter/iPython and regular Python) from .shared.errors import SPYExceptionHandler diff --git a/syncopy/datatype/__init__.py b/syncopy/datatype/__init__.py index f138e8ddd..f490dc5d4 100644 --- a/syncopy/datatype/__init__.py +++ b/syncopy/datatype/__init__.py @@ -13,6 +13,7 @@ from .methods.selectdata import * from .methods.show import * from .methods.copy import * +from .util import * # Populate local __all__ namespace __all__ = [] @@ -25,3 +26,4 @@ __all__.extend(methods.selectdata.__all__) __all__.extend(methods.show.__all__) __all__.extend(methods.copy.__all__) +__all__.extend(util.__all__) diff --git a/syncopy/datatype/base_data.py b/syncopy/datatype/base_data.py index ba9fc4c9a..6f79d246e 100644 --- a/syncopy/datatype/base_data.py +++ b/syncopy/datatype/base_data.py @@ -10,11 +10,8 @@ import sys import os from abc import ABC, abstractmethod -from datetime import datetime from hashlib import blake2b -from itertools import islice from functools import reduce -from inspect import signature import shutil import numpy as np import h5py @@ -22,6 +19,7 @@ # Local imports import syncopy as spy +from .util import TrialIndexer from .methods.arithmetic import _process_operator from .methods.selectdata import selectdata from .methods.show import show @@ -744,8 +742,13 @@ def _t0(self): def trials(self): """list-like array of trials""" - return Indexer(map(self._get_trial, range(self.sampleinfo.shape[0])), - self.sampleinfo.shape[0]) if self.sampleinfo is not None else None + if self.sampleinfo is not None: + trial_ids = list(range(self.sampleinfo.shape[0])) + # this is cheap as it just initializes a list-like object + # with no real data and/or computation! + return TrialIndexer(self, trial_ids) + else: + return None @property def trialinfo(self): @@ -1161,147 +1164,6 @@ def __init__(self, filename=None, dimord=None, mode="r+", **kwargs): self._version = __version__ -class Indexer: - - __slots__ = ["_iterobj", "_iterlen"] - - def __init__(self, iterobj, iterlen): - """ - Make an iterable object subscriptable using itertools magic - """ - self._iterobj = iterobj - self._iterlen = iterlen - - def __iter__(self): - return self._iterobj - - def __getitem__(self, idx): - if np.issubdtype(type(idx), np.number): - try: - scalar_parser( - idx, varname="idx", ntype="int_like", lims=[0, self._iterlen - 1] - ) - except Exception as exc: - raise exc - return next(islice(self._iterobj, idx, idx + 1)) - elif isinstance(idx, slice): - start, stop = idx.start, idx.stop - if idx.start is None: - start = 0 - if idx.stop is None: - stop = self._iterlen - index = slice(start, stop, idx.step) - if not (0 <= index.start < self._iterlen) or not ( - 0 < index.stop <= self._iterlen - ): - err = "value between {lb:s} and {ub:s}" - raise SPYValueError( - err.format(lb="0", ub=str(self._iterlen)), - varname="idx", - actual=str(index), - ) - return np.hstack(islice(self._iterobj, index.start, index.stop, index.step)) - elif isinstance(idx, (list, np.ndarray)): - try: - array_parser( - idx, - varname="idx", - ntype="int_like", - hasnan=False, - hasinf=False, - lims=[0, self._iterlen], - dims=1, - ) - except Exception as exc: - raise exc - return np.hstack( - [next(islice(self._iterobj, int(ix), int(ix + 1))) for ix in idx] - ) - else: - raise SPYTypeError(idx, varname="idx", expected="int_like or slice") - - def __len__(self): - return self._iterlen - - def __repr__(self): - return self.__str__() - - def __str__(self): - return "{} element iterable".format(self._iterlen) - - -class SessionLogger: - - __slots__ = ["sessionfile", "_rm"] - - def __init__(self): - - # Create package-wide tmp directory if not already present - if not os.path.exists(__storage__): - try: - os.mkdir(__storage__) - except Exception as exc: - err = ( - "Syncopy core: cannot create temporary storage directory {}. " - + "Original error message below\n{}" - ) - raise IOError(err.format(__storage__, str(exc))) - - # Check for upper bound of temp directory size - with os.scandir(__storage__) as scan: - st_size = 0.0 - st_fles = 0 - for fle in scan: - try: - st_size += fle.stat().st_size / 1024 ** 3 - st_fles += 1 - # this catches a cleanup by another process - except FileNotFoundError: - continue - - if st_size > __storagelimit__: - msg = ( - "\nSyncopy WARNING: Temporary storage folder {tmpdir:s} " - + "contains {nfs:d} files taking up a total of {sze:4.2f} GB on disk. \n" - + "Consider running `spy.cleanup()` to free up disk space." - ) - print(msg.format(tmpdir=__storage__, nfs=st_fles, sze=st_size)) - - # If we made it to this point, (attempt to) write the session file - sess_log = "{user:s}@{host:s}: <{time:s}> started session {sess:s}" - self.sessionfile = os.path.join( - __storage__, "session_{}_log.id".format(__sessionid__) - ) - try: - with open(self.sessionfile, "w") as fid: - fid.write( - sess_log.format( - user=getpass.getuser(), - host=socket.gethostname(), - time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - sess=__sessionid__, - ) - ) - except Exception as exc: - err = "Syncopy core: cannot access {}. Original error message below\n{}" - raise IOError(err.format(self.sessionfile, str(exc))) - - # Workaround to prevent Python from garbage-collecting ``os.unlink`` - self._rm = os.unlink - - def __repr__(self): - return self.__str__() - - def __str__(self): - return "Session {}".format(__sessionid__) - - def __del__(self): - try: - self._rm(self.sessionfile) - except FileNotFoundError: - pass - - class FauxTrial: """ Stand-in mockup of NumPy arrays representing trial data @@ -1617,23 +1479,28 @@ def trial_ids(self, dataselect): raise SPYValueError(legal=lgl, varname=vname, actual=act) else: trials = trlList - self._trial_ids = list(trials) # ensure `trials` is a list cf. #180 + self._trial_ids = list(trials) # ensure `trials` is a list cf. #180 @property def trials(self): """ - Returns an Indexer indexing single trial arrays respecting the selection - Indices are RELATIVE with respect to existing trial selections: + Returns an iterable indexing single trial arrays respecting the selection + Indices are ABSOLUTE with respect to existing trial selections: - >>> selection.trials[2] + >>> selection.trials[11] - indexes the 3rd trial of `selection.trial_ids` + indexes the 11th trial of the original dataset, if and only if + trial number 11 is part of the selection. Selections must be "simple": ordered and without repetitions """ - return Indexer(map(self._get_trial, self.trial_ids), - len(self.trial_ids)) if self.trial_ids is not None else None + if self.sampleinfo is not None: + # this is cheap as it just initializes a list-like object + # with no real data and/or computations! + return TrialIndexer(self, self.trial_ids) + else: + return None def create_get_trial(self, data): """ Closure to allow emulation of BaseData._get_trial""" @@ -2338,7 +2205,7 @@ def __str__(self): attr, "s" if not attr.endswith("s") else "", ) - elif isinstance(val, (list, Indexer)): + elif isinstance(val, (list, TrialIndexer)): ppdict[attr] = "{0:d} {1:s}{2:s}, ".format( len(val), attr, "s" if not attr.endswith("s") else "" ) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 376c0fbb5..a8fef9c09 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -11,7 +11,7 @@ # Local imports -from .base_data import BaseData, Indexer, FauxTrial +from .base_data import BaseData, FauxTrial from .methods.definetrial import definetrial from syncopy.shared.parsers import scalar_parser, array_parser from syncopy.shared.errors import SPYValueError @@ -164,16 +164,6 @@ def trialid(self, trlid): raise exc self._trialid = np.array(trlid, dtype=int) - @property - def trials(self): - """list-like([sample x (>=2)] :class:`numpy.ndarray`) : trial slices of :attr:`data` property""" - if self.trialid is not None: - valid_trls = np.unique(self.trialid[self.trialid >= 0]) - return Indexer(map(self._get_trial, valid_trls), - valid_trls.size) - else: - return None - @property def trialtime(self): """list(:class:`numpy.ndarray`): trigger-relative sample times in s""" diff --git a/syncopy/datatype/util.py b/syncopy/datatype/util.py new file mode 100644 index 000000000..711c15061 --- /dev/null +++ b/syncopy/datatype/util.py @@ -0,0 +1,127 @@ +""" +Helpers and tools for Syncopy data classes +""" + +import os +import getpass +import socket +from datetime import datetime + +# Syncopy imports +from syncopy import __storage__, __storagelimit__, __sessionid__ +from syncopy.shared.errors import SPYTypeError, SPYValueError + +__all__ = ['TrialIndexer', 'SessionLogger'] + + +class TrialIndexer(list): + + def __init__(self, data_object, idx_list): + """ + Subclass Python's list to obtain an indexable trials iterable. + Relies on the `_get_trial` method of the + respective `data_object`. + + Parameters + ---------- + data_object : Syncopy data class, e.g. AnalogData + + idx_list : list + List of valid trial indices for `_get_trial` + """ + self.data_object = data_object + self.idx_list = idx_list + self._len = len(idx_list) + + def __getitem__(self, trialno): + if trialno not in self.idx_list: + lgl = "index of existing trials" + raise SPYValueError(lgl, "trial index", trialno) + return self.data_object._get_trial(trialno) + + def __iter__(self): + # this generator gets freshly created and exhausted + # only for iterations (iterator protocol) + yield from (self[i] for i in self.idx_list) + + def __len__(self): + return self._len + + def __repr__(self): + return self.__str__() + + def __str__(self): + return "{} element iterable".format(self._len) + + +class SessionLogger: + + __slots__ = ["sessionfile", "_rm"] + + def __init__(self): + + # Create package-wide tmp directory if not already present + if not os.path.exists(__storage__): + try: + os.mkdir(__storage__) + except Exception as exc: + err = ( + "Syncopy core: cannot create temporary storage directory {}. " + + "Original error message below\n{}" + ) + raise IOError(err.format(__storage__, str(exc))) + + # Check for upper bound of temp directory size + with os.scandir(__storage__) as scan: + st_size = 0.0 + st_fles = 0 + for fle in scan: + try: + st_size += fle.stat().st_size / 1024 ** 3 + st_fles += 1 + # this catches a cleanup by another process + except FileNotFoundError: + continue + + if st_size > __storagelimit__: + msg = ( + "\nSyncopy WARNING: Temporary storage folder {tmpdir:s} " + + "contains {nfs:d} files taking up a total of {sze:4.2f} GB on disk. \n" + + "Consider running `spy.cleanup()` to free up disk space." + ) + print(msg.format(tmpdir=__storage__, nfs=st_fles, sze=st_size)) + + # If we made it to this point, (attempt to) write the session file + sess_log = "{user:s}@{host:s}: <{time:s}> started session {sess:s}" + self.sessionfile = os.path.join( + __storage__, "session_{}_log.id".format(__sessionid__) + ) + try: + with open(self.sessionfile, "w") as fid: + fid.write( + sess_log.format( + user=getpass.getuser(), + host=socket.gethostname(), + time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + sess=__sessionid__, + ) + ) + except Exception as exc: + err = "Syncopy core: cannot access {}. Original error message below\n{}" + raise IOError(err.format(self.sessionfile, str(exc))) + + # Workaround to prevent Python from garbage-collecting ``os.unlink`` + self._rm = os.unlink + + def __repr__(self): + return self.__str__() + + def __str__(self): + return "Session {}".format(__sessionid__) + + def __del__(self): + try: + self._rm(self.sessionfile) + except FileNotFoundError: + pass + diff --git a/syncopy/statistics/spike_psth.py b/syncopy/statistics/spike_psth.py index f6b6a986f..a9ede5f82 100644 --- a/syncopy/statistics/spike_psth.py +++ b/syncopy/statistics/spike_psth.py @@ -11,7 +11,6 @@ from syncopy.shared.parsers import data_parser, scalar_parser, array_parser from syncopy.shared.tools import get_defaults, get_frontend_cfg from syncopy.datatype import TimeLockData -from syncopy.datatype.base_data import Indexer from syncopy.shared.errors import SPYValueError, SPYTypeError, SPYInfo from syncopy.shared.kwarg_decorators import ( diff --git a/syncopy/statistics/summary_stats.py b/syncopy/statistics/summary_stats.py index ae4ab6c5d..dae8dd85b 100644 --- a/syncopy/statistics/summary_stats.py +++ b/syncopy/statistics/summary_stats.py @@ -274,7 +274,6 @@ def _statistics(spy_data, operation, dim, keeptrials=True, **kwargs): if kwargs.get('parallel'): msg = "Trial statistics can be only computed sequentially, ignoring `parallel` keyword" SPYWarning(msg) - out = _trial_statistics(spy_data, operation) # any other statistic @@ -447,8 +446,10 @@ def _trial_statistics(in_data, operation='mean'): act = f"got {nTrials} trials" raise SPYValueError(lgl, 'in_data', act) + # index 1st selected trial + idx0 = in_data.selection.trial_ids[0] # we always have at least one (all-to-all) trial selection - out_shape = in_data.selection.trials[0].shape + out_shape = in_data.selection.trials[idx0].shape # now look at the other ones for trl in in_data.selection.trials: diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 24217d422..d400a0dfd 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -654,15 +654,14 @@ def test_selector_trials(self): # checks time axis assert len(ang.selection.trials) == 3 - # test for non-existing trials, trial indices are relative here! + # trial indices are absolute here! select = {'trials': [0, 3, 5]} ang.selectdata(**select, inplace=True) assert ang.selection.trial_ids[2] == 5 - # this returns original trial 6 (with index 5) - assert np.array_equal(ang.selection.trials[2], ang.trials[5]) - # we only have 3 trials selected here, so max. relative index is 2 - with pytest.raises(SPYValueError, match='less or equals 2'): - ang.selection.trials[5] + assert np.array_equal(ang.selection.trials[5], ang.trials[5]) + # test access of non-selected trial fails + with pytest.raises(SPYValueError, match='expected index of existing trials'): + ang.selection.trials[1] # Fancy indexing is not allowed so far select = {'channel': [7, 7, 8]} From 29ae020967f9c1d68c0f2600019fba866ffb23f0 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 16 Jan 2023 16:57:16 +0100 Subject: [PATCH 094/135] CHG: Add tests and check for single trial index Changes to be committed: modified: syncopy/datatype/util.py modified: syncopy/tests/test_basedata.py --- syncopy/datatype/util.py | 4 ++++ syncopy/tests/test_basedata.py | 38 ++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/syncopy/datatype/util.py b/syncopy/datatype/util.py index 711c15061..d652d7482 100644 --- a/syncopy/datatype/util.py +++ b/syncopy/datatype/util.py @@ -6,6 +6,7 @@ import getpass import socket from datetime import datetime +from numbers import Number # Syncopy imports from syncopy import __storage__, __storagelimit__, __sessionid__ @@ -29,11 +30,14 @@ def __init__(self, data_object, idx_list): idx_list : list List of valid trial indices for `_get_trial` """ + self.data_object = data_object self.idx_list = idx_list self._len = len(idx_list) def __getitem__(self, trialno): + if not isinstance(trialno, Number): + raise SPYTypeError(trialno, "trial index", "single number to index a single trial") if trialno not in self.idx_list: lgl = "index of existing trials" raise SPYValueError(lgl, "trial index", trialno) diff --git a/syncopy/tests/test_basedata.py b/syncopy/tests/test_basedata.py index 61b2705c9..6e2d0e64d 100644 --- a/syncopy/tests/test_basedata.py +++ b/syncopy/tests/test_basedata.py @@ -161,6 +161,44 @@ def test_trialdef(self): assert np.array_equal(dummy._t0, self.trl[dclass][:, 2]) assert np.array_equal(dummy.trialinfo.flatten(), self.trl[dclass][:, 3]) + def test_trials_property(self): + + # 3 trials, trial index = data values + data = AnalogData([i * np.ones((2,2)) for i in range(3)], samplerate=1) + + # single index access + assert np.all(data.trials[0] == 0) + assert np.all(data.trials[1] == 1) + assert np.all(data.trials[2] == 2) + + # iterator + all_trials = [trl for trl in data.trials] + assert len(all_trials) == 3 + assert all([np.all(all_trials[i] == i) for i in range(3)]) + + # selection + data.selectdata(trials=[0, 2], inplace=True) + all_selected_trials = [trl for trl in data.selection.trials] + assert data.selection.trial_ids == [0, 2] + assert len(all_selected_trials) == 2 + assert all([np.all(data.selection.trials[i] == i) for i in data.selection.trial_ids]) + + # check that non-existing trials get catched + with pytest.raises(SPYValueError, match='existing trials'): + data.trials[999] + # selections have absolute trial indices! + with pytest.raises(SPYValueError, match='existing trials'): + data.selection.trials[1] + + # check that invalid trial indexing gets catched + with pytest.raises(SPYTypeError, match='trial index'): + data.trials[range(4)] + with pytest.raises(SPYTypeError, match='trial index'): + data.trials[2:3] + with pytest.raises(SPYTypeError, match='trial index'): + data.trials[np.arange(3)] + return data + # Test ``_gen_filename`` with `AnalogData` only - method is independent from concrete data object def test_filename(self): # ensure we're salting sufficiently to create at least `numf` From e3dbb3e5e5f68b53fb45354005f97996a6f6e197 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 16 Jan 2023 17:01:10 +0100 Subject: [PATCH 095/135] FIX: Remove wip return statement Changes to be committed: modified: syncopy/tests/test_basedata.py --- syncopy/tests/test_basedata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/syncopy/tests/test_basedata.py b/syncopy/tests/test_basedata.py index 6e2d0e64d..14bdf36df 100644 --- a/syncopy/tests/test_basedata.py +++ b/syncopy/tests/test_basedata.py @@ -197,7 +197,6 @@ def test_trials_property(self): data.trials[2:3] with pytest.raises(SPYTypeError, match='trial index'): data.trials[np.arange(3)] - return data # Test ``_gen_filename`` with `AnalogData` only - method is independent from concrete data object def test_filename(self): From c3529e6d054d3cdb710b0048ef8bf5c9f9b5a5bb Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 16 Jan 2023 17:26:36 +0100 Subject: [PATCH 096/135] FIX: Use samplerate variable Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index c1db9b9dc..cc3f510a4 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -376,7 +376,9 @@ class TestSpikeSelections: np.arange(0, T_max, nSamples) + nSamples, np.ones(nTrials) * -2]).T - spike_data = spy.SpikeData(data=data, samplerate=1, trialdefinition=trldef) + spike_data = spy.SpikeData(data=data, + samplerate=samplerate, + trialdefinition=trldef) def test_spike_selection(self): From 438f513a76616aa8bf7f0583b85f0a8cc18523c5 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 10:30:18 +0100 Subject: [PATCH 097/135] FIX: Correct samplerate Changes to be committed: modified: syncopy/tests/test_selectdata.py --- syncopy/tests/test_selectdata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index cc3f510a4..1fcf74b36 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -362,7 +362,7 @@ class TestSpikeSelections: nChannels = 10 nTrials = 5 - samplerate = 2.0 + samplerate = 1.0 nSpikes = 20 T_max = 2 * nSpikes # in samples, not seconds! nSamples = T_max / nTrials From 233b79f2f44c5940608a58ddb8db07891d0857f4 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Tue, 17 Jan 2023 10:47:51 +0100 Subject: [PATCH 098/135] Test config added, test for parallel simplified --- syncopy/tests/test_connectivity.py | 85 +++++++++--------------------- 1 file changed, 24 insertions(+), 61 deletions(-) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index f1e4be96d..3979202c8 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -473,18 +473,8 @@ def test_coh_cfg(self): cfg=get_defaults(cafunc)) @skip_low_mem - def test_coh_parallel(self, testcluster=None): - - ppl.ioff() - client = dd.Client(testcluster) - all_tests = [attr for attr in self.__dir__() - if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] - - for test in all_tests: - test_method = getattr(self, test) - test_method() - client.close() - ppl.ion() + def test_parallel(self): + check_parallel(TestCoherence()) def test_coh_padding(self): @@ -554,30 +544,21 @@ def test_data_output_type(self): cross_spec = spy.connectivityanalysis(self.spec, method='csd') assert np.all(self.spec.freq == cross_spec.freq) assert cross_spec.data.dtype.name == 'complex64' + assert cross_spec.data.shape != self.spec.data.shape @skip_low_mem - def test_csd_parallel(self, testcluster=None): - - ppl.ioff() - client = dd.Client(testcluster) - all_tests = [attr for attr in self.__dir__() - if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] - - for test in all_tests: - test_method = getattr(self, test) - test_method() - client.close() - ppl.ion() - - def test_csd_cfg(self): - - call = lambda cfg: cafunc(self.spec, cfg) - - run_cfg_test(call, method='csd', cfg=get_defaults(cafunc)) + def test_parallel(self): + check_parallel(TestCSD()) def test_csd_input(self): + assert isinstance(self.spec, SpectralData) - assert not isinstance(self.spec, SpectralData) + def test_csd_cfg(self): + Method = 'csd' + cross_spec = spy.connectivityanalysis(self.spec, method=Method) + assert len(cross_spec.cfg) == 2 + assert np.all([True for cfg in zip(self.spec.cfg['freqanalysis'], cross_spec.cfg['freqanalysis']) if cfg[0] == cfg[1]]) + assert cross_spec.cfg['connectivityanalysis'].method == Method class TestCorrelation: @@ -712,18 +693,8 @@ def test_corr_cfg(self): cfg=get_defaults(cafunc)) @skip_low_mem - def test_corr_parallel(self, testcluster=None): - - ppl.ioff() - client = dd.Client(testcluster) - all_tests = [attr for attr in self.__dir__() - if (inspect.ismethod(getattr(self, attr)) and 'parallel' not in attr)] - - for test in all_tests: - test_method = getattr(self, test) - test_method() - client.close() - ppl.ion() + def test_parallel(self): + check_parallel(TestCorrelation()) def test_corr_polyremoval(self): @@ -731,24 +702,16 @@ def test_corr_polyremoval(self): helpers.run_polyremoval_test(call) -def run_csd_cfg_test(method_call, method, cfg, positivity=True): - - cfg.method = method - if method != 'granger': - cfg.frequency = [0, 70] - # test general tapers with - # additional parameters - cfg.taper = 'kaiser' - cfg.taper_opt = {'beta': 2} - - cfg.output = 'abs' - - result = method_call(cfg) - - # check here just for finiteness and positivity - assert np.all(np.isfinite(result.data)) - if positivity: - assert np.all(result.data[0, ...] >= -1e-10) +def check_parallel(TestClass, testcluster=None): + ppl.ioff() + client = dd.Client(testcluster) + all_tests = [attr for attr in TestClass.__dir__() + if (inspect.ismethod(getattr(TestClass, attr)) and 'parallel' not in attr)] + for test in all_tests: + test_method = getattr(TestClass, test) + test_method() + client.close() + ppl.ion() def run_cfg_test(method_call, method, cfg, positivity=True): From ca8f80dbcaad4ed91c5f181d96bf5316d4687a1f Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Tue, 17 Jan 2023 11:01:13 +0100 Subject: [PATCH 099/135] FIX: use samplerate variable for consistency --- syncopy/tests/test_selectdata.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index c1dcdaa47..0dd6c8830 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -475,6 +475,7 @@ class TestEventSelections: nSamples = 4 nTrials = 5 + samplerate = 1.0 eIDs = [0, 111, 31] # event ids rng = np.random.default_rng(42) @@ -486,7 +487,7 @@ class TestEventSelections: data = np.vstack([np.arange(0, nSamples * nTrials, 1), rng.choice(eIDs, size=nSamples * nTrials)]).T - edata = spy.EventData(data=data, samplerate=1, trialdefinition=trldef) + edata = spy.EventData(data=data, samplerate=samplerate, trialdefinition=trldef) def test_event_selection(self): From 0dca548d02e4a4921eca6d4fa52f607a0283fc41 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Tue, 17 Jan 2023 11:35:16 +0100 Subject: [PATCH 100/135] CHG: clarify comment --- syncopy/tests/test_selectdata.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/syncopy/tests/test_selectdata.py b/syncopy/tests/test_selectdata.py index 0dd6c8830..87c389849 100644 --- a/syncopy/tests/test_selectdata.py +++ b/syncopy/tests/test_selectdata.py @@ -81,7 +81,9 @@ def test_ad_selection(self): # pick the data by hand, latency [0, 1] covers 2nd - 4th sample index # as time axis is array([-0.5, 0. , 0.5, 1. , 1.5]) + # pick trial solution = self.adata.data[self.nSamples:self.nSamples * 2] + # pick channels and latency solution = np.column_stack([solution[1:4, 6], solution[1:4, 2]]) assert np.all(solution == res.data) From 0ef4343601e019d1f1925dd6ba7f8d7fa5e6cdfe Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 15:13:10 +0100 Subject: [PATCH 101/135] CHG: Remove list as parent class of TrialIndexer - it is not needed at all, all relevant functionality got implemented from scratch Changes to be committed: modified: syncopy/datatype/util.py --- syncopy/datatype/util.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/syncopy/datatype/util.py b/syncopy/datatype/util.py index d652d7482..4fb88f672 100644 --- a/syncopy/datatype/util.py +++ b/syncopy/datatype/util.py @@ -15,11 +15,12 @@ __all__ = ['TrialIndexer', 'SessionLogger'] -class TrialIndexer(list): +class TrialIndexer: def __init__(self, data_object, idx_list): """ - Subclass Python's list to obtain an indexable trials iterable. + Class to obtain an indexable trials iterable from + an instantiated Syncopy data class `data_object`. Relies on the `_get_trial` method of the respective `data_object`. @@ -36,6 +37,7 @@ def __init__(self, data_object, idx_list): self._len = len(idx_list) def __getitem__(self, trialno): + # single trial access via index operator [] if not isinstance(trialno, Number): raise SPYTypeError(trialno, "trial index", "single number to index a single trial") if trialno not in self.idx_list: @@ -45,7 +47,8 @@ def __getitem__(self, trialno): def __iter__(self): # this generator gets freshly created and exhausted - # only for iterations (iterator protocol) + # for each new iteration, with only 1 trial being in memory + # at any given time yield from (self[i] for i in self.idx_list) def __len__(self): From ff3a3ba261b3329867398031ca21115517932e58 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 15:24:27 +0100 Subject: [PATCH 102/135] FIX: Arithmetic comparison - weird that this worked in the 1st place.. Changes to be committed: modified: syncopy/tests/test_basedata.py --- syncopy/tests/test_basedata.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/syncopy/tests/test_basedata.py b/syncopy/tests/test_basedata.py index 14bdf36df..ac4de81a7 100644 --- a/syncopy/tests/test_basedata.py +++ b/syncopy/tests/test_basedata.py @@ -278,7 +278,7 @@ def test_arithmetic(self): # Start w/the one operator that does not handle zeros well... with pytest.raises(SPYValueError) as spyval: - dummy / 0 + _ = dummy / 0 assert "expected non-zero scalar for division" in str(spyval.value) # Go through all supported operators and try to sabotage them @@ -374,8 +374,8 @@ def test_arithmetic(self): # Difference in actual numerical data dummy3 = dummy.copy() for dsetName in dummy3._hdfFileDatasetProperties: - getattr(dummy3, dsetName)[0] = 2 * np.pi - assert dummy3 != dummy + getattr(dummy3, dsetName)[0, 0] = -99 + assert dummy3.data != dummy.data del dummy, dummy3, other From 47ea688c5e821291c7a1e5edcfaebbe325cd9752 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 16:05:21 +0100 Subject: [PATCH 103/135] CHG: Rewrite channel/unit label setting logic - it's tricky due to empty data class inits Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 57 +++++++++++++++++++------------ 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index b644bd81a..2bce96ff4 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -312,6 +312,7 @@ def __init__(self, data=None, samplerate=None, trialid=None, **kwargs): # Fill in dimensional info definetrial(self, kwargs.get("trialdefinition")) + class SpikeData(DiscreteData): """Spike times of multi- and/or single units @@ -328,11 +329,17 @@ class SpikeData(DiscreteData): _selectionKeyWords = DiscreteData._selectionKeyWords + ('channel', 'unit',) def _compute_unique(self): + """ + Use `np.unique` on whole(!) dataset to compute globally + available channel and unit indices only once + """ + # after data was added via selection or loading from file + # this function gets re-triggered if self.data is None: return - # this is costly + # this is costly and loads the entire hdf5 dataset into memory! self.channel_idx = np.unique(self.data[:, self.dimord.index("channel")]) self.unit_idx = np.unique(self.data[:, self.dimord.index("unit")]) @@ -348,24 +355,27 @@ def channel(self, chan): if chan is not None: raise SPYValueError(f"non-empty SpikeData", "cannot assign `channel` without data. " + "Please assign data first") - # empy labels for empty data is fine + # No labels for no data is fine self._channel = chan return # there is data - else: - if chan is None: - raise SPYValueError("channel labels, cannot set `channel` to `None` with existing data.") + elif chan is None: + raise SPYValueError("channel labels, cannot set `channel` to `None` with existing data.") - # we have data and new labels - if self.channel_idx is None: - self._compute_unique() + # if we landed here, we have data and new labels + + # in case of selections and/or loading from file + # the constructor was called with data=None, hence + # we have to compute the unique indices here + if self.channel_idx is None: + self._compute_unique() # we need as many labels as there are distinct channels nChan = self.channel_idx.size if nChan != len(chan): - raise SPYValueError(f"exactly {nChan} channel labels") + raise SPYValueError(f"exactly {nChan} channel label(s)") array_parser(chan, varname="channel", ntype="str", dims=(nChan, )) self._channel = np.array(chan) @@ -402,13 +412,14 @@ def unit(self, unit): return # there is data - else: - if unit is None: - raise SPYValueError("unit labels, cannot set `unit` to `None` with existing data.") + elif unit is None: + raise SPYValueError("unit labels, cannot set `unit` to `None` with existing data.") - # we have data and new labels - if self.unit_idx is None: - self._compute_unique() + # in case of selections and/or loading from file + # the constructor was called with data=None, hence + # we have to compute this here + if self.unit_idx is None: + self._compute_unique() if unit is None and self.data is not None: raise SPYValueError("Cannot set `unit` to `None` with existing data.") @@ -421,7 +432,7 @@ def unit(self, unit): nunit = self.unit_idx.size if nunit != len(unit): - raise SPYValueError(f"exactly {nunit} unit labels") + raise SPYValueError(f"exactly {nunit} unit label(s)") array_parser(unit, varname="unit", ntype="str", dims=(nunit,)) self._unit = np.array(unit) @@ -545,17 +556,21 @@ def __init__(self, # for fast lookup and labels self._compute_unique() - # use the setters to assign initial labels, + # constructor gets `data=None` for + # empty inits, selections and loading from file + # can't set any labels in that case if channel is not None: - # this rightfully fails for empty data + # setter raises exception if data=None self.channel = channel - else: - # sets to None if no data + elif data is not None: + # data but no given labels self.channel = self._default_channel_labels() + # same for unit if unit is not None: + # setter raises exception if data=None self.unit = unit - else: + elif data is not None: self.unit = self._default_unit_labels() From d5abb6235012fca754771856860a8371332a0756 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 16:30:30 +0100 Subject: [PATCH 104/135] CHG: Enforce integer data type Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 2bce96ff4..fc129ce5d 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -14,7 +14,7 @@ from .base_data import BaseData, FauxTrial from .methods.definetrial import definetrial from syncopy.shared.parsers import scalar_parser, array_parser -from syncopy.shared.errors import SPYValueError, SPYError +from syncopy.shared.errors import SPYValueError, SPYError, SPYTypeError from syncopy.shared.tools import best_match __all__ = ["SpikeData", "EventData"] @@ -51,9 +51,15 @@ def data(self): @data.setter def data(self, inData): + """ Also checks for integer type of data """ # this comes from BaseData self._set_dataset_property(inData, "data") + if inData is not None: + # probably not the most elegant way.. + if not 'int' in str(self.data.dtype): + raise SPYTypeError(self.data.dtype, 'data', "integer like") + def __str__(self): # Get list of print-worthy attributes ppattrs = [attr for attr in self.__dir__() @@ -332,10 +338,14 @@ def _compute_unique(self): """ Use `np.unique` on whole(!) dataset to compute globally available channel and unit indices only once + + This function gets triggered by the constructor + `if data is not None` or latest when channel/unit + labels are assigned with the respective setters. """ # after data was added via selection or loading from file - # this function gets re-triggered + # this function gets re-triggered by the channel/unit setters! if self.data is None: return @@ -445,7 +455,7 @@ def _default_unit_labels(self): if self.data is not None: unit_max = self.unit_idx.max() - return np.array(["unit" + str(int(i)).zfill(len(str(unit_max))) + return np.array(["unit" + str(int(i)).zfill(len(str(unit_max)) + 1) for i in self.unit_idx]) else: return None From 6aecbf353b98e486c5651312926e5fc86ec6154a Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 16:58:15 +0100 Subject: [PATCH 105/135] CHG: Make unit indices also 0-based - to be in sync with channel Changes to be committed: modified: syncopy/datatype/discrete_data.py --- syncopy/datatype/discrete_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index fc129ce5d..22eb70b6c 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -455,7 +455,7 @@ def _default_unit_labels(self): if self.data is not None: unit_max = self.unit_idx.max() - return np.array(["unit" + str(int(i)).zfill(len(str(unit_max)) + 1) + return np.array(["unit" + str(int(i + 1)).zfill(len(str(unit_max)) + 1) for i in self.unit_idx]) else: return None From fa89214add8b0bd15a3e14bd57a0e1a768748488 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 16:58:46 +0100 Subject: [PATCH 106/135] CHG: Add init tests and remove redundant selection tests Changes to be committed: modified: syncopy/tests/test_discretedata.py --- syncopy/tests/test_discretedata.py | 117 +++++++---------------------- 1 file changed, 27 insertions(+), 90 deletions(-) diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index 6c8e307cd..e163b15a0 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -43,6 +43,26 @@ class TestSpikeData(): num_chn = data[:, 1].max() + 1 num_unt = data[:, 2].max() + 1 + def test_init(self): + + # data and no labels triggers default labels + dummy = SpikeData(data=4 * np.ones((2, 3), dtype=int)) + # labels are 0-based + assert dummy.channel == 'channel05' + assert dummy.unit == 'unit05' + + # data and fitting labels is fine + assert isinstance(SpikeData(data=np.ones((2, 3), dtype=int), channel=['only_channel']), + SpikeData) + + # data and too many labels + with pytest.raises(SPYValueError, match='expected exactly 1 unit'): + _ = SpikeData(data=np.ones((2, 3), dtype=int), unit=['unit1', 'unit2']) + + # no data but labels + with pytest.raises(SPYValueError, match='cannot assign `channel` without data'): + _ = SpikeData(channel=['a', 'b', 'c']) + def test_empty(self): dummy = SpikeData() assert len(dummy.cfg) == 0 @@ -54,9 +74,11 @@ def test_empty(self): SpikeData({}) def test_issue_257_fixed_no_error_for_empty_data(self): - """This tests that the data object is created without throwing an error, see #257.""" + """This tests that empty datasets are not allowed""" with pytest.raises(SPYValueError, match='non empty'): - data = SpikeData(np.column_stack(([],[],[])), dimord = ['sample', 'channel', 'unit'], samplerate = 30000) + data = SpikeData(np.column_stack(([],[],[])).astype(int), + dimord=['sample', 'channel', 'unit'], + samplerate=30000) def test_nparray(self): dummy = SpikeData(self.data) @@ -147,92 +169,6 @@ def test_saveload(self): del dummy, dummy2 time.sleep(0.1) - # test data-selection via class method - def test_dataselection(self): - - # Create testing objects (regular and swapped dimords) - dummy = SpikeData(data=self.data, - trialdefinition=self.trl, - samplerate=2.0) - ymmud = SpikeData(data=self.data[:, ::-1], - trialdefinition=self.trl, - samplerate=2.0, - dimord=dummy.dimord[::-1]) - - # selections are chosen so that result is not empty - trialSelections = [ - "all", # enforce below selections in all trials of `dummy` - [3, 1] # minimally unordered - ] - chanSelections = [ - ["channel03", "channel01", "channel01", "channel02"], # string selection w/repetition + unordered - [4, 2, 2, 5, 5], # repetition + unorderd - range(5, 8), # narrow range - ] - latencySelections = [ - [0.5, 2.5], # regular range - [1.0, 2] # recued range - ] - unitSelections = [ - ["unit1", "unit1", "unit2", "unit3"], # preserve repetition - [0, 0, 2, 3], # preserve repetition, don't convert to slice - range(1, 4), # narrow range - ] - - timeSelections = list(zip(["latency"] * len(latencySelections), latencySelections)) - - trialSels = [random.choice(trialSelections)] - chanSels = [random.choice(chanSelections)] - unitSels = [random.choice(unitSelections)] - timeSels = [random.choice(timeSelections)] - - for obj in [dummy, ymmud]: - chanIdx = obj.dimord.index("channel") - unitIdx = obj.dimord.index("unit") - chanArr = np.arange(obj.channel.size) - for trialSel in trialSels: - for chanSel in chanSels: - for unitSel in unitSels: - for timeSel in timeSels: - kwdict = {} - kwdict["trials"] = trialSel - kwdict["channel"] = chanSel - kwdict["unit"] = unitSel - kwdict[timeSel[0]] = timeSel[1] - cfg = StructDict(kwdict) - # data selection via class-method + `Selector` instance for indexing - - selected = obj.selectdata(**kwdict) - obj.selectdata(**kwdict, inplace=True) - selector = obj.selection - tk = 0 - for trialno in selector.trial_ids: - if selector.time[tk]: - assert np.array_equal(obj.trials[trialno][selector.time[tk], :], - selected.trials[tk]) - tk += 1 - assert set(selected.data[:, chanIdx]).issubset(chanArr[selector.channel]) - assert set(selected.channel) == set(obj.channel[selector.channel]) - # only if we got sth - if np.size(selected.unit) > 0: - assert np.array_equal(selected.unit, - obj.unit[np.unique(selected.data[:, unitIdx])]) - cfg.data = obj - # data selection via package function and `cfg`: ensure equality - out = selectdata(cfg) - assert np.array_equal(out.channel, selected.channel) - assert np.array_equal(out.unit, selected.unit) - assert np.array_equal(out.data, selected.data) - - def test_parallel(self, testcluster): - # repeat selected test w/parallel processing engine - client = dd.Client(testcluster) - par_tests = ["test_dataselection"] - for test in par_tests: - getattr(self, test)() - flush_local_cluster(testcluster) - client.close() - class TestEventData(): @@ -240,7 +176,7 @@ class TestEventData(): nc = 10 ns = 30 data = np.vstack([np.arange(0, ns, 5), - np.zeros((int(ns / 5), ))]).T + np.zeros((int(ns / 5), ))]).T.astype(int) data[1::2, 1] = 1 data2 = data.copy() data2[:, -1] = data[:, 0] @@ -451,7 +387,7 @@ def test_ed_trialsetting(self): # Extend data and provoke an exception due to out of bounds error smp = np.vstack([np.arange(self.ns, int(2.5 * self.ns), 5), - np.zeros((int((1.5 * self.ns) / 5),))]).T + np.zeros((int((1.5 * self.ns) / 5),))]).T.astype(int) smp[1::2, 1] = 1 smp = np.hstack([smp, smp]) data4 = np.vstack([data3, smp]) @@ -581,3 +517,4 @@ def test_ed_parallel(self, testcluster): if __name__ == '__main__': T1 = TestSpikeData() + T2 = TestEventData() From 39991f205a648dfb73d35bc3d706b3fb91832360 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 17:50:03 +0100 Subject: [PATCH 107/135] FIX: Use integer type data for discrete data tests Changes to be committed: modified: syncopy/tests/test_basedata.py modified: syncopy/tests/test_spyio.py --- syncopy/tests/test_basedata.py | 4 ++-- syncopy/tests/test_spyio.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/syncopy/tests/test_basedata.py b/syncopy/tests/test_basedata.py index ac4de81a7..586fdd75e 100644 --- a/syncopy/tests/test_basedata.py +++ b/syncopy/tests/test_basedata.py @@ -61,12 +61,12 @@ class TestBaseData(): seed = np.random.RandomState(13) data["SpikeData"] = np.vstack([seed.choice(nSamples, size=nSpikes), seed.choice(nChannels, size=nSpikes), - seed.choice(int(nChannels / 2), size=nSpikes)]).T + seed.choice(int(nChannels / 2), size=nSpikes)]).T.astype(int) trl["SpikeData"] = trl["AnalogData"] # Use a simple binary trigger pattern to simulate EventData data["EventData"] = np.vstack([np.arange(0, nSamples, 5), - np.zeros((int(nSamples / 5), ))]).T + np.zeros((int(nSamples / 5), ))]).T.astype(int) data["EventData"][1::2, 1] = 1 trl["EventData"] = trl["AnalogData"] diff --git a/syncopy/tests/test_spyio.py b/syncopy/tests/test_spyio.py index 4d8a8d5f8..8e57ec733 100644 --- a/syncopy/tests/test_spyio.py +++ b/syncopy/tests/test_spyio.py @@ -35,6 +35,7 @@ skip_no_esi = pytest.mark.skipif(not on_esi, reason="ESI fs not available") skip_no_nwb = pytest.mark.skipif(not spy.__nwb__, reason="pynwb not installed") + class TestSpyIO(): # Allocate test-datasets for AnalogData, SpectralData, SpikeData and EventData objects @@ -65,12 +66,12 @@ class TestSpyIO(): seed = np.random.RandomState(13) data["SpikeData"] = np.vstack([seed.choice(ns, size=nd), seed.choice(nc, size=nd), - seed.choice(int(nc / 2), size=nd)]).T + seed.choice(int(nc / 2), size=nd)]).T.astype(int) trl["SpikeData"] = trl["AnalogData"] # Generate bogus trigger timings data["EventData"] = np.vstack([np.arange(0, ns, 5), - np.zeros((int(ns / 5), ))]).T + np.zeros((int(ns / 5), ))]).T.astype(int) data["EventData"][1::2, 1] = 1 trl["EventData"] = trl["AnalogData"] @@ -613,6 +614,7 @@ def test_load_nwb(self): if __name__ == '__main__': + T0 = TestSpyIO() T1 = TestFTImporter() T2 = TestTDTImporter() T3 = TestNWBImporter() From a7d7893ec0a1c9f93b24eaf7f9f07af936eb2d5c Mon Sep 17 00:00:00 2001 From: tensionhead Date: Tue, 17 Jan 2023 18:29:28 +0100 Subject: [PATCH 108/135] CHG: Increase coverage Changes to be committed: modified: syncopy/datatype/discrete_data.py modified: syncopy/tests/test_discretedata.py --- syncopy/datatype/discrete_data.py | 25 +++++++++---------------- syncopy/tests/test_discretedata.py | 10 ++++++++++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 22eb70b6c..5829ea017 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -57,7 +57,7 @@ def data(self, inData): if inData is not None: # probably not the most elegant way.. - if not 'int' in str(self.data.dtype): + if 'int' not in str(self.data.dtype): raise SPYTypeError(self.data.dtype, 'data', "integer like") def __str__(self): @@ -395,15 +395,11 @@ def _default_channel_labels(self): Creates the default channel labels """ - if self.data is not None: - # channel entries in self.data are 0-based - chan_max = self.channel_idx.max() - channel_labels = np.array(["channel" + str(int(i + 1)).zfill(len(str(chan_max)) + 1) - for i in self.channel_idx]) - return channel_labels - - else: - return None + # channel entries in self.data are 0-based + chan_max = self.channel_idx.max() + channel_labels = np.array(["channel" + str(int(i + 1)).zfill(len(str(chan_max)) + 1) + for i in self.channel_idx]) + return channel_labels @property def unit(self): @@ -453,12 +449,9 @@ def _default_unit_labels(self): Creates the default unit labels """ - if self.data is not None: - unit_max = self.unit_idx.max() - return np.array(["unit" + str(int(i + 1)).zfill(len(str(unit_max)) + 1) - for i in self.unit_idx]) - else: - return None + unit_max = self.unit_idx.max() + return np.array(["unit" + str(int(i + 1)).zfill(len(str(unit_max)) + 1) + for i in self.unit_idx]) # Helper function that extracts by-trial unit-indices def _get_unit(self, trials, units=None): diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index e163b15a0..9c9be7826 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -55,6 +55,16 @@ def test_init(self): assert isinstance(SpikeData(data=np.ones((2, 3), dtype=int), channel=['only_channel']), SpikeData) + # --- invalid inits --- + + # non-integer types + with pytest.raises(SPYTypeError, match='expected integer like'): + _ = SpikeData(data=np.ones((2, 3)), unit=['unit1', 'unit2']) + + with pytest.raises(SPYTypeError, match='expected integer like'): + data = np.array([np.nan, 2, np.nan])[:, np.newaxis] + _ = SpikeData(data=data, unit=['unit1', 'unit2']) + # data and too many labels with pytest.raises(SPYValueError, match='expected exactly 1 unit'): _ = SpikeData(data=np.ones((2, 3), dtype=int), unit=['unit1', 'unit2']) From ff23719da54d2aa2c06a4f8792e5c4aee9bce40f Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 18 Jan 2023 07:48:32 +0100 Subject: [PATCH 109/135] FIX: Remove incorrect test --- syncopy/tests/test_discretedata.py | 63 ------------------------------ 1 file changed, 63 deletions(-) diff --git a/syncopy/tests/test_discretedata.py b/syncopy/tests/test_discretedata.py index 9a668f464..9a32eee6b 100644 --- a/syncopy/tests/test_discretedata.py +++ b/syncopy/tests/test_discretedata.py @@ -506,69 +506,6 @@ def test_ed_trialsetting(self): with pytest.raises(SPYValueError): ang_dummy.definetrial(evt_dummy, pre=pre, post=post, trigger=1) - # test data-selection via class method - def test_ed_dataselection(self): - - # Create testing objects (regular and swapped dimords) - dummy = EventData(data=np.hstack([self.data, self.data]), - dimord=self.customDimord, - trialdefinition=self.trl, - samplerate=2.0) - ymmud = EventData(data=np.hstack([self.data[:, ::-1], self.data[:, ::-1]]), - trialdefinition=self.trl, - samplerate=2.0, - dimord=dummy.dimord[::-1]) - - # selections are chosen so that result is not empty - trialSelections = [ - "all", # enforce below selections in all trials of `dummy` - [3, 1] # minimally unordered - ] - - eventidSelections = [ - [0, 0, 1], # preserve repetition, don't convert to slice - range(0, 2), # narrow range - ] - - latencySelections = [ - [0.5, 2.5], # regular range - [0.7, 2.] # reduce range - ] - - timeSelections = list(zip(["latency"] * len(latencySelections), latencySelections)) - - trialSels = [random.choice(trialSelections)] - eventidSels = [random.choice(eventidSelections)] - timeSels = [random.choice(timeSelections)] - - for obj in [dummy, ymmud]: - eventidIdx = obj.dimord.index("eventid") - for trialSel in trialSels: - for eventidSel in eventidSels: - for timeSel in timeSels: - kwdict = {} - kwdict["trials"] = trialSel - kwdict["eventid"] = eventidSel - kwdict[timeSel[0]] = timeSel[1] - cfg = StructDict(kwdict) - # data selection via class-method + `Selector` instance for indexing - selected = obj.selectdata(**kwdict) - obj.selectdata(**kwdict, inplace=True) - selector = obj.selection - tk = 0 - for trialno in selector.trial_ids: - if selector.time[tk]: - assert np.array_equal(obj.trials[trialno][selector.time[tk], :], - selected.trials[tk]) - tk += 1 - assert np.array_equal(selected.eventid, - obj.eventid[np.unique(selected.data[:, eventidIdx]).astype(np.intp)]) - cfg.data = obj - # data selection via package function and `cfg`: ensure equality - out = selectdata(cfg) - assert np.array_equal(out.eventid, selected.eventid) - assert np.array_equal(out.data, selected.data) - def test_ed_parallel(self, testcluster): # repeat selected test w/parallel processing engine client = dd.Client(testcluster) From e6bfe1f09f8baf5504585f8c07352867aff5c47f Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 18 Jan 2023 09:54:53 +0100 Subject: [PATCH 110/135] FIX: Minor renaming --- syncopy/datatype/discrete_data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 5829ea017..e12a63efa 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -57,7 +57,7 @@ def data(self, inData): if inData is not None: # probably not the most elegant way.. - if 'int' not in str(self.data.dtype): + if np.issubdtype(self.data.dtype, np.integer): raise SPYTypeError(self.data.dtype, 'data', "integer like") def __str__(self): @@ -161,7 +161,7 @@ def trialid(self, trlid): if self.data is None: SPYError("SyNCoPy core - trialid: Cannot assign `trialid` without data. " + - "Please assing data first") + "Please assign data first") return scount = np.nanmax(self.data[:, self.dimord.index("sample")]) try: @@ -334,7 +334,7 @@ class SpikeData(DiscreteData): _stackingDimLabel = "sample" _selectionKeyWords = DiscreteData._selectionKeyWords + ('channel', 'unit',) - def _compute_unique(self): + def _compute_unique_idx(self): """ Use `np.unique` on whole(!) dataset to compute globally available channel and unit indices only once @@ -379,7 +379,7 @@ def channel(self, chan): # the constructor was called with data=None, hence # we have to compute the unique indices here if self.channel_idx is None: - self._compute_unique() + self._compute_unique_idx() # we need as many labels as there are distinct channels nChan = self.channel_idx.size @@ -425,7 +425,7 @@ def unit(self, unit): # the constructor was called with data=None, hence # we have to compute this here if self.unit_idx is None: - self._compute_unique() + self._compute_unique_idx() if unit is None and self.data is not None: raise SPYValueError("Cannot set `unit` to `None` with existing data.") @@ -557,7 +557,7 @@ def __init__(self, dimord=dimord) # for fast lookup and labels - self._compute_unique() + self._compute_unique_idx() # constructor gets `data=None` for # empty inits, selections and loading from file From 992e1773945d413186b8540b373e94356debb519 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 18 Jan 2023 09:56:48 +0100 Subject: [PATCH 111/135] FIX: missing not --- syncopy/datatype/discrete_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index e12a63efa..9c7814065 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -57,7 +57,7 @@ def data(self, inData): if inData is not None: # probably not the most elegant way.. - if np.issubdtype(self.data.dtype, np.integer): + if not np.issubdtype(self.data.dtype, np.integer): raise SPYTypeError(self.data.dtype, 'data', "integer like") def __str__(self): From 2917162c8fc3904643a0aeb97f51d76ba8a3aa91 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 18 Jan 2023 10:38:19 +0100 Subject: [PATCH 112/135] FIX: Remove old comment --- syncopy/datatype/discrete_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 9c7814065..425588bde 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -56,7 +56,6 @@ def data(self, inData): self._set_dataset_property(inData, "data") if inData is not None: - # probably not the most elegant way.. if not np.issubdtype(self.data.dtype, np.integer): raise SPYTypeError(self.data.dtype, 'data', "integer like") From 6d6dce73d793525ffa1b8a98b5040a53ba95a406 Mon Sep 17 00:00:00 2001 From: KatharineShapcott <65502584+KatharineShapcott@users.noreply.github.com> Date: Wed, 18 Jan 2023 11:30:33 +0100 Subject: [PATCH 113/135] FIX: Remove unique from sample --- syncopy/datatype/discrete_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/syncopy/datatype/discrete_data.py b/syncopy/datatype/discrete_data.py index 78162a958..9c8e4b53b 100644 --- a/syncopy/datatype/discrete_data.py +++ b/syncopy/datatype/discrete_data.py @@ -115,9 +115,7 @@ def sample(self): """Indices of all recorded samples""" if self.data is None: return None - # return self.data[:, self.dimord.index("sample")] - # there should be only one event per sample number?! - return np.unique(self.data[:, self.dimord.index("sample")]) + return self.data[:, self.dimord.index("sample")] @property def samplerate(self): From 4ec7eacc97887d5e553c8a7092f614c138422621 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Wed, 18 Jan 2023 14:12:06 +0100 Subject: [PATCH 114/135] Update Changelog Changes to be committed: modified: CHANGELOG.md --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29f9222d9..8a9884f51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,11 @@ All notable changes to this project will be documented in this file. ### NEW ### CHANGED +- major performance improvements for DiscreteData #403 #418, #424 ### Fixed - fix bug #394 'Copying a spy.StructDict returns a dict'. -- serializable `.cfg` #392 +- serializable `.cfg` #392 ## [2022.12] From 73a5fe1f56fd3c47d95ffb5abd353affdfd4eee3 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Wed, 18 Jan 2023 15:55:51 +0100 Subject: [PATCH 115/135] CHG: move logging setup to function --- syncopy/__init__.py | 66 +++------------------------------------ syncopy/shared/log.py | 72 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 62 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index df59a6fec..9a1416062 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -104,68 +104,6 @@ else: __storage__ = os.path.join(os.path.expanduser("~"), ".spy", "tmp_storage") -# Setup logging. -if os.environ.get("SPYLOGDIR"): - __logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) -else: - if os.path.exists(csHome): - __logdir__ = os.path.join(csHome, ".spy", "logs") - else: - __logdir__ = os.path.join(os.path.expanduser("~"), ".spy", "logs") - -if not os.path.exists(__logdir__): - os.makedirs(__logdir__, exist_ok=True) - -loglevel = os.getenv("SPYLOGLEVEL", "WARNING") -numeric_level = getattr(logging, loglevel.upper(), None) -if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. - warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") - loglevel = "WARNING" - -# The logger for local/sequential stuff -- goes to terminal and to a file. -spy_logger = logging.getLogger('syncopy') -fmt = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') -sh = logging.StreamHandler(sys.stdout) -sh.setFormatter(fmt) -spy_logger.addHandler(sh) - -logfile = os.path.join(__logdir__, f'syncopy.log') -fh = logging.FileHandler(logfile) # The default mode is 'append'. -fh.setFormatter(fmt) -spy_logger.addHandler(fh) - - -spy_logger.setLevel(loglevel) -spy_logger.debug(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") -spy_logger.info(f"Syncopy log level set to: {loglevel}.") - -# Log to per-host files in parallel code by default. -# Note that this setup handles only the logger of the current host. -parloglevel = os.getenv("SPYPARLOGLEVEL", loglevel) -numeric_level = getattr(logging, parloglevel.upper(), None) -if not isinstance(numeric_level, int): # An invalid string was set as the env variable, use default. - warnings.warn("Invalid log level set in environment variable 'SPYPARLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") - parloglevel = "WARNING" -host = platform.node() -parallel_logger_name = "syncopy_" + host -spy_parallel_logger = logging.getLogger(parallel_logger_name) - -class HostnameFilter(logging.Filter): - hostname = platform.node() - - def filter(self, record): - record.hostname = HostnameFilter.hostname - return True - -logfile_par = os.path.join(__logdir__, f'syncopy_{host}.log') -fhp = logging.FileHandler(logfile_par) # The default mode is 'append'. -fhp.addFilter(HostnameFilter()) -spy_parallel_logger.setLevel(parloglevel) -fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') -fhp.setFormatter(fmt_with_hostname) -spy_parallel_logger.addHandler(fhp) -spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") - # Set upper bound for temp directory size (in GB) __storagelimit__ = 10 @@ -195,6 +133,10 @@ def filter(self, record): from .plotting import * from .preproc import * +from .shared.log import setup_logging +setup_logging() + + # Register session __session__ = datatype.util.SessionLogger() diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 8042fb295..bddf9f521 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -5,15 +5,87 @@ # Note: The logging setup is done in the top-level `__init.py__` file. import os +import sys import logging import socket import syncopy import warnings +import datetime +import platform +import getpass +import syncopy loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). loglevels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] +def setup_logging(): + print("setting up logging") + # Setup logging. + csHome = "/cs/home/{}".format(getpass.getuser()) + if os.environ.get("SPYLOGDIR"): + syncopy.__logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) + else: + if os.path.exists(csHome): + syncopy.__logdir__ = os.path.join(csHome, ".spy", "logs") + else: + syncopy.__logdir__ = os.path.join(os.path.expanduser("~"), ".spy", "logs") + + if not os.path.exists(syncopy.__logdir__): + os.makedirs(syncopy.__logdir__, exist_ok=True) + + loglevel = os.getenv("SPYLOGLEVEL", "WARNING") + numeric_level = getattr(logging, loglevel.upper(), None) + if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. + warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + loglevel = "WARNING" + + # The logger for local/sequential stuff -- goes to terminal and to a file. + spy_logger = logging.getLogger('syncopy') + fmt = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(fmt) + spy_logger.addHandler(sh) + + logfile = os.path.join(syncopy.__logdir__, f'syncopy.log') + fh = logging.FileHandler(logfile) # The default mode is 'append'. + fh.setFormatter(fmt) + spy_logger.addHandler(fh) + + + spy_logger.setLevel(loglevel) + spy_logger.debug(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") + spy_logger.info(f"Syncopy log level set to: {loglevel}.") + + # Log to per-host files in parallel code by default. + # Note that this setup handles only the logger of the current host. + parloglevel = os.getenv("SPYPARLOGLEVEL", loglevel) + numeric_level = getattr(logging, parloglevel.upper(), None) + if not isinstance(numeric_level, int): # An invalid string was set as the env variable, use default. + warnings.warn("Invalid log level set in environment variable 'SPYPARLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + parloglevel = "WARNING" + host = platform.node() + parallel_logger_name = "syncopy_" + host + spy_parallel_logger = logging.getLogger(parallel_logger_name) + + class HostnameFilter(logging.Filter): + hostname = platform.node() + + def filter(self, record): + record.hostname = HostnameFilter.hostname + return True + + logfile_par = os.path.join(syncopy.__logdir__, f'syncopy_{host}.log') + fhp = logging.FileHandler(logfile_par) # The default mode is 'append'. + fhp.addFilter(HostnameFilter()) + spy_parallel_logger.setLevel(parloglevel) + fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') + fhp.setFormatter(fmt_with_hostname) + spy_parallel_logger.addHandler(fhp) + spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") + + + def get_logger(): """Get the syncopy root logger. From 50e87f1c77d45a63f572699783316daf140a0230 Mon Sep 17 00:00:00 2001 From: kajal5888 Date: Thu, 19 Jan 2023 09:49:04 +0100 Subject: [PATCH 116/135] replay test for configuration added --- syncopy/tests/test_connectivity.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/syncopy/tests/test_connectivity.py b/syncopy/tests/test_connectivity.py index 3979202c8..2ff2b7f05 100644 --- a/syncopy/tests/test_connectivity.py +++ b/syncopy/tests/test_connectivity.py @@ -510,6 +510,7 @@ class TestCSD: nChannels = 4 nTrials = 100 fs = 1000 + Method = 'csd' # -- two harmonics with individual phase diffusion -- @@ -553,12 +554,18 @@ def test_parallel(self): def test_csd_input(self): assert isinstance(self.spec, SpectralData) - def test_csd_cfg(self): - Method = 'csd' - cross_spec = spy.connectivityanalysis(self.spec, method=Method) + def test_csd_cfg_replay(self): + cross_spec = spy.connectivityanalysis(self.spec, method=self.Method) assert len(cross_spec.cfg) == 2 assert np.all([True for cfg in zip(self.spec.cfg['freqanalysis'], cross_spec.cfg['freqanalysis']) if cfg[0] == cfg[1]]) - assert cross_spec.cfg['connectivityanalysis'].method == Method + assert cross_spec.cfg['connectivityanalysis'].method == self.Method + + first_cfg = cross_spec.cfg['connectivityanalysis'] + first_res = spy.connectivityanalysis(self.spec, cfg=first_cfg) + replay_res = spy.connectivityanalysis(self.spec, cfg=first_res.cfg) + + assert np.allclose(first_res.data[:], replay_res.data[:]) + assert first_res.cfg == replay_res.cfg class TestCorrelation: From d3d896026a1ac6d624c735700f64777023437114 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Fri, 20 Jan 2023 12:51:09 +0100 Subject: [PATCH 117/135] CHG: Print log path on package init Changes to be committed: modified: syncopy/__init__.py modified: syncopy/shared/log.py --- syncopy/__init__.py | 8 +++++++- syncopy/shared/log.py | 10 ++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 9a1416062..89384de73 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -135,7 +135,13 @@ from .shared.log import setup_logging setup_logging() - +# do not spam via worker imports +try: + dd.get_client() +except ValueError: + silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") + if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): + print(f"logging to {__logdir__}\n") # Register session __session__ = datatype.util.SessionLogger() diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index bddf9f521..65422de7f 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -13,15 +13,17 @@ import datetime import platform import getpass -import syncopy loggername = "syncopy" # Since this is a library, we should not use the root logger (see Python logging docs). loglevels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] + def setup_logging(): - print("setting up logging") + # Setup logging. + + # default path ONLY relevant for ESI Frankfurt csHome = "/cs/home/{}".format(getpass.getuser()) if os.environ.get("SPYLOGDIR"): syncopy.__logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) @@ -52,7 +54,6 @@ def setup_logging(): fh.setFormatter(fmt) spy_logger.addHandler(fh) - spy_logger.setLevel(loglevel) spy_logger.debug(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") spy_logger.info(f"Syncopy log level set to: {loglevel}.") @@ -82,10 +83,11 @@ def filter(self, record): fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') fhp.setFormatter(fmt_with_hostname) spy_parallel_logger.addHandler(fhp) + sh = logging.StreamHandler(sys.stdout) + spy_parallel_logger.addHandler(sh) spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") - def get_logger(): """Get the syncopy root logger. From 65d83134805273c96bad2809560ef17d996e5570 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 15:45:24 +0100 Subject: [PATCH 118/135] CHG: minor change to log dir info message --- syncopy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 89384de73..841129495 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -141,7 +141,7 @@ except ValueError: silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): - print(f"logging to {__logdir__}\n") + print(f"Logging to log directory '{__logdir__}'.\n") # Note the __logdir__ is set in the call to setup_logging above. # Register session __session__ = datatype.util.SessionLogger() From 6310382187fc8b8b8a8a625372ce6bdf1212ef24 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 16:07:02 +0100 Subject: [PATCH 119/135] NEW: support setting both tempfir and logdir via SPYDIR env var --- syncopy/__init__.py | 21 +++++++++++++++------ syncopy/shared/log.py | 10 +++------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 841129495..222c66cf6 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -96,14 +96,23 @@ # Set package-wide temp directory csHome = "/cs/home/{}".format(getpass.getuser()) +if os.environ.get("SPYDIR"): + spydir = os.path.abspath(os.path.expanduser(os.environ["SPYDIR"])) + if not os.path.exists(spydir): + raise ValueError(f"Environment variable SPYDIR set to non-existent or unreadable directory '{spydir}'. Please unset SPYDIR or create the directory.") +else: + if os.path.exists(csHome): # ESI cluster. + spydir = os.path.join(csHome, ".spy") + else: + spydir = os.path.abspath(os.path.join(os.path.expanduser("~"), ".spy")) + if os.environ.get("SPYTMPDIR"): __storage__ = os.path.abspath(os.path.expanduser(os.environ["SPYTMPDIR"])) else: - if os.path.exists(csHome): - __storage__ = os.path.join(csHome, ".spy", "tmp_storage") - else: - __storage__ = os.path.join(os.path.expanduser("~"), ".spy", "tmp_storage") + __storage__ = os.path.join(spydir, "tmp_storage") +if not os.path.exists(spydir): + os.makedirs(spydir, exist_ok=True) # Set upper bound for temp directory size (in GB) __storagelimit__ = 10 @@ -134,14 +143,14 @@ from .preproc import * from .shared.log import setup_logging -setup_logging() +setup_logging(spydir=spydir) # do not spam via worker imports try: dd.get_client() except ValueError: silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): - print(f"Logging to log directory '{__logdir__}'.\n") # Note the __logdir__ is set in the call to setup_logging above. + print(f"Logging to log directory '{__logdir__}'.\nTemporary storage directory set to '{__storage__}'.\n") # The __logdir__ is set in the call to setup_logging above. # Register session __session__ = datatype.util.SessionLogger() diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 65422de7f..80154fa38 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -19,17 +19,13 @@ loglevels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] -def setup_logging(): +def setup_logging(spydir=None): - # Setup logging. - - # default path ONLY relevant for ESI Frankfurt - csHome = "/cs/home/{}".format(getpass.getuser()) if os.environ.get("SPYLOGDIR"): syncopy.__logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) else: - if os.path.exists(csHome): - syncopy.__logdir__ = os.path.join(csHome, ".spy", "logs") + if spydir is not None: + syncopy.__logdir__ = os.path.join(spydir, "logs") else: syncopy.__logdir__ = os.path.join(os.path.expanduser("~"), ".spy", "logs") From d25fc84ac7e1eb3fe1629a2b00753f3949aa070c Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 16:19:31 +0100 Subject: [PATCH 120/135] CHG: refactor to reduce code duplication --- syncopy/__init__.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 222c66cf6..e30a29684 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -40,19 +40,23 @@ # --- Greeting --- +def startup_print_once(message): + """Print message once: do not spam message n times during all n worker imports.""" + try: + dd.get_client() + except ValueError: + silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") + if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): + print(message) + + msg = f""" Syncopy {__version__} See https://syncopy.org for the online documentation. For bug reports etc. please send an email to syncopy@esi-frankfurt.de """ -# do not spam via worker imports -try: - dd.get_client() -except ValueError: - silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") - if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): - print(msg) +startup_print_once(msg) # Set up sensible printing options for NumPy arrays np.set_printoptions(suppress=True, precision=4, linewidth=80) @@ -127,7 +131,7 @@ # Set checksum algorithm to be used __checksum_algorithm__ = sha1 -# Fill up namespace +# Fill namespace from . import ( shared, io, @@ -143,14 +147,8 @@ from .preproc import * from .shared.log import setup_logging -setup_logging(spydir=spydir) -# do not spam via worker imports -try: - dd.get_client() -except ValueError: - silence_file = os.path.join(os.path.expanduser("~"), ".spy", "silentstartup") - if os.getenv("SPYSILENTSTARTUP") is None and not os.path.isfile(silence_file): - print(f"Logging to log directory '{__logdir__}'.\nTemporary storage directory set to '{__storage__}'.\n") # The __logdir__ is set in the call to setup_logging above. +setup_logging(spydir=spydir) # Sets __logdir__. +startup_print_once(f"Logging to log directory '{__logdir__}'.\nTemporary storage directory set to '{__storage__}'.\n") # Register session __session__ = datatype.util.SessionLogger() From ad65b9ad4c42ef83669891bfbdde5e4a2f4aef29 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 16:47:26 +0100 Subject: [PATCH 121/135] CHG: more logging in parallel mode, separate format --- syncopy/shared/log.py | 104 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 86 insertions(+), 18 deletions(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 80154fa38..dc66cedc2 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -20,6 +20,9 @@ def setup_logging(spydir=None): + """Setup logging on module initialization (in the module root level '__init__.py' file). Should not be called elsewhere.""" + + _addLoggingLevel('IMPORTANT', logging.INFO - 5) # Add a new custom log level named 'IMPORTANT' between DEBUG and INFO. if os.environ.get("SPYLOGDIR"): syncopy.__logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) @@ -32,22 +35,39 @@ def setup_logging(spydir=None): if not os.path.exists(syncopy.__logdir__): os.makedirs(syncopy.__logdir__, exist_ok=True) - loglevel = os.getenv("SPYLOGLEVEL", "WARNING") + loglevel = os.getenv("SPYLOGLEVEL", "IMPORTANT") numeric_level = getattr(logging, loglevel.upper(), None) - if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to WARNING. - warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") - loglevel = "WARNING" + if not isinstance(numeric_level, int): # An invalid string was set as the env variable, default to IMPORTANT. + warnings.warn("Invalid log level set in environment variable 'SPYLOGLEVEL', ignoring and using IMPORTANT instead. Hint: Set one of 'DEBUG', 'IMPORTANT', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + loglevel = "IMPORTANT" + + class HostnameFilter(logging.Filter): + hostname = platform.node() + + def filter(self, record): + record.hostname = HostnameFilter.hostname + return True + + class SessionFilter(logging.Filter): + session = syncopy.__session__ + + def filter(self, record): + record.session = self.session + return True # The logger for local/sequential stuff -- goes to terminal and to a file. spy_logger = logging.getLogger('syncopy') - fmt = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') + fmt_interactive = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') # Interactive formatter: no hostname and session info (less clutter on terminal). + fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s - %(session)s: %(message)s') # Log file formatter: with hostname and session info. sh = logging.StreamHandler(sys.stdout) - sh.setFormatter(fmt) + sh.setFormatter(fmt_interactive) spy_logger.addHandler(sh) logfile = os.path.join(syncopy.__logdir__, f'syncopy.log') fh = logging.FileHandler(logfile) # The default mode is 'append'. - fh.setFormatter(fmt) + fh.addFilter(HostnameFilter()) + fh.addFilter(SessionFilter()) + fh.setFormatter(fmt_with_hostname) spy_logger.addHandler(fh) spy_logger.setLevel(loglevel) @@ -56,34 +76,82 @@ def setup_logging(spydir=None): # Log to per-host files in parallel code by default. # Note that this setup handles only the logger of the current host. - parloglevel = os.getenv("SPYPARLOGLEVEL", loglevel) + parloglevel = os.getenv("SPYPARLOGLEVEL", "INFO") numeric_level = getattr(logging, parloglevel.upper(), None) if not isinstance(numeric_level, int): # An invalid string was set as the env variable, use default. - warnings.warn("Invalid log level set in environment variable 'SPYPARLOGLEVEL', ignoring and using WARNING instead. Hint: Set one of 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") - parloglevel = "WARNING" + warnings.warn("Invalid log level set in environment variable 'SPYPARLOGLEVEL', ignoring and using IMPORTANT instead. Hint: Set one of 'DEBUG', 'IMPORTANT', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") + parloglevel = "IMPORTANT" host = platform.node() parallel_logger_name = "syncopy_" + host spy_parallel_logger = logging.getLogger(parallel_logger_name) - class HostnameFilter(logging.Filter): - hostname = platform.node() - - def filter(self, record): - record.hostname = HostnameFilter.hostname - return True - logfile_par = os.path.join(syncopy.__logdir__, f'syncopy_{host}.log') fhp = logging.FileHandler(logfile_par) # The default mode is 'append'. fhp.addFilter(HostnameFilter()) + fhp.addFilter(SessionFilter()) spy_parallel_logger.setLevel(parloglevel) - fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s: %(message)s') + fhp.setFormatter(fmt_with_hostname) spy_parallel_logger.addHandler(fhp) sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(fmt_interactive) + spy_parallel_logger.addHandler(sh) spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") +# See https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945 +def _addLoggingLevel(levelName, levelNum, methodName=None): + """ + Comprehensively adds a new logging level to the `logging` module and the + currently configured logging class. + + `levelName` becomes an attribute of the `logging` module with the value + `levelNum`. `methodName` becomes a convenience method for both `logging` + itself and the class returned by `logging.getLoggerClass()` (usually just + `logging.Logger`). If `methodName` is not specified, `levelName.lower()` is + used. + + To avoid accidental clobberings of existing attributes, this method will + raise an `AttributeError` if the level name is already an attribute of the + `logging` module or if the method name is already present + + Example + ------- + >>> addLoggingLevel('TRACE', logging.DEBUG - 5) + >>> logging.getLogger(__name__).setLevel("TRACE") + >>> logging.getLogger(__name__).trace('that worked') + >>> logging.trace('so did this') + >>> logging.TRACE + 5 + + """ + if not methodName: + methodName = levelName.lower() + + if hasattr(logging, levelName): + raise AttributeError('{} already defined in logging module'.format(levelName)) + if hasattr(logging, methodName): + raise AttributeError('{} already defined in logging module'.format(methodName)) + if hasattr(logging.getLoggerClass(), methodName): + raise AttributeError('{} already defined in logger class'.format(methodName)) + + # This method was inspired by the answers to Stack Overflow post + # http://stackoverflow.com/q/2183233/2988730, especially + # http://stackoverflow.com/a/13638084/2988730 + def logForLevel(self, message, *args, **kwargs): + if self.isEnabledFor(levelNum): + self._log(levelNum, message, args, **kwargs) + + def logToRoot(message, *args, **kwargs): + logging.log(levelNum, message, *args, **kwargs) + + logging.addLevelName(levelNum, levelName) + setattr(logging, levelName, levelNum) + setattr(logging.getLoggerClass(), methodName, logForLevel) + setattr(logging, methodName, logToRoot) + + def get_logger(): """Get the syncopy root logger. From cc3aa9e86b84369f9703e15b423996a52bafdbd1 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 16:57:29 +0100 Subject: [PATCH 122/135] NEW: add new IMPORTANT log level, set as default for loggers. --- syncopy/shared/log.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index dc66cedc2..043914d6a 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -56,7 +56,8 @@ def filter(self, record): return True # The logger for local/sequential stuff -- goes to terminal and to a file. - spy_logger = logging.getLogger('syncopy') + logger_name = 'syncopy' + spy_logger = logging.getLogger(logger_name) fmt_interactive = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') # Interactive formatter: no hostname and session info (less clutter on terminal). fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s - %(session)s: %(message)s') # Log file formatter: with hostname and session info. sh = logging.StreamHandler(sys.stdout) @@ -71,12 +72,12 @@ def filter(self, record): spy_logger.addHandler(fh) spy_logger.setLevel(loglevel) - spy_logger.debug(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") - spy_logger.info(f"Syncopy log level set to: {loglevel}.") + spy_logger.info(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") + spy_logger.important(f"Syncopy logger '{logger_name}' setup to log to file '{logfile}' at level {loglevel}.") # Log to per-host files in parallel code by default. # Note that this setup handles only the logger of the current host. - parloglevel = os.getenv("SPYPARLOGLEVEL", "INFO") + parloglevel = os.getenv("SPYPARLOGLEVEL", "IMPORTANT") numeric_level = getattr(logging, parloglevel.upper(), None) if not isinstance(numeric_level, int): # An invalid string was set as the env variable, use default. warnings.warn("Invalid log level set in environment variable 'SPYPARLOGLEVEL', ignoring and using IMPORTANT instead. Hint: Set one of 'DEBUG', 'IMPORTANT', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'.") @@ -97,7 +98,7 @@ def filter(self, record): sh.setFormatter(fmt_interactive) spy_parallel_logger.addHandler(sh) - spy_parallel_logger.info(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {loglevel}.") + spy_parallel_logger.important(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {parloglevel}.") # See https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945 From 137dffe91142657299a41647cb6de714990435dc Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 17:14:18 +0100 Subject: [PATCH 123/135] NEW: add convenience function set_loglevel --- syncopy/shared/log.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 043914d6a..75aef5f7a 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -56,8 +56,7 @@ def filter(self, record): return True # The logger for local/sequential stuff -- goes to terminal and to a file. - logger_name = 'syncopy' - spy_logger = logging.getLogger(logger_name) + spy_logger = logging.getLogger(loggername) fmt_interactive = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') # Interactive formatter: no hostname and session info (less clutter on terminal). fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s - %(session)s: %(message)s') # Log file formatter: with hostname and session info. sh = logging.StreamHandler(sys.stdout) @@ -73,7 +72,7 @@ def filter(self, record): spy_logger.setLevel(loglevel) spy_logger.info(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") - spy_logger.important(f"Syncopy logger '{logger_name}' setup to log to file '{logfile}' at level {loglevel}.") + spy_logger.important(f"Syncopy logger '{loggername}' setup to log to file '{logfile}' at level {loglevel}.") # Log to per-host files in parallel code by default. # Note that this setup handles only the logger of the current host. @@ -172,6 +171,22 @@ def get_parallel_logger(): return logging.getLogger(loggername + "_" + host) +def set_loglevel(level, parallel_level=None): + """ + Set log level for the loggers. + + Parameters + ---------- + level: str, one of 'DEBUG', 'IMPORTANT', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. + parallel_level: optional str (same as for 'level' above) of None. If None, the log level of the sequential logger is also used for the parallel logger. + """ + if parallel_level is None: + parallel_level = level + get_logger().setLevel(level) + get_parallel_logger().setLevel(parallel_level) + + + def delete_all_logfiles(silent=True): """Delete all '.log' files in the Syncopy logging directory. From a39bfd4ceb6bbc715468122ff2828beddca049ed Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Fri, 20 Jan 2023 17:44:11 +0100 Subject: [PATCH 124/135] FIX: pass session to log init func --- syncopy/__init__.py | 7 ++++--- syncopy/shared/log.py | 6 ++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index e30a29684..4d2c2bb32 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -146,12 +146,13 @@ def startup_print_once(message): from .plotting import * from .preproc import * +# Register session +__session__ = datatype.util.SessionLogger() + from .shared.log import setup_logging -setup_logging(spydir=spydir) # Sets __logdir__. +setup_logging(spydir=spydir, session=__session__) # Sets __logdir__. startup_print_once(f"Logging to log directory '{__logdir__}'.\nTemporary storage directory set to '{__storage__}'.\n") -# Register session -__session__ = datatype.util.SessionLogger() # Override default traceback (differentiate b/w Jupyter/iPython and regular Python) from .shared.errors import SPYExceptionHandler diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 75aef5f7a..c180a6cfb 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -19,7 +19,7 @@ loglevels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] -def setup_logging(spydir=None): +def setup_logging(spydir=None, session=""): """Setup logging on module initialization (in the module root level '__init__.py' file). Should not be called elsewhere.""" _addLoggingLevel('IMPORTANT', logging.INFO - 5) # Add a new custom log level named 'IMPORTANT' between DEBUG and INFO. @@ -49,10 +49,8 @@ def filter(self, record): return True class SessionFilter(logging.Filter): - session = syncopy.__session__ - def filter(self, record): - record.session = self.session + record.session = session return True # The logger for local/sequential stuff -- goes to terminal and to a file. From d4adf502b1121475fe37428e12fc6d2b978b9fdc Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Sat, 21 Jan 2023 12:42:41 +0100 Subject: [PATCH 125/135] FIX: change expected log level in test, do not err on custom log level already set.2 --- syncopy/shared/log.py | 3 +++ syncopy/tests/test_logging.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index c180a6cfb..e56f31dfc 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -127,6 +127,9 @@ def _addLoggingLevel(levelName, levelNum, methodName=None): if not methodName: methodName = levelName.lower() + if hasattr(logging, levelName) and hasattr(logging, methodName) and hasattr(logging.getLoggerClass(), methodName): + return # Setup already complete. + if hasattr(logging, levelName): raise AttributeError('{} already defined in logging module'.format(levelName)) if hasattr(logging, methodName): diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py index 3fb53a3f2..dff54d4de 100644 --- a/syncopy/tests/test_logging.py +++ b/syncopy/tests/test_logging.py @@ -17,10 +17,10 @@ def test_logfile_exists(self): logfile = os.path.join(spy.__logdir__, "syncopy.log") assert os.path.isfile(logfile) - def test_default_log_level_is_warning(self): + def test_default_log_level_is_important(self): # Ensure the log level is at default (that user did not change SPYLOGLEVEL on test system) - assert os.getenv("SPYLOGLEVEL", "WARNING") == "WARNING" + assert os.getenv("SPYLOGLEVEL", "IMPORTANT") == "IMPORTANT" logfile = os.path.join(spy.__logdir__, "syncopy.log") assert os.path.isfile(logfile) From 3f14e67d96229ba34ee0ec424c13da12c7efdd6e Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Sat, 21 Jan 2023 13:15:56 +0100 Subject: [PATCH 126/135] FIX: adapt numeric value of new IMPORTANT level --- syncopy/shared/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index e56f31dfc..0c43e7168 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -22,7 +22,7 @@ def setup_logging(spydir=None, session=""): """Setup logging on module initialization (in the module root level '__init__.py' file). Should not be called elsewhere.""" - _addLoggingLevel('IMPORTANT', logging.INFO - 5) # Add a new custom log level named 'IMPORTANT' between DEBUG and INFO. + _addLoggingLevel('IMPORTANT', logging.WARNING - 5) # Add a new custom log level named 'IMPORTANT' between DEBUG and INFO (int value = 25). if os.environ.get("SPYLOGDIR"): syncopy.__logdir__ = os.path.abspath(os.path.expanduser(os.environ["SPYLOGDIR"])) From 8836bbe27a4da65bbdc273354207a33b47199e8f Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Sat, 21 Jan 2023 13:53:21 +0100 Subject: [PATCH 127/135] FIX: fix typo --- syncopy/statistics/summary_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/statistics/summary_stats.py b/syncopy/statistics/summary_stats.py index 90d59077e..070daab8a 100644 --- a/syncopy/statistics/summary_stats.py +++ b/syncopy/statistics/summary_stats.py @@ -209,7 +209,7 @@ def itc(spec_data, **kwargs): raise SPYValueError(lgl, 'spec_data', act) logger = logging.getLogger("syncopy_" + platform.node()) - logger.debug(f"Computing intertrial coherence on SpectralData instancewith shape {spec_data.data.shape}.") + logger.debug(f"Computing intertrial coherence on SpectralData instance with shape {spec_data.data.shape}.") # takes care of remaining checks res = _trial_statistics(spec_data, operation='itc') From 4e7b13b924ae9050e3afb731e0827fd2f3117f94 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 23 Jan 2023 09:56:35 +0100 Subject: [PATCH 128/135] NEW: add test for par log file --- syncopy/tests/test_logging.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py index dff54d4de..f0622e0a3 100644 --- a/syncopy/tests/test_logging.py +++ b/syncopy/tests/test_logging.py @@ -4,6 +4,7 @@ # import os +import platform # Local imports import syncopy as spy @@ -13,10 +14,14 @@ class TestLogging: - def test_logfile_exists(self): + def test_seq_logfile_exists(self): logfile = os.path.join(spy.__logdir__, "syncopy.log") assert os.path.isfile(logfile) + def test_par_logfile_exists(self): + par_logfile = os.path.join(spy.__logdir__, f"syncopy_{platform.node()}.log") + assert os.path.isfile(par_logfile) + def test_default_log_level_is_important(self): # Ensure the log level is at default (that user did not change SPYLOGLEVEL on test system) From b65c43a908e3352e78a9417e0e0663736e9182a5 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 23 Jan 2023 13:15:01 +0100 Subject: [PATCH 129/135] NEW: test parallel logger --- syncopy/tests/test_logging.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py index f0622e0a3..62546a964 100644 --- a/syncopy/tests/test_logging.py +++ b/syncopy/tests/test_logging.py @@ -8,7 +8,7 @@ # Local imports import syncopy as spy -from syncopy.shared.log import get_logger +from syncopy.shared.log import get_logger, get_parallel_logger from syncopy.shared.errors import SPYLog @@ -23,7 +23,6 @@ def test_par_logfile_exists(self): assert os.path.isfile(par_logfile) def test_default_log_level_is_important(self): - # Ensure the log level is at default (that user did not change SPYLOGLEVEL on test system) assert os.getenv("SPYLOGLEVEL", "IMPORTANT") == "IMPORTANT" @@ -46,6 +45,33 @@ def test_default_log_level_is_important(self): num_lines_after_warning = sum(1 for line in open(logfile)) assert num_lines_after_warning > num_lines_after_info_debug + def test_default_parellel_log_level_is_important(self): + # Ensure the log level is at default (that user did not change SPYLOGLEVEL on test system) + assert os.getenv("SPYLOGLEVEL", "IMPORTANT") == "IMPORTANT" + assert os.getenv("SPYPARLOGLEVEL", "IMPORTANT") == "IMPORTANT" + + par_logfile = os.path.join(spy.__logdir__, f"syncopy_{platform.node()}.log") + assert os.path.isfile(par_logfile) + num_lines_initial = sum(1 for line in open(par_logfile)) # The log file gets appended, so it will most likely *not* be empty. + + # Log something with log level info and DEBUG, which should not affect the logfile. + par_logger = get_parallel_logger() + par_logger.info("I am adding an INFO level log entry.") + par_logger.debug("I am adding a IMPORTANT level log entry.") + + num_lines_after_info_debug = sum(1 for line in open(par_logfile)) + + assert num_lines_initial == num_lines_after_info_debug + + # Now log something with log level WARNING + par_logger.important("I am adding a IMPORTANT level log entry.") + par_logger.warning("This is the last warning.") + + num_lines_after_warning = sum(1 for line in open(par_logfile)) + assert num_lines_after_warning > num_lines_after_info_debug + + + From b3d82d64d67b7bb7f1d798aabe75ec4079ac3e35 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 23 Jan 2023 13:16:20 +0100 Subject: [PATCH 130/135] FIX: minor, fix typo in log message in test --- syncopy/tests/test_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncopy/tests/test_logging.py b/syncopy/tests/test_logging.py index 62546a964..01f6340ad 100644 --- a/syncopy/tests/test_logging.py +++ b/syncopy/tests/test_logging.py @@ -57,7 +57,7 @@ def test_default_parellel_log_level_is_important(self): # Log something with log level info and DEBUG, which should not affect the logfile. par_logger = get_parallel_logger() par_logger.info("I am adding an INFO level log entry.") - par_logger.debug("I am adding a IMPORTANT level log entry.") + par_logger.debug("I am adding a DEBUG level log entry.") num_lines_after_info_debug = sum(1 for line in open(par_logfile)) From 059604c9e492a3aa9bb60801591685bee2b67fe4 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 23 Jan 2023 15:05:48 +0100 Subject: [PATCH 131/135] CHG: remove SessionLogger --- syncopy/__init__.py | 21 +++++---- syncopy/datatype/util.py | 97 +++++++++++++--------------------------- 2 files changed, 43 insertions(+), 75 deletions(-) diff --git a/syncopy/__init__.py b/syncopy/__init__.py index 4d2c2bb32..b8f6305fc 100644 --- a/syncopy/__init__.py +++ b/syncopy/__init__.py @@ -7,14 +7,10 @@ import os import sys import subprocess -import datetime import getpass import socket import numpy as np from hashlib import blake2b, sha1 -import logging -import warnings -import platform from importlib.metadata import version, PackageNotFoundError import dask.distributed as dd @@ -123,7 +119,6 @@ def startup_print_once(message): # Establish ID and log-file for current session __sessionid__ = blake2b(digest_size=2, salt=os.urandom(blake2b.SALT_SIZE)).hexdigest() -__sessionfile__ = os.path.join(__storage__, "session_{}.id".format(__sessionid__)) # Set max. no. of lines for traceback info shown in prompt __tbcount__ = 5 @@ -146,13 +141,23 @@ def startup_print_once(message): from .plotting import * from .preproc import * -# Register session -__session__ = datatype.util.SessionLogger() +from .datatype.util import setup_storage +storage_tmpdir_size_gb, storage_tmpdir_numfiles = setup_storage() # Creates the storage dir if needed and computes size and number of files in there if any. from .shared.log import setup_logging -setup_logging(spydir=spydir, session=__session__) # Sets __logdir__. +__logdir__ = None # Gets set in setup_logging() call below. +setup_logging(spydir=spydir, session=__sessionid__) # Sets __logdir__. startup_print_once(f"Logging to log directory '{__logdir__}'.\nTemporary storage directory set to '{__storage__}'.\n") +if storage_tmpdir_size_gb > __storagelimit__: + msg = ( + "\nSyncopy WARNING: Temporary storage folder {tmpdir:s} " + + "contains {nfs:d} files taking up a total of {sze:4.2f} GB on disk. \n" + + "Consider running `spy.cleanup()` to free up disk space." + ) + msg_formatted = msg.format(tmpdir=__storage__, nfs=storage_tmpdir_numfiles, sze=storage_tmpdir_size_gb) + startup_print_once(msg_formatted) + # Override default traceback (differentiate b/w Jupyter/iPython and regular Python) from .shared.errors import SPYExceptionHandler diff --git a/syncopy/datatype/util.py b/syncopy/datatype/util.py index 4fb88f672..6411737e2 100644 --- a/syncopy/datatype/util.py +++ b/syncopy/datatype/util.py @@ -3,16 +3,13 @@ """ import os -import getpass -import socket -from datetime import datetime from numbers import Number # Syncopy imports from syncopy import __storage__, __storagelimit__, __sessionid__ from syncopy.shared.errors import SPYTypeError, SPYValueError -__all__ = ['TrialIndexer', 'SessionLogger'] +__all__ = ['TrialIndexer'] class TrialIndexer: @@ -61,74 +58,40 @@ def __str__(self): return "{} element iterable".format(self._len) -class SessionLogger: +def setup_storage(): + """ + Create temporary storage dir and report on its size. - __slots__ = ["sessionfile", "_rm"] + Returns + ------- + storage_size: Size of files in temporary storage directory, in GB. + storage_num_files: Number of files in temporary storage directory. + """ - def __init__(self): - - # Create package-wide tmp directory if not already present - if not os.path.exists(__storage__): - try: - os.mkdir(__storage__) - except Exception as exc: - err = ( - "Syncopy core: cannot create temporary storage directory {}. " - + "Original error message below\n{}" - ) - raise IOError(err.format(__storage__, str(exc))) - - # Check for upper bound of temp directory size - with os.scandir(__storage__) as scan: - st_size = 0.0 - st_fles = 0 - for fle in scan: - try: - st_size += fle.stat().st_size / 1024 ** 3 - st_fles += 1 - # this catches a cleanup by another process - except FileNotFoundError: - continue - - if st_size > __storagelimit__: - msg = ( - "\nSyncopy WARNING: Temporary storage folder {tmpdir:s} " - + "contains {nfs:d} files taking up a total of {sze:4.2f} GB on disk. \n" - + "Consider running `spy.cleanup()` to free up disk space." - ) - print(msg.format(tmpdir=__storage__, nfs=st_fles, sze=st_size)) - - # If we made it to this point, (attempt to) write the session file - sess_log = "{user:s}@{host:s}: <{time:s}> started session {sess:s}" - self.sessionfile = os.path.join( - __storage__, "session_{}_log.id".format(__sessionid__) - ) + # Create package-wide tmp directory if not already present + if not os.path.exists(__storage__): try: - with open(self.sessionfile, "w") as fid: - fid.write( - sess_log.format( - user=getpass.getuser(), - host=socket.gethostname(), - time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - sess=__sessionid__, - ) - ) + os.mkdir(__storage__) except Exception as exc: - err = "Syncopy core: cannot access {}. Original error message below\n{}" - raise IOError(err.format(self.sessionfile, str(exc))) + err = ( + "Syncopy core: cannot create temporary storage directory {}. " + + "Original error message below\n{}" + ) + raise IOError(err.format(__storage__, str(exc))) + + # Check for upper bound of temp directory size + with os.scandir(__storage__) as scan: + storage_size = 0.0 + storage_num_files = 0 + for fle in scan: + try: + storage_size += fle.stat().st_size / 1024 ** 3 + storage_num_files += 1 + # this catches a cleanup by another process + except FileNotFoundError: + continue - # Workaround to prevent Python from garbage-collecting ``os.unlink`` - self._rm = os.unlink + return storage_size, storage_num_files - def __repr__(self): - return self.__str__() - def __str__(self): - return "Session {}".format(__sessionid__) - def __del__(self): - try: - self._rm(self.sessionfile) - except FileNotFoundError: - pass - From bab250cd8c038df53fe964c93be528a0eafb4231 Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 23 Jan 2023 15:26:33 +0100 Subject: [PATCH 132/135] CHG: Update new default log level in docs. --- doc/source/developer/logging.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/source/developer/logging.rst b/doc/source/developer/logging.rst index a8983b1dc..8f056df54 100644 --- a/doc/source/developer/logging.rst +++ b/doc/source/developer/logging.rst @@ -11,7 +11,7 @@ is run by the remote workers in a high performance computing (HPC) cluster envir Log levels ----------- -The default log level is for the Syncopy logger is `'logging.WARNING'` (from now on referred to as `'WARNING'`). This means that you will not see any Syncopy messages below that threshold, i.e., messages printed with log levels `'DEBUG'` and `'INFO'`. To change the log level, you can either use the logging API in your application code as explained below, or set the environment variable `'SPYLOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs of the logging module `_ for details on the supported log levels. +The default log level is for the Syncopy logger is `'logging.IMPORTANT'` (from now on referred to as `'IMPORTANT'`). This means that you will not see any Syncopy messages below that threshold, i.e., messages printed with log levels `'DEBUG'` and `'INFO'`. To change the log level, you can either use the logging API in your application code as explained below, or set the environment variable `'SPYLOGLEVEL'` to one of the values supported by the logging module, e.g., 'CRITICAL', 'WARNING', 'INFO', or 'DEBUG'. See the `official docs of the logging module `_ for details on the supported log levels. Note that IMPORTANT is a custom log level with importance 25, i.e., between INFO and WARNING. Log file location From e9701c8d0a0d8d96e10bb883115d733ba83fedeb Mon Sep 17 00:00:00 2001 From: Tim Schaefer Date: Mon, 23 Jan 2023 15:45:02 +0100 Subject: [PATCH 133/135] CHG: set log level of logfile locations on startup to DEBUG --- syncopy/shared/log.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index 0c43e7168..d979a6fbb 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -70,7 +70,7 @@ def filter(self, record): spy_logger.setLevel(loglevel) spy_logger.info(f"Starting Syncopy session at {datetime.datetime.now().astimezone().isoformat()}.") - spy_logger.important(f"Syncopy logger '{loggername}' setup to log to file '{logfile}' at level {loglevel}.") + spy_logger.debug(f"Syncopy logger '{loggername}' setup to log to file '{logfile}' at level {loglevel}.") # Log to per-host files in parallel code by default. # Note that this setup handles only the logger of the current host. @@ -95,7 +95,7 @@ def filter(self, record): sh.setFormatter(fmt_interactive) spy_parallel_logger.addHandler(sh) - spy_parallel_logger.important(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {parloglevel}.") + spy_parallel_logger.debug(f"Syncopy parallel logger '{parallel_logger_name}' setup to log to file '{logfile_par}' at level {parloglevel}.") # See https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility/35804945#35804945 From 822edd0b9f8b730adc44d4370fb15604611a41cc Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 23 Jan 2023 16:04:06 +0100 Subject: [PATCH 134/135] CHG: Elevate parallel processing log level Changes to be committed: modified: syncopy/shared/kwarg_decorators.py --- syncopy/shared/kwarg_decorators.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/syncopy/shared/kwarg_decorators.py b/syncopy/shared/kwarg_decorators.py index 89446a315..c07fd6ef5 100644 --- a/syncopy/shared/kwarg_decorators.py +++ b/syncopy/shared/kwarg_decorators.py @@ -11,13 +11,15 @@ import dask.distributed as dd -# Local imports +import syncopy as spy from syncopy.shared.errors import (SPYTypeError, SPYValueError, SPYError, SPYWarning, SPYInfo) from syncopy.shared.tools import StructDict from syncopy.shared.metadata import h5_add_metadata, parse_cF_returns + +# Local imports from .dask_helpers import check_slurm_available, check_workers_available -import syncopy as spy +from .log import get_logger __all__ = [] @@ -463,6 +465,8 @@ def detect_parallel_client(func): @functools.wraps(func) def parallel_client_detector(*args, **kwargs): + logger = get_logger() + # Extract `parallel` keyword: if `parallel` is `False`, nothing happens parallel = kwargs.get("parallel") kill_spawn = False @@ -478,8 +482,8 @@ def parallel_client_detector(*args, **kwargs): try: client = dd.get_client() check_workers_available(client.cluster) - msg = f"..attaching to running Dask client:\n{client}" - SPYInfo(msg) + msg = f"..attaching to running Dask client:\n\t{client}" + logger.important(msg) parallel = True except ValueError: parallel = False @@ -493,7 +497,7 @@ def parallel_client_detector(*args, **kwargs): client = dd.get_client() check_workers_available(client.cluster) msg = f"..attaching to running Dask client:\n{client}" - SPYInfo(msg) + logger.important(msg) except ValueError: # we are on a HPC but ACME and Dask client are missing, # LocalCluster still gets created @@ -515,8 +519,8 @@ def parallel_client_detector(*args, **kwargs): dd.Client(cluster) kill_spawn = True msg = ("No running Dask cluster found, created a local instance:\n" - f"\t {cluster.scheduler}") - SPYInfo(msg) + f"\t{cluster.scheduler}") + logger.important(msg) # Add/update `parallel` to/in keyword args kwargs["parallel"] = parallel From 16afe42dd4766e18453584d3ee81d166adb57d68 Mon Sep 17 00:00:00 2001 From: tensionhead Date: Mon, 23 Jan 2023 16:16:33 +0100 Subject: [PATCH 135/135] CHG: Shorten time format, remove milliseconds Changes to be committed: modified: syncopy/shared/log.py --- syncopy/shared/log.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/syncopy/shared/log.py b/syncopy/shared/log.py index d979a6fbb..eada5d364 100644 --- a/syncopy/shared/log.py +++ b/syncopy/shared/log.py @@ -55,8 +55,16 @@ def filter(self, record): # The logger for local/sequential stuff -- goes to terminal and to a file. spy_logger = logging.getLogger(loggername) - fmt_interactive = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s') # Interactive formatter: no hostname and session info (less clutter on terminal). - fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s - %(session)s: %(message)s') # Log file formatter: with hostname and session info. + + datefmt_interactive = '%H:%M:%S' + datefmt_file = "%Y-%m-%d %H:%M:%S" + + # Interactive formatter: no hostname and session info (less clutter on terminal). + fmt_interactive = logging.Formatter('%(asctime)s - %(levelname)s: %(message)s', datefmt_interactive) + # Log file formatter: with hostname and session info. + fmt_with_hostname = logging.Formatter('%(asctime)s - %(levelname)s - %(hostname)s - %(session)s: %(message)s', + datefmt_file) + sh = logging.StreamHandler(sys.stdout) sh.setFormatter(fmt_interactive) spy_logger.addHandler(sh) @@ -206,5 +214,3 @@ def delete_all_logfiles(silent=True): warnings.warn(f"Could not delete log file '{logfile}': {str(ex)}") if not silent: print(f"Deleted {num_deleted} log files from directory '{logdir}'.") - -