Skip to content

Commit

Permalink
Fix Tabular.from_xml_element for histogram case (#3287)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulromano authored Feb 4, 2025
1 parent 59c398b commit 6e0f156
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
5 changes: 3 additions & 2 deletions openmc/stats/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,8 +1116,9 @@ def from_xml_element(cls, elem: ET.Element):
"""
interpolation = get_text(elem, 'interpolation')
params = [float(x) for x in get_text(elem, 'parameters').split()]
x = params[:len(params)//2]
p = params[len(params)//2:]
m = (len(params) + 1)//2 # +1 for when len(params) is odd
x = params[:m]
p = params[m:]
return cls(x, p, interpolation)

def integral(self):
Expand Down
6 changes: 5 additions & 1 deletion src/distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,12 @@ Tabular::Tabular(pugi::xml_node node)
interp_ = Interpolation::histogram;
}

// Read and initialize tabular distribution
// Read and initialize tabular distribution. If number of parameters is odd,
// add an extra zero for the 'p' array.
auto params = get_node_array<double>(node, "parameters");
if (params.size() % 2 != 0) {
params.push_back(0.0);
}
std::size_t n = params.size() / 2;
const double* x = params.data();
const double* p = x + n;
Expand Down
33 changes: 23 additions & 10 deletions tests/unit_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,19 +195,10 @@ def test_watt():


def test_tabular():
# test linear-linear sampling
x = np.array([0.0, 5.0, 7.0, 10.0])
p = np.array([10.0, 20.0, 5.0, 6.0])
d = openmc.stats.Tabular(x, p, 'linear-linear')
elem = d.to_xml_element('distribution')

d = openmc.stats.Tabular.from_xml_element(elem)
assert all(d.x == x)
assert all(d.p == p)
assert d.interpolation == 'linear-linear'
assert len(d) == len(x)

# test linear-linear sampling
d = openmc.stats.Tabular(x, p)
n_samples = 100_000
samples = d.sample(n_samples)
assert_sample_mean(samples, d.mean())
Expand Down Expand Up @@ -242,6 +233,28 @@ def test_tabular():
d.cdf()


def test_tabular_from_xml():
x = np.array([0.0, 5.0, 7.0, 10.0])
p = np.array([10.0, 20.0, 5.0, 6.0])
d = openmc.stats.Tabular(x, p, 'linear-linear')
elem = d.to_xml_element('distribution')

d = openmc.stats.Tabular.from_xml_element(elem)
assert all(d.x == x)
assert all(d.p == p)
assert d.interpolation == 'linear-linear'
assert len(d) == len(x)

# Make sure XML roundtrip works with len(x) == len(p) + 1
x = np.array([0.0, 5.0, 7.0, 10.0])
p = np.array([10.0, 20.0, 5.0])
d = openmc.stats.Tabular(x, p, 'histogram')
elem = d.to_xml_element('distribution')
d = openmc.stats.Tabular.from_xml_element(elem)
assert all(d.x == x)
assert all(d.p == p)


def test_legendre():
# Pu239 elastic scattering at 100 keV
coeffs = [1.000e+0, 1.536e-1, 1.772e-2, 5.945e-4, 3.497e-5, 1.881e-5]
Expand Down

0 comments on commit 6e0f156

Please sign in to comment.