Skip to content

Commit

Permalink
fix LeastSquares.visualize for models that accept parameter array (#968)
Browse files Browse the repository at this point in the history
Closes #966 

`LeastSquares.model` can now be used to uniformly call models which
accept parameter values and parameter arrays in the same way, as already
indicated by the docstring.
  • Loading branch information
HDembinski authored Feb 8, 2024
1 parent 4945271 commit dce1f63
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
14 changes: 11 additions & 3 deletions src/iminuit/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,7 @@ def prediction(self, args: Sequence[float]) -> Tuple[NDArray, NDArray]:
Returns
-------
y, yerr : NDArray, np.array
y, yerr : NDArray, NDArray
Template prediction and its standard deviation, based on the statistical
uncertainty of the template only.
"""
Expand Down Expand Up @@ -2104,7 +2104,12 @@ def yerror(self, value):
@property
def model(self):
"""Get model of the form y = f(x, par0, [par1, ...])."""
return self._model
if len(self._parameters) == 1:
return lambda x, *args: (
self._model(x, args) if len(args) > 1 else self._model(x, *args)
)
else:
return self._model

@property
def loss(self):
Expand Down Expand Up @@ -2198,7 +2203,9 @@ def __init__(
def _ndata(self):
return len(self._masked)

def visualize(self, args: ArrayLike, model_points: Union[int, Sequence[float]] = 0):
def visualize(
self, args: ArrayLike, model_points: Union[int, Sequence[float]] = 0
) -> Tuple[Tuple[NDArray, NDArray, NDArray], Tuple[NDArray, NDArray]]:
"""
Visualize data and model agreement (requires matplotlib).
Expand Down Expand Up @@ -2233,6 +2240,7 @@ def visualize(self, args: ArrayLike, model_points: Union[int, Sequence[float]] =
else:
xm, ym = _smart_sampling(lambda x: self.model(x, *args), x[0], x[-1])
plt.plot(xm, ym)
return (x, y, ye), (xm, ym)

def prediction(self, args: Sequence[float]) -> NDArray:
"""
Expand Down
33 changes: 27 additions & 6 deletions tests/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,13 +1139,14 @@ def test_LeastSquares_mask_2():

def test_LeastSquares_properties():
def model(x, a):
return a
return x + 2 * a

c = LeastSquares(1, 2, 3, model)
assert_equal(c.x, [1])
assert_equal(c.y, [2])
assert_equal(c.yerror, [3])
assert c.model is model
assert c.model(1, 1) == model(1, 1)
assert c.model(2, 3) == model(2, 3)
with pytest.raises(AttributeError):
c.model = model
with pytest.raises(ValueError):
Expand All @@ -1162,14 +1163,34 @@ def test_LeastSquares_visualize():
c = LeastSquares([1, 2], [2, 3], 0.1, line)

# auto-sampling
c.visualize((1, 2))
(x, y, ye), (xm, ym) = c.visualize((1, 2))
assert_equal(x, (1, 2))
assert_equal(y, (2, 3))
assert_equal(ye, 0.1)
assert len(xm) < 10
# linear spacing
c.visualize((1, 2), model_points=10)
(x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=10)
assert len(xm) == 10
assert_allclose(xm[1:] - xm[:-1], xm[1] - xm[0])
# trigger use of log-spacing
c = LeastSquares([1, 10, 100], [2, 3, 4], 0.1, line)
c.visualize((1, 2), model_points=10)
(x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=10)
assert len(xm) == 10
assert_allclose(xm[1:] / xm[:-1], xm[1] / xm[0])
# manual spacing
c.visualize((1, 2), model_points=np.linspace(1, 100))
(x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=np.linspace(1, 100))
assert_equal(xm, np.linspace(1, 100))


def test_LeastSquares_visualize_par_array():
pytest.importorskip("matplotlib")

def line(x, par):
return par[0] + par[1] * x

c = LeastSquares([1, 2], [2, 3], 0.1, line)

c.visualize((1, 2))


def test_LeastSquares_visualize_2D():
Expand Down

0 comments on commit dce1f63

Please sign in to comment.