Skip to content

Commit

Permalink
Merge pull request #2375 from paulromano/tabular-mean-fix
Browse files Browse the repository at this point in the history
Don't call normalize inside Tabular.mean
  • Loading branch information
pshriwise authored Feb 4, 2023
2 parents 5c672f5 + dbbad68 commit 6218bec
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
11 changes: 7 additions & 4 deletions openmc/stats/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,6 @@ def mean(self):
'or linear-linear interpolation.')
if self.interpolation == 'linear-linear':
mean = 0.0
self.normalize()
for i in range(1, len(self.x)):
y_min = self.p[i-1]
y_max = self.p[i]
Expand All @@ -872,9 +871,13 @@ def mean(self):
mean += exp_val

elif self.interpolation == 'histogram':
mean = 0.5 * (self.x[:-1] + self.x[1:])
mean *= np.diff(self.cdf())
mean = sum(mean)
x_l = self.x[:-1]
x_r = self.x[1:]
p_l = self.p[:-1]
mean = (0.5 * (x_l + x_r) * (x_r - x_l) * p_l).sum()

# Normalize for when integral of distribution is not 1
mean /= self.integral()

return mean

Expand Down
11 changes: 7 additions & 4 deletions tests/unit_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_watt():

def test_tabular():
x = np.array([0.0, 5.0, 7.0])
p = np.array([0.1, 0.2, 0.05])
p = np.array([10.0, 20.0, 5.0])
d = openmc.stats.Tabular(x, p, 'linear-linear')
elem = d.to_xml_element('distribution')

Expand All @@ -178,19 +178,22 @@ def test_tabular():

# test linear-linear sampling
d = openmc.stats.Tabular(x, p)

n_samples = 100_000
samples = d.sample(n_samples)
assert_sample_mean(samples, d.mean())

# test histogram sampling
d = openmc.stats.Tabular(x, p, interpolation='histogram')
# test linear-linear normalization
d.normalize()
assert d.integral() == pytest.approx(1.0)

# test histogram sampling
d = openmc.stats.Tabular(x, p, interpolation='histogram')
samples = d.sample(n_samples)
assert_sample_mean(samples, d.mean())

d.normalize()
assert d.integral() == pytest.approx(1.0)


def test_legendre():
# Pu239 elastic scattering at 100 keV
Expand Down

0 comments on commit 6218bec

Please sign in to comment.