Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly Convert References in Compound Datasets #529

Merged
merged 8 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 47 additions & 56 deletions +io/parseCompound.m
Original file line number Diff line number Diff line change
@@ -1,65 +1,56 @@
function data = parseCompound(did, data)
%did is the dataset_id for the containing dataset
%data should be a scalar struct with fields as columns
if isempty(data)
return;
end
tid = H5D.get_type(did);
ncol = H5T.get_nmembers(tid);
subtids = cell(1, ncol);
ref_i = false(1, ncol);
char_i = false(1, ncol);
bool_i = false(1,ncol);
for i = 1:ncol
subtid = H5T.get_member_type(tid, i-1);
subtids{i} = subtid;
switch H5T.get_member_class(tid, i-1)
case H5ML.get_constant_value('H5T_REFERENCE')
ref_i(i) = true;
case H5ML.get_constant_value('H5T_STRING')
%if not variable len (which would make it a cell array)
%then mark for transpose
char_i(i) = ~H5T.is_variable_str(subtid);
case H5ML.get_constant_value('H5T_ENUM')
bool_i(i) = io.isBool(subtid);
otherwise
%do nothing
function data = parseCompound(datasetId, data)
%did is the dataset_id for the containing dataset
%data should be a scalar struct with fields as columns
if isempty(data)
return;
end
end

fields = fieldnames(data);
if any(ref_i)
%resolve references by column
reftids = subtids(ref_i);
refFields = fields(ref_i);
for j=1:length(refFields)
rpname = refFields{j};
refdata = data.(rpname);
reflist = cell(size(refdata, 2), 1);
for k=1:size(refdata, 2)
r = refdata(:,k);
reflist{k} = io.parseReference(did, reftids{j}, r);
typeId = H5D.get_type(datasetId);
numFields = H5T.get_nmembers(typeId);
subTypeId = cell(1, numFields);
isReferenceType = false(1, numFields);
isCharacterType = false(1, numFields);
isLogicalType = false(1,numFields);
for iField = 1:numFields
fieldTypeId = H5T.get_member_type(typeId, iField-1);
subTypeId{iField} = fieldTypeId;
switch H5T.get_member_class(typeId, iField-1)
case H5ML.get_constant_value('H5T_REFERENCE')
isReferenceType(iField) = true;
case H5ML.get_constant_value('H5T_STRING')
%if not variable len (which would make it a cell array)
%then mark for transpose
isCharacterType(iField) = ~H5T.is_variable_str(fieldTypeId);
case H5ML.get_constant_value('H5T_ENUM')
isLogicalType(iField) = io.isBool(fieldTypeId);
otherwise
%do nothing
end
data.(rpname) = [reflist{:}] .';
end
end

if any(char_i)
%transpose character arrays because they are column-ordered
%when read
charFields = fields(char_i);
for j=1:length(charFields)
cpname = charFields{j};
data.(cpname) = data.(cpname) .';
fieldName = fieldnames(data);

% resolve references by column
referenceTypeId = subTypeId(isReferenceType);
referenceFieldName = fieldName(isReferenceType);
for iFieldName = 1:length(referenceFieldName)
name = referenceFieldName{iFieldName};
rawReference = data.(name);
rawTypeId = referenceTypeId{iFieldName};
data.(name) = io.parseReference(datasetId, rawTypeId, rawReference);
end
end

if any(bool_i)
% convert column data to proper logical arrays/matrices
for f=fields{bool_i}
data.(f) = strcmp('TRUE', data.(f));
% transpose character arrays because they are column-ordered
% when read
characterFieldName = fieldName(isCharacterType);
for iFieldName = 1:length(characterFieldName)
name = characterFieldName{iFieldName};
data.(name) = data.(name) .';
end
end

