-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Properly Convert References in Compound Datasets (#529)
* move DataStub into @ directory * update gitignore to allow datastub dir * minor refactor of load_mat_style * Major refactor to io.parseCompound - retab function - function no longer returns table due to ambiguous implications from vector direction. - expand variable names and remove useless functions * Major refactor for io.parseReference - Retab function - expand variable and loop names * Support compound and reference types for DataStub Also support reference types nested in compound structs * retab verify container equal * Loosen equality checks with compound types Since we allow multiple types of inputs for compound data types we need to allow for multiple output validation for them too. --------- Co-authored-by: Lawrence Niu <[email protected]>
- Loading branch information
1 parent
b7583a3
commit d8c7d71
Showing
8 changed files
with
603 additions
and
574 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,65 +1,56 @@ | ||
function data = parseCompound(did, data) | ||
%did is the dataset_id for the containing dataset | ||
%data should be a scalar struct with fields as columns | ||
if isempty(data) | ||
return; | ||
end | ||
tid = H5D.get_type(did); | ||
ncol = H5T.get_nmembers(tid); | ||
subtids = cell(1, ncol); | ||
ref_i = false(1, ncol); | ||
char_i = false(1, ncol); | ||
bool_i = false(1,ncol); | ||
for i = 1:ncol | ||
subtid = H5T.get_member_type(tid, i-1); | ||
subtids{i} = subtid; | ||
switch H5T.get_member_class(tid, i-1) | ||
case H5ML.get_constant_value('H5T_REFERENCE') | ||
ref_i(i) = true; | ||
case H5ML.get_constant_value('H5T_STRING') | ||
%if not variable len (which would make it a cell array) | ||
%then mark for transpose | ||
char_i(i) = ~H5T.is_variable_str(subtid); | ||
case H5ML.get_constant_value('H5T_ENUM') | ||
bool_i(i) = io.isBool(subtid); | ||
otherwise | ||
%do nothing | ||
function data = parseCompound(datasetId, data) | ||
%did is the dataset_id for the containing dataset | ||
%data should be a scalar struct with fields as columns | ||
if isempty(data) | ||
return; | ||
end | ||
end | ||
|
||
fields = fieldnames(data); | ||
if any(ref_i) | ||
%resolve references by column | ||
reftids = subtids(ref_i); | ||
refFields = fields(ref_i); | ||
for j=1:length(refFields) | ||
rpname = refFields{j}; | ||
refdata = data.(rpname); | ||
reflist = cell(size(refdata, 2), 1); | ||
for k=1:size(refdata, 2) | ||
r = refdata(:,k); | ||
reflist{k} = io.parseReference(did, reftids{j}, r); | ||
typeId = H5D.get_type(datasetId); | ||
numFields = H5T.get_nmembers(typeId); | ||
subTypeId = cell(1, numFields); | ||
isReferenceType = false(1, numFields); | ||
isCharacterType = false(1, numFields); | ||
isLogicalType = false(1,numFields); | ||
for iField = 1:numFields | ||
fieldTypeId = H5T.get_member_type(typeId, iField-1); | ||
subTypeId{iField} = fieldTypeId; | ||
switch H5T.get_member_class(typeId, iField-1) | ||
case H5ML.get_constant_value('H5T_REFERENCE') | ||
isReferenceType(iField) = true; | ||
case H5ML.get_constant_value('H5T_STRING') | ||
%if not variable len (which would make it a cell array) | ||
%then mark for transpose | ||
isCharacterType(iField) = ~H5T.is_variable_str(fieldTypeId); | ||
case H5ML.get_constant_value('H5T_ENUM') | ||
isLogicalType(iField) = io.isBool(fieldTypeId); | ||
otherwise | ||
%do nothing | ||
end | ||
data.(rpname) = [reflist{:}] .'; | ||
end | ||
end | ||
|
||
if any(char_i) | ||
%transpose character arrays because they are column-ordered | ||
%when read | ||
charFields = fields(char_i); | ||
for j=1:length(charFields) | ||
cpname = charFields{j}; | ||
data.(cpname) = data.(cpname) .'; | ||
fieldName = fieldnames(data); | ||
|
||
% resolve references by column | ||
referenceTypeId = subTypeId(isReferenceType); | ||
referenceFieldName = fieldName(isReferenceType); | ||
for iFieldName = 1:length(referenceFieldName) | ||
name = referenceFieldName{iFieldName}; | ||
rawReference = data.(name); | ||
rawTypeId = referenceTypeId{iFieldName}; | ||
data.(name) = io.parseReference(datasetId, rawTypeId, rawReference); | ||
end | ||
end | ||
|
||
if any(bool_i) | ||
% convert column data to proper logical arrays/matrices | ||
for f=fields{bool_i} | ||
data.(f) = strcmp('TRUE', data.(f)); | ||
% transpose character arrays because they are column-ordered | ||
% when read | ||
characterFieldName = fieldName(isCharacterType); | ||
for iFieldName = 1:length(characterFieldName) | ||
name = characterFieldName{iFieldName}; | ||
data.(name) = data.(name) .'; | ||
end | ||
end | ||
|
||
data = struct2table(data); | ||
% convert column data to proper logical arrays/matrices | ||
logicalFieldName = fieldName(isLogicalType); | ||
for iFieldName = 1:length(logicalFieldName) | ||
name = logicalFieldName{iFieldName}; | ||
data.(name) = strcmp('TRUE', data.(name)); | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,63 @@ | ||
function refobj = parseReference(did, tid, data) | ||
szref = size(data); | ||
%first dimension is always the raw buffer size | ||
szref = szref(2:end); | ||
if isscalar(szref) | ||
szref = [szref 1]; | ||
end | ||
numref = prod(szref); | ||
if H5T.equal(tid, 'H5T_STD_REF_OBJ') | ||
reftype = H5ML.get_constant_value('H5R_OBJECT'); | ||
else | ||
reftype = H5ML.get_constant_value('H5R_DATASET_REGION'); | ||
end | ||
for i=1:numref | ||
refobj(i) = parseSingleRef(did, reftype, data(:,i)); | ||
end | ||
refobj = reshape(refobj, szref); | ||
function Reference = parseReference(datasetId, typeId, data) | ||
referenceSize = size(data); | ||
%first dimension is always the raw buffer size | ||
referenceSize = referenceSize(2:end); | ||
if isscalar(referenceSize) | ||
referenceSize = [referenceSize 1]; | ||
end | ||
totalNumReferences = prod(referenceSize); | ||
if H5T.equal(typeId, 'H5T_STD_REF_OBJ') | ||
referenceType = H5ML.get_constant_value('H5R_OBJECT'); | ||
else | ||
referenceType = H5ML.get_constant_value('H5R_DATASET_REGION'); | ||
end | ||
for iReference = 1:totalNumReferences | ||
Reference(iReference) = parseSingleReference(datasetId, referenceType, data(:,iReference)); | ||
end | ||
Reference = reshape(Reference, referenceSize); | ||
end | ||
|
||
function refobj = parseSingleRef(did, reftype, data) | ||
target = H5R.get_name(did, reftype, data); | ||
function Reference = parseSingleReference(datasetId, referenceType, data) | ||
target = H5R.get_name(datasetId, referenceType, data); | ||
|
||
%% H5R_OBJECT | ||
if reftype == H5ML.get_constant_value('H5R_OBJECT') | ||
refobj = types.untyped.ObjectView(target); | ||
return; | ||
end | ||
%% H5R_OBJECT | ||
if referenceType == H5ML.get_constant_value('H5R_OBJECT') | ||
Reference = types.untyped.ObjectView(target); | ||
return; | ||
end | ||
|
||
%% H5R_DATASET_REGION | ||
if isempty(target) | ||
refobj = types.untyped.RegionView(target); | ||
return; | ||
end | ||
sid = H5R.get_region(did, reftype, data); | ||
%% H5R_DATASET_REGION | ||
if isempty(target) | ||
Reference = types.untyped.RegionView(target); | ||
return; | ||
end | ||
spaceId = H5R.get_region(datasetId, referenceType, data); | ||
|
||
if H5ML.get_constant_value('H5S_SEL_HYPERSLABS') ~= H5S.get_select_type(sid) | ||
warning('NWB:ParseReference:UnsupportedSelectionType',... | ||
['MatNWB does not support space selections other than hyperslab mode. '... | ||
'Ignoring other selections.']); | ||
end | ||
if H5ML.get_constant_value('H5S_SEL_HYPERSLABS') ~= H5S.get_select_type(spaceId) | ||
warning('NWB:ParseReference:UnsupportedSelectionType',... | ||
['MatNWB does not support space selections other than hyperslab mode. '... | ||
'Ignoring other selections.']); | ||
end | ||
|
||
blocklist = flipud(H5S.get_select_hyper_blocklist(sid, 0, H5S.get_select_hyper_nblocks(sid))); | ||
% Returns an (m x 2n) array, where m is the number of dimensions (or rank) of the dataspace. | ||
% The 2n rows of Result contain the list of blocks. The first row contains the start | ||
% coordinates of the first block, followed by the next row which contains the opposite | ||
% corner coordinates, followed by the next row which contains the start coordinates of the | ||
% second block,etc. | ||
selections = cell(size(blocklist, 1), 1); | ||
for i = 1:length(selections) | ||
prevSel = selections{i}; | ||
blockDim = mat2cell(blocklist(i,:), 1, ones(1, (size(blocklist, 2) / 2)) + 1); | ||
for j = 1:length(blockDim) | ||
block = blockDim{j}; | ||
blockDim{j} = (block(1):block(2))+1; | ||
numHyperBlocks = H5S.get_select_hyper_nblocks(spaceId); | ||
selectionBlock = flipud(H5S.get_select_hyper_blocklist(spaceId, 0, numHyperBlocks)); | ||
% Returns an (m x 2n) array, where m is the number of dimensions (or rank) of the dataspace. | ||
% The 2n rows of Result contain the list of blocks. The first row contains the start | ||
% coordinates of the first block, followed by the next row which contains the opposite | ||
% corner coordinates, followed by the next row which contains the start coordinates of the | ||
% second block,etc. | ||
selections = cell(size(selectionBlock, 1), 1); | ||
selectionCellSize = {1, (ones(1, (size(selectionBlock, 2) / 2)) + 1)}; | ||
for iSelection = 1:length(selections) | ||
previousSelection = selections{iSelection}; | ||
blockDimension = mat2cell(selectionBlock(iSelection,:), selectionCellSize{:}); | ||
for iDimension = 1:length(blockDimension) | ||
block = blockDimension{iDimension}; | ||
blockDimension{iDimension} = (block(1):block(2))+1; | ||
end | ||
selections{iSelection} = [previousSelection cell2mat(blockDimension)]; | ||
end | ||
selections{i} = [prevSel cell2mat(blockDim)]; | ||
end | ||
|
||
H5S.close(sid); | ||
refobj = types.untyped.RegionView(target, selections{:}); | ||
H5S.close(spaceId); | ||
Reference = types.untyped.RegionView(target, selections{:}); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,78 @@ | ||
function verifyContainerEqual(testCase, actual, expected, ignoreList) | ||
if nargin < 4 | ||
ignoreList = {}; | ||
end | ||
assert(iscellstr(ignoreList),... | ||
'NWB:Test:InvalidIgnoreList',... | ||
['Ignore List must be a cell array of character arrays indicating props that should be '... | ||
'ignored.']); | ||
testCase.verifyEqual(class(actual), class(expected)); | ||
props = setdiff(properties(actual), ignoreList); | ||
for i = 1:numel(props) | ||
prop = props{i}; | ||
|
||
actualValue = actual.(prop); | ||
expectedValue = expected.(prop); | ||
failureMessage = ['Values for property ''' prop ''' are not equal']; | ||
|
||
if isa(actualValue, 'types.untyped.DataStub') | ||
actualValue = actualValue.load(); | ||
if nargin < 4 | ||
ignoreList = {}; | ||
end | ||
|
||
if startsWith(class(expectedValue), 'types.') && ~startsWith(class(expectedValue), 'types.untyped') | ||
tests.util.verifyContainerEqual(testCase, actualValue, expectedValue); | ||
elseif isa(expectedValue, 'types.untyped.Set') | ||
tests.util.verifySetEqual(testCase, actualValue, expectedValue, failureMessage); | ||
elseif ischar(expectedValue) | ||
testCase.verifyEqual(char(actualValue), expectedValue, failureMessage); | ||
elseif isa(expectedValue, 'types.untyped.ObjectView') || isa(expectedValue, 'types.untyped.SoftLink') | ||
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage); | ||
elseif isa(expectedValue, 'types.untyped.RegionView') | ||
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage); | ||
testCase.verifyEqual(actualValue.region, expectedValue.region, failureMessage); | ||
elseif isa(expectedValue, 'types.untyped.Anon') | ||
testCase.verifyEqual(actualValue.name, expectedValue.name, failureMessage); | ||
tests.util.verifyContainerEqual(testCase, actualValue.value, expectedValue.value); | ||
elseif isdatetime(expectedValue)... | ||
|| (iscell(expectedValue) && all(cellfun('isclass', expectedValue, 'datetime'))) | ||
% linux MATLAB doesn't appear to propery compare datetimes whereas | ||
% Windows MATLAB does. This is a workaround to get tests to work | ||
% while getting close enough to exact date representation. | ||
actualValue = types.util.checkDtype(prop, 'datetime', actualValue); | ||
if ~iscell(expectedValue) | ||
expectedValue = num2cell(expectedValue); | ||
assert(iscellstr(ignoreList),... | ||
'NWB:Test:InvalidIgnoreList',... | ||
['Ignore List must be a cell array of character arrays indicating props that should be '... | ||
'ignored.']); | ||
testCase.verifyEqual(class(actual), class(expected)); | ||
props = setdiff(properties(actual), ignoreList); | ||
for iProperty = 1:numel(props) | ||
prop = props{iProperty}; | ||
|
||
actualValue = actual.(prop); | ||
expectedValue = expected.(prop); | ||
failureMessage = ['Values for property ''' prop ''' are not equal']; | ||
|
||
if isa(actualValue, 'types.untyped.DataStub') | ||
actualValue = actualValue.load(); | ||
end | ||
if ~iscell(actualValue) | ||
actualValue = num2cell(actualValue); | ||
|
||
if startsWith(class(expectedValue), 'types.') && ~startsWith(class(expectedValue), 'types.untyped') | ||
tests.util.verifyContainerEqual(testCase, actualValue, expectedValue); | ||
elseif isa(expectedValue, 'types.untyped.Set') | ||
tests.util.verifySetEqual(testCase, actualValue, expectedValue, failureMessage); | ||
elseif ischar(expectedValue) | ||
testCase.verifyEqual(char(actualValue), expectedValue, failureMessage); | ||
elseif isa(expectedValue, 'types.untyped.ObjectView') || isa(expectedValue, 'types.untyped.SoftLink') | ||
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage); | ||
elseif isa(expectedValue, 'types.untyped.RegionView') | ||
testCase.verifyEqual(actualValue.path, expectedValue.path, failureMessage); | ||
testCase.verifyEqual(actualValue.region, expectedValue.region, failureMessage); | ||
elseif isa(expectedValue, 'types.untyped.Anon') | ||
testCase.verifyEqual(actualValue.name, expectedValue.name, failureMessage); | ||
tests.util.verifyContainerEqual(testCase, actualValue.value, expectedValue.value); | ||
elseif isdatetime(expectedValue)... | ||
|| (iscell(expectedValue) && all(cellfun('isclass', expectedValue, 'datetime'))) | ||
% linux MATLAB doesn't appear to propery compare datetimes whereas | ||
% Windows MATLAB does. This is a workaround to get tests to work | ||
% while getting close enough to exact date representation. | ||
actualValue = types.util.checkDtype(prop, 'datetime', actualValue); | ||
if ~iscell(expectedValue) | ||
expectedValue = num2cell(expectedValue); | ||
end | ||
if ~iscell(actualValue) | ||
actualValue = num2cell(actualValue); | ||
end | ||
for iDates = 1:length(expectedValue) | ||
% ignore microseconds as linux datetime has some strange error | ||
% even when datetime doesn't change in Windows. | ||
ActualDate = actualValue{iDates}; | ||
ExpectedDate = expectedValue{iDates}; | ||
ExpectedUpperBound = ExpectedDate + milliseconds(1); | ||
ExpectedLowerBound = ExpectedDate - milliseconds(1); | ||
testCase.verifyTrue(isbetween(ActualDate, ExpectedLowerBound, ExpectedUpperBound) ... | ||
, failureMessage); | ||
end | ||
elseif startsWith(class(expectedValue), 'int') | ||
testCase.verifyEqual(int64(actualValue), int64(expectedValue), failureMessage); | ||
elseif startsWith(class(expectedValue), 'uint') | ||
testCase.verifyEqual(uint64(actualValue), uint64(expectedValue), failureMessage); | ||
elseif isstruct(expectedValue) || istable(expectedValue) | ||
if istable(expectedValue) | ||
fieldNames = expectedValue.Properties.VariableNames; | ||
else | ||
fieldNames = fieldnames(expectedValue); | ||
end | ||
fieldNames = convertStringsToChars(fieldNames); | ||
testCase.verifyTrue(isstruct(actualValue) || istable(actualValue), failureMessage); | ||
for iField = 1:length(fieldNames) | ||
name = fieldNames{iField}; | ||
testCase.verifyEqual(actualValue.(name), expectedValue.(name), failureMessage); | ||
end | ||
else | ||
testCase.verifyEqual(actualValue, expectedValue, failureMessage); | ||
end | ||
for iDates = 1:length(expectedValue) | ||
% ignore microseconds as linux datetime has some strange error | ||
% even when datetime doesn't change in Windows. | ||
ActualDate = actualValue{iDates}; | ||
ExpectedDate = expectedValue{iDates}; | ||
ExpectedUpperBound = ExpectedDate + milliseconds(1); | ||
ExpectedLowerBound = ExpectedDate - milliseconds(1); | ||
testCase.verifyTrue(isbetween(ActualDate, ExpectedLowerBound, ExpectedUpperBound) ... | ||
, failureMessage); | ||
end | ||
elseif startsWith(class(expectedValue), 'int') | ||
testCase.verifyEqual(int64(actualValue), int64(expectedValue), failureMessage); | ||
elseif startsWith(class(expectedValue), 'uint') | ||
testCase.verifyEqual(uint64(actualValue), uint64(expectedValue), failureMessage); | ||
else | ||
testCase.verifyEqual(actualValue, expectedValue, failureMessage); | ||
end | ||
end | ||
end |
Oops, something went wrong.