Skip to content

Commit

Permalink
Properly Convert References in Compound Datasets (#529)
Browse files Browse the repository at this point in the history
* move DataStub into @ directory

* update gitignore to allow datastub dir

* minor refactor of load_mat_style

* Major refactor to io.parseCompound

- retab function
- function no longer returns table due to ambiguous implications from
  vector direction.
- expand variable names and remove useless functions

* Major refactor for io.parseReference

- Retab function
- expand variable and loop names

* Support compound and reference types for DataStub

Also support reference types nested in compound structs

* retab verify container equal

* Loosen equality checks with compound types

Since we allow multiple types of inputs for compound data types we need
to allow for multiple output validation for them too.

---------

Co-authored-by: Lawrence Niu <[email protected]>
  • Loading branch information
lawrence-mbf and lawrence-mbf authored Jul 28, 2023
1 parent b7583a3 commit d8c7d71
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 574 deletions.
103 changes: 47 additions & 56 deletions +io/parseCompound.m
Original file line number Diff line number Diff line change
@@ -1,65 +1,56 @@
function data = parseCompound(did, data)
%did is the dataset_id for the containing dataset
%data should be a scalar struct with fields as columns
if isempty(data)
return;
end
tid = H5D.get_type(did);
ncol = H5T.get_nmembers(tid);
subtids = cell(1, ncol);
ref_i = false(1, ncol);
char_i = false(1, ncol);
bool_i = false(1,ncol);
for i = 1:ncol
subtid = H5T.get_member_type(tid, i-1);
subtids{i} = subtid;
switch H5T.get_member_class(tid, i-1)
case H5ML.get_constant_value('H5T_REFERENCE')
ref_i(i) = true;
case H5ML.get_constant_value('H5T_STRING')
%if not variable len (which would make it a cell array)
%then mark for transpose
char_i(i) = ~H5T.is_variable_str(subtid);
case H5ML.get_constant_value('H5T_ENUM')
bool_i(i) = io.isBool(subtid);
otherwise
%do nothing
function data = parseCompound(datasetId, data)
%did is the dataset_id for the containing dataset
%data should be a scalar struct with fields as columns
if isempty(data)
return;
end
end

fields = fieldnames(data);
if any(ref_i)
%resolve references by column
reftids = subtids(ref_i);
refFields = fields(ref_i);
for j=1:length(refFields)
rpname = refFields{j};
refdata = data.(rpname);
reflist = cell(size(refdata, 2), 1);
for k=1:size(refdata, 2)
r = refdata(:,k);
reflist{k} = io.parseReference(did, reftids{j}, r);
typeId = H5D.get_type(datasetId);
numFields = H5T.get_nmembers(typeId);
subTypeId = cell(1, numFields);
isReferenceType = false(1, numFields);
isCharacterType = false(1, numFields);
isLogicalType = false(1,numFields);
for iField = 1:numFields
fieldTypeId = H5T.get_member_type(typeId, iField-1);
subTypeId{iField} = fieldTypeId;
switch H5T.get_member_class(typeId, iField-1)
case H5ML.get_constant_value('H5T_REFERENCE')
isReferenceType(iField) = true;
case H5ML.get_constant_value('H5T_STRING')
%if not variable len (which would make it a cell array)
%then mark for transpose
isCharacterType(iField) = ~H5T.is_variable_str(fieldTypeId);
case H5ML.get_constant_value('H5T_ENUM')
isLogicalType(iField) = io.isBool(fieldTypeId);
otherwise
%do nothing
end
data.(rpname) = [reflist{:}] .';
end
end

if any(char_i)
%transpose character arrays because they are column-ordered
%when read
charFields = fields(char_i);
for j=1:length(charFields)
cpname = charFields{j};
data.(cpname) = data.(cpname) .';
fieldName = fieldnames(data);

% resolve references by column
referenceTypeId = subTypeId(isReferenceType);
referenceFieldName = fieldName(isReferenceType);
for iFieldName = 1:length(referenceFieldName)
name = referenceFieldName{iFieldName};
rawReference = data.(name);
rawTypeId = referenceTypeId{iFieldName};
data.(name) = io.parseReference(datasetId, rawTypeId, rawReference);
end
end

if any(bool_i)
% convert column data to proper logical arrays/matrices
for f=fields{bool_i}
data.(f) = strcmp('TRUE', data.(f));
% transpose character arrays because they are column-ordered
% when read
characterFieldName = fieldName(isCharacterType);
for iFieldName = 1:length(characterFieldName)
name = characterFieldName{iFieldName};
data.(name) = data.(name) .';
end
end

data = struct2table(data);
% convert column data to proper logical arrays/matrices
logicalFieldName = fieldName(isLogicalType);
for iFieldName = 1:length(logicalFieldName)
name = logicalFieldName{iFieldName};
data.(name) = strcmp('TRUE', data.(name));
end
end
106 changes: 54 additions & 52 deletions +io/parseReference.m
Original file line number Diff line number Diff line change
@@ -1,61 +1,63 @@
function refobj = parseReference(did, tid, data)
szref = size(data);
%first dimension is always the raw buffer size
szref = szref(2:end);
if isscalar(szref)
szref = [szref 1];
end
numref = prod(szref);
if H5T.equal(tid, 'H5T_STD_REF_OBJ')
reftype = H5ML.get_constant_value('H5R_OBJECT');
else
reftype = H5ML.get_constant_value('H5R_DATASET_REGION');
end
for i=1:numref
refobj(i) = parseSingleRef(did, reftype, data(:,i));
end
refobj = reshape(refobj, szref);
function Reference = parseReference(datasetId, typeId, data)
referenceSize = size(data);
%first dimension is always the raw buffer size
referenceSize = referenceSize(2:end);
if isscalar(referenceSize)
referenceSize = [referenceSize 1];
end
totalNumReferences = prod(referenceSize);
if H5T.equal(typeId, 'H5T_STD_REF_OBJ')
referenceType = H5ML.get_constant_value('H5R_OBJECT');
else
referenceType = H5ML.get_constant_value('H5R_DATASET_REGION');
end
for iReference = 1:totalNumReferences
Reference(iReference) = parseSingleReference(datasetId, referenceType, data(:,iReference));
end
Reference = reshape(Reference, referenceSize);
end

function refobj = parseSingleRef(did, reftype, data)
target = H5R.get_name(did, reftype, data);
function Reference = parseSingleReference(datasetId, referenceType, data)
target = H5R.get_name(datasetId, referenceType, data);

%% H5R_OBJECT
if reftype == H5ML.get_constant_value('H5R_OBJECT')
refobj = types.untyped.ObjectView(target);
return;
end
%% H5R_OBJECT
if referenceType == H5ML.get_constant_value('H5R_OBJECT')
Reference = types.untyped.ObjectView(target);
return;
end

%% H5R_DATASET_REGION
if isempty(target)
refobj = types.untyped.RegionView(target);
return;
end
sid = H5R.get_region(did, reftype, data);
%% H5R_DATASET_REGION
if isempty(target)
Reference = types.untyped.RegionView(target);
return;
end
spaceId = H5R.get_region(datasetId, referenceType, data);

if H5ML.get_constant_value('H5S_SEL_HYPERSLABS') ~= H5S.get_select_type(sid)
warning('NWB:ParseReference:UnsupportedSelectionType',...
['MatNWB does not support space selections other than hyperslab mode. '...
'Ignoring other selections.']);
end
if H5ML.get_constant_value('H5S_SEL_HYPERSLABS') ~= H5S.get_select_type(spaceId)
warning('NWB:ParseReference:UnsupportedSelectionType',...
['MatNWB does not support space selections other than hyperslab mode. '...
'Ignoring other selections.']);
end

blocklist = flipud(H5S.get_select_hyper_blocklist(sid, 0, H5S.get_select_hyper_nblocks(sid)));
% Returns an (m x 2n) array, where m is the number of dimensions (or rank) of the dataspace.
% The 2n rows of Result contain the list of blocks. The first row contains the start
% coordinates of the first block, followed by the next row which contains the opposite
% corner coordinates, followed by the next row which contains the start coordinates of the
% second block,etc.
selections = cell(size(blocklist, 1), 1);
for i = 1:length(selections)
prevSel = selections{i};
blockDim = mat2cell(blocklist(i,:), 1, ones(1, (size(blocklist, 2) / 2)) + 1);
for j = 1:length(blockDim)
block = blockDim{j};
blockDim{j} = (block(1):block(2))+1;
numHyperBlocks = H5S.get_select_hyper_nblocks(spaceId);
selectionBlock = flipud(H5S.get_select_hyper_blocklist(spaceId, 0, numHyperBlocks));
% Returns an (m x 2n) array, where m is the number of dimensions (or rank) of the dataspace.
% The 2n rows of Result contain the list of blocks. The first row contains the start
% coordinates of the first block, followed by the next row which contains the opposite
% corner coordinates, followed by the next row which contains the start coordinates of the
% second block,etc.
selections = cell(size(selectionBlock, 1), 1);
selectionCellSize = {1, (ones(1, (size(selectionBlock, 2) / 2)) + 1)};
for iSelection = 1:length(selections)
previousSelection = selections{iSelection};
blockDimension = mat2cell(selectionBlock(iSelection,:), selectionCellSize{:});
for iDimension = 1:length(blockDimension)
block = blockDimension{iDimension};
blockDimension{iDimension} = (block(1):block(2))+1;
end
selections{iSelection} = [previousSelection cell2mat(blockDimension)];
end
selections{i} = [prevSel cell2mat(blockDim)];
end

H5S.close(sid);
refobj = types.untyped.RegionView(target, selections{:});
H5S.close(spaceId);
Reference = types.untyped.RegionView(target, selections{:});
end
132 changes: 72 additions & 60 deletions +tests/+util/verifyContainerEqual.m
Original file line number Diff line number Diff line change
@@ -1,66 +1,78 @@
function verifyContainerEqual(testCase, actual, expected, ignoreList)
if nargin < 4
ignoreList = {};
end
assert(iscellstr(ignoreList),...
'NWB:Test:InvalidIgnoreList',...
['Ignore List must be a cell array of character arrays indicating props that should be '...
'ignored.']);
testCase.verifyEqual(class(actual), class(expected));
props = setdiff(properties(actual), ignoreList);
for i = 1:numel(props)
prop = props{i};

actualValue = actual.(prop);
expectedValue = expected.(prop);
failureMessage = ['Values for property ''' prop ''' are not equal'];

if isa(actualValue, 'types.untyped.DataStub')
actualValue = actualValue.load();
if nargin < 4
ignoreList = {};
end

if startsWith(class(expectedValue), 'types.') && ~startsWith(class(expectedValue), 'types.untyped')
tests.util.verifyContainerEqual(testCase, actualValue, expectedValue);
elseif isa(expectedValue, 'types.untyped.Set')
tests.util.verifySetEqual(testCase, actualValue, expectedValue, failureMessage);
elseif ischar(expectedValue)
testCase.verifyEqual(char(actualValue), expectedValue, failureMessage);
elseif isa(expectedValue, 'types.untyped.ObjectView') || isa(expectedValue, 'types.untyped.SoftLink')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
elseif isa(expectedValue, 'types.untyped.RegionView')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
testCase.verifyEqual(actualValue.region, expectedValue.region, failureMessage);
elseif isa(expectedValue, 'types.untyped.Anon')
testCase.verifyEqual(actualValue.name, expectedValue.name, failureMessage);
tests.util.verifyContainerEqual(testCase, actualValue.value, expectedValue.value);
elseif isdatetime(expectedValue)...
|| (iscell(expectedValue) && all(cellfun('isclass', expectedValue, 'datetime')))
% linux MATLAB doesn't appear to propery compare datetimes whereas
% Windows MATLAB does. This is a workaround to get tests to work
% while getting close enough to exact date representation.
actualValue = types.util.checkDtype(prop, 'datetime', actualValue);
if ~iscell(expectedValue)
expectedValue = num2cell(expectedValue);
assert(iscellstr(ignoreList),...
'NWB:Test:InvalidIgnoreList',...
['Ignore List must be a cell array of character arrays indicating props that should be '...
'ignored.']);
testCase.verifyEqual(class(actual), class(expected));
props = setdiff(properties(actual), ignoreList);
for iProperty = 1:numel(props)
prop = props{iProperty};

actualValue = actual.(prop);
expectedValue = expected.(prop);
failureMessage = ['Values for property ''' prop ''' are not equal'];

if isa(actualValue, 'types.untyped.DataStub')
actualValue = actualValue.load();
end
if ~iscell(actualValue)
actualValue = num2cell(actualValue);

if startsWith(class(expectedValue), 'types.') && ~startsWith(class(expectedValue), 'types.untyped')
tests.util.verifyContainerEqual(testCase, actualValue, expectedValue);
elseif isa(expectedValue, 'types.untyped.Set')
tests.util.verifySetEqual(testCase, actualValue, expectedValue, failureMessage);
elseif ischar(expectedValue)
testCase.verifyEqual(char(actualValue), expectedValue, failureMessage);
elseif isa(expectedValue, 'types.untyped.ObjectView') || isa(expectedValue, 'types.untyped.SoftLink')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
elseif isa(expectedValue, 'types.untyped.RegionView')
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage);
testCase.verifyEqual(actualValue.region, expectedValue.region, failureMessage);
elseif isa(expectedValue, 'types.untyped.Anon')
testCase.verifyEqual(actualValue.name, expectedValue.name, failureMessage);
tests.util.verifyContainerEqual(testCase, actualValue.value, expectedValue.value);
elseif isdatetime(expectedValue)...
|| (iscell(expectedValue) && all(cellfun('isclass', expectedValue, 'datetime')))
% linux MATLAB doesn't appear to propery compare datetimes whereas
% Windows MATLAB does. This is a workaround to get tests to work
% while getting close enough to exact date representation.
actualValue = types.util.checkDtype(prop, 'datetime', actualValue);
if ~iscell(expectedValue)
expectedValue = num2cell(expectedValue);
end
if ~iscell(actualValue)
actualValue = num2cell(actualValue);
end
for iDates = 1:length(expectedValue)
% ignore microseconds as linux datetime has some strange error
% even when datetime doesn't change in Windows.
ActualDate = actualValue{iDates};
ExpectedDate = expectedValue{iDates};
ExpectedUpperBound = ExpectedDate + milliseconds(1);
ExpectedLowerBound = ExpectedDate - milliseconds(1);
testCase.verifyTrue(isbetween(ActualDate, ExpectedLowerBound, ExpectedUpperBound) ...
, failureMessage);
end
elseif startsWith(class(expectedValue), 'int')
testCase.verifyEqual(int64(actualValue), int64(expectedValue), failureMessage);
elseif startsWith(class(expectedValue), 'uint')
testCase.verifyEqual(uint64(actualValue), uint64(expectedValue), failureMessage);
elseif isstruct(expectedValue) || istable(expectedValue)
if istable(expectedValue)
fieldNames = expectedValue.Properties.VariableNames;
else
fieldNames = fieldnames(expectedValue);
end
fieldNames = convertStringsToChars(fieldNames);
testCase.verifyTrue(isstruct(actualValue) || istable(actualValue), failureMessage);
for iField = 1:length(fieldNames)
name = fieldNames{iField};
testCase.verifyEqual(actualValue.(name), expectedValue.(name), failureMessage);
end
else
testCase.verifyEqual(actualValue, expectedValue, failureMessage);
end
for iDates = 1:length(expectedValue)
% ignore microseconds as linux datetime has some strange error
% even when datetime doesn't change in Windows.
ActualDate = actualValue{iDates};
ExpectedDate = expectedValue{iDates};
ExpectedUpperBound = ExpectedDate + milliseconds(1);
ExpectedLowerBound = ExpectedDate - milliseconds(1);
testCase.verifyTrue(isbetween(ActualDate, ExpectedLowerBound, ExpectedUpperBound) ...
, failureMessage);
end
elseif startsWith(class(expectedValue), 'int')
testCase.verifyEqual(int64(actualValue), int64(expectedValue), failureMessage);
elseif startsWith(class(expectedValue), 'uint')
testCase.verifyEqual(uint64(actualValue), uint64(expectedValue), failureMessage);
else
testCase.verifyEqual(actualValue, expectedValue, failureMessage);
end
end
end
Loading

0 comments on commit d8c7d71

Please sign in to comment.