Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NEW: TrialIndexer #422

Merged
merged 3 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion syncopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
from .preproc import *

# Register session
__session__ = datatype.base_data.SessionLogger()
__session__ = datatype.util.SessionLogger()

# Override default traceback (differentiate b/w Jupyter/iPython and regular Python)
from .shared.errors import SPYExceptionHandler
Expand Down
2 changes: 2 additions & 0 deletions syncopy/datatype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .methods.selectdata import *
from .methods.show import *
from .methods.copy import *
from .util import *

# Populate local __all__ namespace
__all__ = []
Expand All @@ -25,3 +26,4 @@
__all__.extend(methods.selectdata.__all__)
__all__.extend(methods.show.__all__)
__all__.extend(methods.copy.__all__)
__all__.extend(util.__all__)
175 changes: 21 additions & 154 deletions syncopy/datatype/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,16 @@
import sys
import os
from abc import ABC, abstractmethod
from datetime import datetime
from hashlib import blake2b
from itertools import islice
from functools import reduce
from inspect import signature
import shutil
import numpy as np
import h5py
import scipy as sp

# Local imports
import syncopy as spy
from .util import TrialIndexer
from .methods.arithmetic import _process_operator
from .methods.selectdata import selectdata
from .methods.show import show
Expand Down Expand Up @@ -744,8 +742,13 @@ def _t0(self):
def trials(self):
"""list-like array of trials"""

return Indexer(map(self._get_trial, range(self.sampleinfo.shape[0])),
self.sampleinfo.shape[0]) if self.sampleinfo is not None else None
if self.sampleinfo is not None:
trial_ids = list(range(self.sampleinfo.shape[0]))
# this is cheap as it just initializes a list-like object
# with no real data and/or computation!
return TrialIndexer(self, trial_ids)
else:
return None

@property
def trialinfo(self):
Expand Down Expand Up @@ -1161,147 +1164,6 @@ def __init__(self, filename=None, dimord=None, mode="r+", **kwargs):
self._version = __version__


class Indexer:

__slots__ = ["_iterobj", "_iterlen"]

def __init__(self, iterobj, iterlen):
"""
Make an iterable object subscriptable using itertools magic
"""
self._iterobj = iterobj
self._iterlen = iterlen

def __iter__(self):
return self._iterobj

def __getitem__(self, idx):
if np.issubdtype(type(idx), np.number):
try:
scalar_parser(
idx, varname="idx", ntype="int_like", lims=[0, self._iterlen - 1]
)
except Exception as exc:
raise exc
return next(islice(self._iterobj, idx, idx + 1))
elif isinstance(idx, slice):
start, stop = idx.start, idx.stop
if idx.start is None:
start = 0
if idx.stop is None:
stop = self._iterlen
index = slice(start, stop, idx.step)
if not (0 <= index.start < self._iterlen) or not (
0 < index.stop <= self._iterlen
):
err = "value between {lb:s} and {ub:s}"
raise SPYValueError(
err.format(lb="0", ub=str(self._iterlen)),
varname="idx",
actual=str(index),
)
return np.hstack(islice(self._iterobj, index.start, index.stop, index.step))
elif isinstance(idx, (list, np.ndarray)):
try:
array_parser(
idx,
varname="idx",
ntype="int_like",
hasnan=False,
hasinf=False,
lims=[0, self._iterlen],
dims=1,
)
except Exception as exc:
raise exc
return np.hstack(
[next(islice(self._iterobj, int(ix), int(ix + 1))) for ix in idx]
)
else:
raise SPYTypeError(idx, varname="idx", expected="int_like or slice")

def __len__(self):
return self._iterlen

def __repr__(self):
return self.__str__()

def __str__(self):
return "{} element iterable".format(self._iterlen)


class SessionLogger:

__slots__ = ["sessionfile", "_rm"]

def __init__(self):

# Create package-wide tmp directory if not already present
if not os.path.exists(__storage__):
try:
os.mkdir(__storage__)
except Exception as exc:
err = (
"Syncopy core: cannot create temporary storage directory {}. "
+ "Original error message below\n{}"
)
raise IOError(err.format(__storage__, str(exc)))