data = struct2table(data);
% convert column data to proper logical arrays/matrices
logicalFieldName = fieldName(isLogicalType);
for iFieldName = 1:length(logicalFieldName)
name = logicalFieldName{iFieldName};
data.(name) = strcmp('TRUE', data.(name));
end
end
106 changes: 54 additions & 52 deletions +io/parseReference.m
Original file line number Diff line number Diff line change
@@ -1,61 +1,63 @@
function refobj = parseReference(did, tid, data)
szref = size(data);
%first dimension is always the raw buffer size
szref = szref(2:end);
if isscalar(szref)
szref = [szref 1];
end
numref = prod(szref);
if H5T.equal(tid, 'H5T_STD_REF_OBJ')
reftype = H5ML.get_constant_value('H5R_OBJECT');
else
reftype = H5ML.get_constant_value('H5R_DATASET_REGION');
end
for i=1:numref
refobj(i) = parseSingleRef(did, reftype, data(:,i));
end
refobj = reshape(refobj, szref);
function Reference = parseReference(datasetId, typeId, data)
referenceSize = size(data);
%first dimension is always the raw buffer size
referenceSize = referenceSize(2:end);
if isscalar(referenceSize)
referenceSize = [referenceSize 1];
end
totalNumReferences = prod(referenceSize);
if H5T.equal(typeId, 'H5T_STD_REF_OBJ')
referenceType = H5ML.get_constant_value('H5R_OBJECT');
else
referenceType = H5ML.get_constant_value('H5R_DATASET_REGION');
end
for iReference = 1:totalNumReferences
Reference(iReference) = parseSingleReference(datasetId, referenceType, data(:,iReference));
end
Reference = reshape(Reference, referenceSize);
end

function refobj = parseSingleRef(did, reftype, data)
target = H5R.get_name(did, reftype, data);
function Reference = parseSingleReference(datasetId, referenceType, data)
target = H5R.get_name(datasetId, referenceType, data);

%% H5R_OBJECT
if reftype == H5ML.get_constant_value('H5R_OBJECT')
refobj = types.untyped.ObjectView(target);
return;
end
%% H5R_OBJECT
if referenceType == H5ML.get_constant_value('H5R_OBJECT')
Reference = types.untyped.ObjectView(target);
return;
end

%% H5R_DATASET_REGION
if isempty(target)
refobj = types.untyped.RegionView(target);
return;
end
sid = H5R.get_region(did, reftype, data);
%% H5R_DATASET_REGION
if isempty(target)
Reference = types.untyped.RegionView(target);
return;
end
spaceId = H5R.get_region(datasetId, referenceType, data);

if H5ML.get_constant_value('H5S_SEL_HYPERSLABS') ~= H5S.get_select_type(sid)
warning('NWB:ParseReference:UnsupportedSelectionType',...
['MatNWB does not support space selections other than hyperslab mode. '...
'Ignoring other selections.']);
end
if H5ML.get_constant_value('H5S_SEL_HYPERSLABS') ~= H5S.get_select_type(spaceId)
warning('NWB:ParseReference:UnsupportedSelectionType',...
['MatNWB does not support space selections other than hyperslab mode. '...
'Ignoring other selections.']);
end

blocklist = flipud(H5S.get_select_hyper_blocklist(sid, 0, H5S.get_select_hyper_nblocks(sid)));
% Returns an (m x 2n) array, where m is the number of dimensions (or rank) of the dataspace.
% The 2n rows of Result contain the list of blocks. The first row contains the start
% coordinates of the first block, followed by the next row which contains the opposite
% corner coordinates, followed by the next row which contains the start coordinates of the
% second block,etc.
selections = cell(size(blocklist, 1), 1);
for i = 1:length(selections)
prevSel = selections{i};
blockDim = mat2cell(blocklist(i,:), 1, ones(1, (size(blocklist, 2) / 2)) + 1);
for j = 1:length(blockDim)
block = blockDim{j};
blockDim{j} = (block(1):block(2))+1;
numHyperBlocks = H5S.get_select_hyper_nblocks(spaceId);
selectionBlock = flipud(H5S.get_select_hyper_blocklist(spaceId, 0, numHyperBlocks));
% Returns an (m x 2n) array, where m is the number of dimensions (or rank) of the dataspace.
% The 2n rows of Result contain the list of blocks. The first row contains the start
% coordinates of the first block, followed by the next row which contains the opposite
% corner coordinates, followed by the next row which contains the start coordinates of the
% second block,etc.
selections = cell(size(selectionBlock, 1), 1);
selectionCellSize = {1, (ones(1, (size(selectionBlock, 2) / 2)) + 1)};
for iSelection = 1:length(selections)
previousSelection = selections{iSelection};
blockDimension = mat2cell(selectionBlock(iSelection,:), selectionCellSize{:});
for iDimension = 1:length(blockDimension)
block = blockDimension{iDimension};
blockDimension{iDimension} = (block(1):block(2))+1;
end
selections{iSelection} = [previousSelection cell2mat(blockDimension)];
end
selections{i} = [prevSel cell2mat(blockDim)];
end

