Skip to content

Commit

Permalink
Merge pull request #1771 from paulromano/check-source-dtype
Browse files Browse the repository at this point in the history
Check that source_bank datatype matches when reading from file
  • Loading branch information
gridley authored Feb 26, 2021
2 parents e2842cf + bd3ebd9 commit af22d18
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/state_point.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,13 +698,35 @@ write_source_bank(hid_t group_id, bool surf_source_bank)
H5Tclose(banktype);
}

// Determine member names of a compound HDF5 datatype
std::string dtype_member_names(hid_t dtype_id)
{
int nmembers = H5Tget_nmembers(dtype_id);
std::string names;
for (int i = 0; i < nmembers; i++) {
names = names.append(H5Tget_member_name(dtype_id, i));
if (i < nmembers - 1) names += ", ";
}
return names;
}

void read_source_bank(hid_t group_id, std::vector<Particle::Bank>& sites, bool distribute)
{
hid_t banktype = h5banktype();

// Open the dataset
hid_t dset = H5Dopen(group_id, "source_bank", H5P_DEFAULT);

// Make sure number of members matches
hid_t dtype = H5Dget_type(dset);
auto file_member_names = dtype_member_names(dtype);
auto bank_member_names = dtype_member_names(banktype);
if (file_member_names != bank_member_names) {
fatal_error(fmt::format("Source site attributes in file do not match what is "
"expected for this version of OpenMC. File attributes = ({}). Expected "
"attributes = ({})", file_member_names, bank_member_names));
}

hid_t dspace = H5Dget_space(dset);
hsize_t n_sites;
H5Sget_simple_extent_dims(dspace, &n_sites, nullptr);
Expand Down
33 changes: 33 additions & 0 deletions tests/unit_tests/test_source_file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from random import random
import subprocess

import h5py
import numpy as np
Expand Down Expand Up @@ -43,3 +44,35 @@ def test_source_file(run_in_tmpdir):
assert np.all(arr['wgt'] == 1.0)
assert np.all(arr['delayed_group'] == 0)
assert np.all(arr['particle'] == 0)


def test_wrong_source_attributes(run_in_tmpdir):
# Create a source file with animal attributes
source_dtype = np.dtype([
('platypus', '<f8'),
('axolotl', '<f8'),
('narwhal', '<i4'),
])
arr = np.array([(1.0, 2.0, 3), (4.0, 5.0, 6), (7.0, 8.0, 9)], dtype=source_dtype)
with h5py.File('animal_source.h5', 'w') as fh:
fh.attrs['filetype'] = np.string_("source")
fh.create_dataset('source_bank', data=arr)

# Create a simple model that uses this lovely animal source
m = openmc.Material()
m.add_nuclide('U235', 0.02)
openmc.Materials([m]).export_to_xml()
s = openmc.Sphere(r=10.0, boundary_type='vacuum')
c = openmc.Cell(fill=m, region=-s)
openmc.Geometry([c]).export_to_xml()
settings = openmc.Settings()
settings.particles = 100
settings.batches = 10
settings.source = openmc.Source(filename='animal_source.h5')
settings.export_to_xml()

# When we run the model, it should error out with a message that includes
# the names of the wrong attributes
with pytest.raises(subprocess.CalledProcessError) as excinfo:
openmc.run()
assert 'platypus, axolotl, narwhal' in excinfo.value.output

0 comments on commit af22d18

Please sign in to comment.