# Check for upper bound of temp directory size
with os.scandir(__storage__) as scan:
st_size = 0.0
st_fles = 0
for fle in scan:
try:
st_size += fle.stat().st_size / 1024 ** 3
st_fles += 1
# this catches a cleanup by another process
except FileNotFoundError:
continue

if st_size > __storagelimit__:
msg = (
"\nSyncopy <core> WARNING: Temporary storage folder {tmpdir:s} "
+ "contains {nfs:d} files taking up a total of {sze:4.2f} GB on disk. \n"
+ "Consider running `spy.cleanup()` to free up disk space."
)
print(msg.format(tmpdir=__storage__, nfs=st_fles, sze=st_size))

# If we made it to this point, (attempt to) write the session file
sess_log = "{user:s}@{host:s}: <{time:s}> started session {sess:s}"
self.sessionfile = os.path.join(
__storage__, "session_{}_log.id".format(__sessionid__)
)
try:
with open(self.sessionfile, "w") as fid:
fid.write(
sess_log.format(
user=getpass.getuser(),
host=socket.gethostname(),
time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
sess=__sessionid__,
)
)
except Exception as exc:
err = "Syncopy core: cannot access {}. Original error message below\n{}"
raise IOError(err.format(self.sessionfile, str(exc)))

# Workaround to prevent Python from garbage-collecting ``os.unlink``
self._rm = os.unlink

def __repr__(self):
return self.__str__()

def __str__(self):
return "Session {}".format(__sessionid__)

def __del__(self):
try:
self._rm(self.sessionfile)
except FileNotFoundError:
pass


class FauxTrial:
"""
Stand-in mockup of NumPy arrays representing trial data
Expand Down Expand Up @@ -1617,23 +1479,28 @@ def trial_ids(self, dataselect):
raise SPYValueError(legal=lgl, varname=vname, actual=act)
else:
trials = trlList
self._trial_ids = list(trials) # ensure `trials` is a list cf. #180
self._trial_ids = list(trials) # ensure `trials` is a list cf. #180

@property
def trials(self):
"""
Returns an Indexer indexing single trial arrays respecting the selection
Indices are RELATIVE with respect to existing trial selections:
Returns an iterable indexing single trial arrays respecting the selection
Indices are ABSOLUTE with respect to existing trial selections:

>>> selection.trials[2]
>>> selection.trials[11]

indexes the 3rd trial of `selection.trial_ids`
indexes the 11th trial of the original dataset, if and only if
trial number 11 is part of the selection.

Selections must be "simple": ordered and without repetitions
"""

return Indexer(map(self._get_trial, self.trial_ids),
len(self.trial_ids)) if self.trial_ids is not None else None
if self.sampleinfo is not None:
# this is cheap as it just initializes a list-like object
# with no real data and/or computations!
return TrialIndexer(self, self.trial_ids)
else:
return None

def create_get_trial(self, data):
""" Closure to allow emulation of BaseData._get_trial"""
Expand Down Expand Up @@ -2338,7 +2205,7 @@ def __str__(self):
attr,
"s" if not attr.endswith("s") else "",
)
elif isinstance(val, (list, Indexer)):
elif isinstance(val, (list, TrialIndexer)):
ppdict[attr] = "{0:d} {1:s}{2:s}, ".format(
len(val), attr, "s" if not attr.endswith("s") else ""
)
Expand Down
12 changes: 1 addition & 11 deletions syncopy/datatype/discrete_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# Local imports
from .base_data import BaseData, Indexer, FauxTrial
from .base_data import BaseData, FauxTrial
from .methods.definetrial import definetrial
from syncopy.shared.parsers import scalar_parser, array_parser
from syncopy.shared.errors import SPYValueError
Expand Down Expand Up @@ -164,16 +164,6 @@ def trialid(self, trlid):
raise exc
self._trialid = np.array(trlid, dtype=int)

@property
def trials(self):
"""list-like([sample x (>=2)] :class:`numpy.ndarray`) : trial slices of :attr:`data` property"""
if self.trialid is not None:
valid_trls = np.unique(self.trialid[self.trialid >= 0])
return Indexer(map(self._get_trial, valid_trls),
valid_trls.size)
else:
return None

@property
def trialtime(self):
"""list(:class:`numpy.ndarray`): trigger-relative sample times in s"""
Expand Down
Loading