diff --git a/src/iminuit/cost.py b/src/iminuit/cost.py index 49611ee4..1bb943e9 100644 --- a/src/iminuit/cost.py +++ b/src/iminuit/cost.py @@ -1830,7 +1830,7 @@ def prediction(self, args: Sequence[float]) -> Tuple[NDArray, NDArray]: Returns ------- - y, yerr : NDArray, np.array + y, yerr : NDArray, NDArray Template prediction and its standard deviation, based on the statistical uncertainty of the template only. """ @@ -2104,7 +2104,12 @@ def yerror(self, value): @property def model(self): """Get model of the form y = f(x, par0, [par1, ...]).""" - return self._model + if len(self._parameters) == 1: + return lambda x, *args: ( + self._model(x, args) if len(args) > 1 else self._model(x, *args) + ) + else: + return self._model @property def loss(self): @@ -2198,7 +2203,9 @@ def __init__( def _ndata(self): return len(self._masked) - def visualize(self, args: ArrayLike, model_points: Union[int, Sequence[float]] = 0): + def visualize( + self, args: ArrayLike, model_points: Union[int, Sequence[float]] = 0 + ) -> Tuple[Tuple[NDArray, NDArray, NDArray], Tuple[NDArray, NDArray]]: """ Visualize data and model agreement (requires matplotlib). @@ -2233,6 +2240,7 @@ def visualize(self, args: ArrayLike, model_points: Union[int, Sequence[float]] = else: xm, ym = _smart_sampling(lambda x: self.model(x, *args), x[0], x[-1]) plt.plot(xm, ym) + return (x, y, ye), (xm, ym) def prediction(self, args: Sequence[float]) -> NDArray: """ diff --git a/tests/test_cost.py b/tests/test_cost.py index ac9b1fd6..77b17bb9 100644 --- a/tests/test_cost.py +++ b/tests/test_cost.py @@ -1139,13 +1139,14 @@ def test_LeastSquares_mask_2(): def test_LeastSquares_properties(): def model(x, a): - return a + return x + 2 * a c = LeastSquares(1, 2, 3, model) assert_equal(c.x, [1]) assert_equal(c.y, [2]) assert_equal(c.yerror, [3]) - assert c.model is model + assert c.model(1, 1) == model(1, 1) + assert c.model(2, 3) == model(2, 3) with pytest.raises(AttributeError): c.model = model with pytest.raises(ValueError): @@ -1162,14 +1163,34 @@ def test_LeastSquares_visualize(): c = LeastSquares([1, 2], [2, 3], 0.1, line) # auto-sampling - c.visualize((1, 2)) + (x, y, ye), (xm, ym) = c.visualize((1, 2)) + assert_equal(x, (1, 2)) + assert_equal(y, (2, 3)) + assert_equal(ye, 0.1) + assert len(xm) < 10 # linear spacing - c.visualize((1, 2), model_points=10) + (x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=10) + assert len(xm) == 10 + assert_allclose(xm[1:] - xm[:-1], xm[1] - xm[0]) # trigger use of log-spacing c = LeastSquares([1, 10, 100], [2, 3, 4], 0.1, line) - c.visualize((1, 2), model_points=10) + (x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=10) + assert len(xm) == 10 + assert_allclose(xm[1:] / xm[:-1], xm[1] / xm[0]) # manual spacing - c.visualize((1, 2), model_points=np.linspace(1, 100)) + (x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=np.linspace(1, 100)) + assert_equal(xm, np.linspace(1, 100)) + + +def test_LeastSquares_visualize_par_array(): + pytest.importorskip("matplotlib") + + def line(x, par): + return par[0] + par[1] * x + + c = LeastSquares([1, 2], [2, 3], 0.1, line) + + c.visualize((1, 2)) def test_LeastSquares_visualize_2D():