H5S.close(sid);
refobj = types.untyped.RegionView(target, selections{:});
H5S.close(spaceId);
Reference = types.untyped.RegionView(target, selections{:});
end
132 changes: 72 additions & 60 deletions +tests/+util/verifyContainerEqual.m
Original file line number Diff line number Diff line change
@@ -1,66 +1,78 @@
function verifyContainerEqual(testCase, actual, expected, ignoreList)
if nargin < 4
ignoreList = {};
end
assert(iscellstr(ignoreList),...
'NWB:Test:InvalidIgnoreList',...
['Ignore List must be a cell array of character arrays indicating props that should be '...
'ignored.']);
testCase.verifyEqual(class(actual), class(expected));
props = setdiff(properties(actual), ignoreList);
for i = 1:numel(props)
prop = props{i};

actualValue = actual.(prop);
expectedValue = expected.(prop);
failureMessage = ['Values for property ''' prop ''' are not equal'];

if isa(actualValue, 'types.untyped.DataStub')
actualValue = actualValue.load();
if nargin < 4
ignoreList = {};
end

if startsWith(class(expectedValue), 'types.') && ~startsWith(class(expectedValue), 'types.untyped')
tests.util.verifyContainerEqual(testCase, actualValue, expectedValue);
elseif isa(expectedValue, 'types.untyped.Set')
tests.util.verifySetEqual(testCase, actualValue, expectedValue, failureMessage);
elseif ischar(expectedValue)
testCase.verifyEqual(char(actualValue), expectedValue, failureMessage);
elseif isa(expectedValue, 'types.untyped.ObjectView') || isa(expectedValue, 'types.untyped.SoftLink')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
elseif isa(expectedValue, 'types.untyped.RegionView')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
testCase.verifyEqual(actualValue.region, expectedValue.region, failureMessage);
elseif isa(expectedValue, 'types.untyped.Anon')
testCase.verifyEqual(actualValue.name, expectedValue.name, failureMessage);
tests.util.verifyContainerEqual(testCase, actualValue.value, expectedValue.value);
elseif isdatetime(expectedValue)...
|| (iscell(expectedValue) && all(cellfun('isclass', expectedValue, 'datetime')))
% linux MATLAB doesn't appear to propery compare datetimes whereas
% Windows MATLAB does. This is a workaround to get tests to work
% while getting close enough to exact date representation.
actualValue = types.util.checkDtype(prop, 'datetime', actualValue);
if ~iscell(expectedValue)
expectedValue = num2cell(expectedValue);
assert(iscellstr(ignoreList),...
'NWB:Test:InvalidIgnoreList',...
['Ignore List must be a cell array of character arrays indicating props that should be '...
'ignored.']);
testCase.verifyEqual(class(actual), class(expected));
props = setdiff(properties(actual), ignoreList);
for iProperty = 1:numel(props)
prop = props{iProperty};

actualValue = actual.(prop);
expectedValue = expected.(prop);
failureMessage = ['Values for property ''' prop ''' are not equal'];

if isa(actualValue, 'types.untyped.DataStub')
actualValue = actualValue.load();
end
if ~iscell(actualValue)
actualValue = num2cell(actualValue);

if startsWith(class(expectedValue), 'types.') && ~startsWith(class(expectedValue), 'types.untyped')
tests.util.verifyContainerEqual(testCase, actualValue, expectedValue);
elseif isa(expectedValue, 'types.untyped.Set')
tests.util.verifySetEqual(testCase, actualValue, expectedValue, failureMessage);
elseif ischar(expectedValue)
testCase.verifyEqual(char(actualValue), expectedValue, failureMessage);
elseif isa(expectedValue, 'types.untyped.ObjectView') || isa(expectedValue, 'types.untyped.SoftLink')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
elseif isa(expectedValue, 'types.untyped.RegionView')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
testCase.verifyEqual(actualValue.region, expectedValue.region, failureMessage);
elseif isa(expectedValue, 'types.untyped.Anon')
testCase.verifyEqual(actualValue.name, expectedValue.name, failureMessage);
tests.util.verifyContainerEqual(testCase, actualValue.value, expectedValue.value);
elseif isdatetime(expectedValue)...
|| (iscell(expectedValue) && all(cellfun('isclass', expectedValue, 'datetime')))
% linux MATLAB doesn't appear to propery compare datetimes whereas
% Windows MATLAB does. This is a workaround to get tests to work
% while getting close enough to exact date representation.
actualValue = types.util.checkDtype(prop, 'datetime', actualValue);
if ~iscell(expectedValue)
expectedValue = num2cell(expectedValue);
end
if ~iscell(actualValue)
actualValue = num2cell(actualValue);
end
for iDates = 1:length(expectedValue)
% ignore microseconds as linux datetime has some strange error
% even when datetime doesn't change in Windows.
ActualDate = actualValue{iDates};
ExpectedDate = expectedValue{iDates};
ExpectedUpperBound = ExpectedDate + milliseconds(1);
ExpectedLowerBound = ExpectedDate - milliseconds(1);
testCase.verifyTrue(isbetween(ActualDate, ExpectedLowerBound, ExpectedUpperBound) ...
, failureMessage);
end
elseif startsWith(class(expectedValue), 'int')
testCase.verifyEqual(int64(actualValue), int64(expectedValue), failureMessage);
elseif startsWith(class(expectedValue), 'uint')
testCase.verifyEqual(uint64(actualValue), uint64(expectedValue), failureMessage);
elseif isstruct(expectedValue) || istable(expectedValue)
if istable(expectedValue)
fieldNames = expectedValue.Properties.VariableNames;
else
fieldNames = fieldnames(expectedValue);
end
fieldNames = convertStringsToChars(fieldNames);
testCase.verifyTrue(isstruct(actualValue) || istable(actualValue), failureMessage);
for iField = 1:length(fieldNames)
name = fieldNames{iField};
testCase.verifyEqual(actualValue.(name), expectedValue.(name), failureMessage);
end
else
testCase.verifyEqual(actualValue, expectedValue, failureMessage);
end
for iDates = 1:length(expectedValue)
% ignore microseconds as linux datetime has some strange error
% even when datetime doesn't change in Windows.
ActualDate = actualValue{iDates};
ExpectedDate = expectedValue{iDates};
ExpectedUpperBound = ExpectedDate + milliseconds(1);
ExpectedLowerBound = ExpectedDate - milliseconds(1);
testCase.verifyTrue(isbetween(ActualDate, ExpectedLowerBound, ExpectedUpperBound) ...
, failureMessage);
end
elseif startsWith(class(expectedValue), 'int')
testCase.verifyEqual(int64(actualValue), int64(expectedValue), failureMessage);
elseif startsWith(class(expectedValue), 'uint')
testCase.verifyEqual(uint64(actualValue), uint64(expectedValue), failureMessage);
else
testCase.verifyEqual(actualValue, expectedValue, failureMessage);
end
end
end
Loading