Skip to content

Commit

Permalink
ENH: added keyword to Dataset.from_file and propagated so saved datas…
Browse files Browse the repository at this point in the history
…ets with old subject names can be loaded with new subject name
  • Loading branch information
marklescroart authored and TomDLT committed Sep 16, 2022
1 parent b134110 commit 9f0c828
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
9 changes: 6 additions & 3 deletions cortex/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __dir__(self):
return list(self.__dict__.keys()) + list(self.views.keys())

@classmethod
def from_file(cls, filename):
def from_file(cls, filename, subject=None):
ds = cls()
ds.h5 = h5py.File(filename, 'r')

Expand All @@ -76,14 +76,17 @@ def from_file(cls, filename):
if name in ("data", "subjects", "views"):
continue
try:
ds.views[name] = _from_hdf_data(ds.h5, name)
ds.views[name] = _from_hdf_data(ds.h5, name, subject=subject)
except KeyError:
print('No metadata found for "%s", skipping...'%name)

#load up the views generated by pycortex
for name, node in ds.h5['views'].items():
try:
ds.views[name] = Dataview.from_hdf(node)
ds.views[name] = Dataview.from_hdf(node, subject=subject)
except FileNotFoundError:
print("Could not load file; old subject name? Try using `subject` kwarg to specify a current pycortex subject")
raise
except Exception:
import traceback
traceback.print_exc()
Expand Down
38 changes: 20 additions & 18 deletions cortex/dataset/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,24 @@ def normalize(data):
else:
raise TypeError("Invalid input for Dataview")

def _from_hdf_data(h5, name, xfmname=None, **kwargs):
def _from_hdf_data(h5, name, xfmname=None, subject=None, **kwargs):
"""Decodes a __hash named node from an HDF file into the
constituent Vertex or Volume object"""
dnode = h5.get("/data/%s"%name)
if dnode is None:
dnode = h5.get(name)

attrs = {k: u(v) for (k, v) in dnode.attrs.items()}
subj = attrs['subject']
if subject is None:
subject = attrs['subject']
#support old style xfmname saving as attribute
if xfmname is None and 'xfmname' in attrs:
xfmname = attrs['xfmname']
mask = None
if 'mask' in attrs:
if attrs['mask'].startswith("__"):
mask = h5['/subjects/%s/transforms/%s/masks/%s'%(attrs['subject'], xfmname, attrs['mask'])].value
mask = h5['/subjects/%s/transforms/%s/masks/%s' %
(attrs['subject'], xfmname, attrs['mask'])].value
else:
mask = attrs['mask']

Expand All @@ -52,37 +54,37 @@ def _from_hdf_data(h5, name, xfmname=None, **kwargs):
alpha = dnode[..., 3]

if xfmname is None:
return VertexRGB(dnode[...,0], dnode[...,1], dnode[...,2], subj,
return VertexRGB(dnode[...,0], dnode[...,1], dnode[...,2], subject,
alpha=alpha, **kwargs)

return VolumeRGB(dnode[...,0], dnode[...,1], dnode[...,2], subj, xfmname,
return VolumeRGB(dnode[...,0], dnode[...,1], dnode[...,2], subject, xfmname,
alpha=alpha, mask=mask, **kwargs)

if xfmname is None:
return Vertex(dnode, subj, **kwargs)
return Vertex(dnode, subject, **kwargs)

return Volume(dnode, subj, xfmname, mask=mask, **kwargs)
return Volume(dnode, subject, xfmname, mask=mask, **kwargs)


def _from_hdf_view(h5, data, xfmname=None, vmin=None, vmax=None, **kwargs):
def _from_hdf_view(h5, data, xfmname=None, vmin=None, vmax=None, subject=None, **kwargs):

if isinstance(data, string_types):
return _from_hdf_data(h5, data, xfmname=xfmname, vmin=vmin, vmax=vmax, **kwargs)
return _from_hdf_data(h5, data, xfmname=xfmname, vmin=vmin, vmax=vmax, subject=subject, **kwargs)

if len(data) == 2:
dim1 = _from_hdf_data(h5, data[0], xfmname=xfmname[0])
dim2 = _from_hdf_data(h5, data[1], xfmname=xfmname[1])
dim1 = _from_hdf_data(h5, data[0], xfmname=xfmname[0], subject=subject)
dim2 = _from_hdf_data(h5, data[1], xfmname=xfmname[1], subject=subject)
cls = Vertex2D if isinstance(dim1, Vertex) else Volume2D
return cls(dim1, dim2, vmin=vmin[0], vmin2=vmin[1],
vmax=vmax[0], vmax2=vmax[1], **kwargs)
vmax=vmax[0], vmax2=vmax[1], subject=subject, **kwargs)
elif len(data) == 4:
red, green, blue = [_from_hdf_data(h5, d, xfmname=xfmname) for d in data[:3]]
red, green, blue = [_from_hdf_data(h5, d, xfmname=xfmname, subject=subject) for d in data[:3]]
alpha = None
if data[3] is not None:
alpha = _from_hdf_data(h5, data[3], xfmname=xfmname)
alpha = _from_hdf_data(h5, data[3], xfmname=xfmname, subject=subject)

cls = VertexRGB if isinstance(red, Vertex) else VolumeRGB
return cls(red, green, blue, alpha=alpha, **kwargs)
return cls(red, green, blue, alpha=alpha, subject=subject, **kwargs)
else:
raise ValueError("Invalid Dataview specification")

Expand Down Expand Up @@ -140,7 +142,7 @@ def to_json(self, simple=False):
return sdict

@staticmethod
def from_hdf(node):
def from_hdf(node, subject=None):
data = json.loads(u(node[0]))
desc = node[1]
try:
Expand All @@ -166,9 +168,9 @@ def from_hdf(node):
if len(data) == 1:
xfm = None if xfmname is None else xfmname[0]
return _from_hdf_view(node.file, data[0], xfmname=xfm, cmap=cmap[0], description=desc,
vmin=vmin[0], vmax=vmax[0], state=state, **attrs)
vmin=vmin[0], vmax=vmax[0], state=state, subject=subject, **attrs)
else:
views = [_from_hdf_view(node.file, d, xfmname=x) for d, x in zip(data, xfname)]
views = [_from_hdf_view(node.file, d, xfmname=x, subject=subject) for d, x in zip(data, xfmname)]
raise NotImplementedError

def _write_hdf(self, h5, name="data", data=None, xfmname=None):
Expand Down

0 comments on commit 9f0c828

Please sign in to comment.