diff --git a/src/sparse_ir/sampling.py b/src/sparse_ir/sampling.py index fac169a..ce60d17 100644 --- a/src/sparse_ir/sampling.py +++ b/src/sparse_ir/sampling.py @@ -20,12 +20,64 @@ class AbstractSampling: |________________| fit |___________________| """ - def evaluate(self, al, axis=None): - """Evaluate the basis coefficients at the sparse sampling points""" + def evaluate(self, al, axis=None, *, points=None): + """Evaluate the basis coefficients at sampling points. + + Arguments: + al (array): + Array where the `l`-th item along `axis` corresponds to the + `l`-th basis coefficient + axis (integer): + Axis or dimension of `al` along which to evaluate the function. + Defaults to the last, i.e., rightmost axis. + points (vector): + Points on which the results should be evaluated. Defaults + to the sampling points for which the sampling objects was + created. + + Return: + Array where the `n`-th item along `axis` corresponds to the + value on the `n`-th sampling point (or value on `point[n]`, if + given.) + + Note: + If `points` is given, a new sampling is created at each invocation, + which can result in a performance hit. Consider caching sampling + objects or simply using the `.u()` and `.uhat()` methods of the + underlying basis. + """ + if points is not None: + return self._for_sampling_points(points).evaluate(al, axis) + return self.matrix.matmul(al, axis) - def fit(self, ax, axis=None): - """Fit basis coefficients from the sparse sampling points""" + def fit(self, ax, axis=None, *, points=None): + """Fit the basis coefficients from the sampling points. + + Arguments: + ax (array): + Array where the `n`-th item along `axis` corresponds to the + value on the `n`-th sampling point (or value on `point[n]`, if + given.) + axis (integer): + Axis or dimension of `ax` along which to fit the function. + Defaults to the last, i.e., rightmost axis. + points (vector): + Points on which the `ax` is given. Defaults to the sampling + points for which the sampling objects was created. + + Return: + Array where the `l`-th item along `axis` corresponds to the + `l`-th basis coefficient + + Note: + If `points` is given, a new sampling is created at each invocation, + which can result in a performance hit. Consider caching sampling + objects. + """ + if points is not None: + return self._for_sampling_points(points).fit(ax, axis) + matrix = self.matrix if self.basis.is_well_conditioned and not (matrix.cond <= 1e8): warn(f"Sampling matrix is poorly conditioned " @@ -53,6 +105,9 @@ def basis(self): """Basis instance""" raise NotImplementedError() + def _for_sampling_points(self, x): + raise RuntimeError("Changing sampling points is not possible") + class TauSampling(AbstractSampling): """Sparse sampling in imaginary time. @@ -73,7 +128,6 @@ def __init__(self, basis, sampling_points=None): self._sampling_points = sampling_points self._matrix = DecomposedMatrix(matrix) - @property def basis(self): return self._basis @@ -88,6 +142,10 @@ def tau(self): """Sampling points in (reduced) imaginary time""" return self._sampling_points + def _for_sampling_points(self, x): + x = np.asarray(x) + return TauSampling(self._basis, x) + class MatsubaraSampling(AbstractSampling): """Sparse sampling in Matsubara frequencies. @@ -146,6 +204,11 @@ def wn(self): """Sampling points as (reduced) Matsubara frequencies""" return self._sampling_points + def _for_sampling_points(self, x): + x = np.asarray(x) + return MatsubaraSampling(self._basis, x, + positive_only=self._positive_only) + class DecomposedMatrix: """Matrix in SVD decomposed form for fast and accurate fitting. diff --git a/test/test_sampling.py b/test/test_sampling.py index bdf6eae..d47c613 100644 --- a/test/test_sampling.py +++ b/test/test_sampling.py @@ -100,3 +100,36 @@ def test_wn_noise(sve_logistic, stat, lambda_, positive_only): Gl_n = smpl.fit(Giw_n) np.testing.assert_allclose(Gl, Gl_n, atol=12 * np.sqrt(1 + positive_only) * noise * Gl_magn, rtol=0) + + +@pytest.mark.parametrize("stat, lambda_", [('F', 42)]) +@pytest.mark.parametrize("positive_only", [False, True]) +def test_wn_eval_other(sve_logistic, stat, lambda_, positive_only): + basis = sparse_ir.FiniteTempBasis(stat, 1, lambda_, + sve_result=sve_logistic[lambda_]) + smpl = sparse_ir.MatsubaraSampling(basis, positive_only=positive_only) + + n2 = [1, 3, 7] + smpl2 = sparse_ir.MatsubaraSampling(basis, sampling_points=n2) + + rhol = basis.v([+.998, -.01, .5]) @ [0.8, -.2, 0.5] + Gl = basis.s * rhol + Gl_magn = np.linalg.norm(Gl) + np.testing.assert_allclose(smpl.evaluate(Gl, points=n2), smpl2.evaluate(Gl), + rtol=1e-15 * Gl_magn) + + +@pytest.mark.parametrize("stat, lambda_", [('F', 42)]) +def test_tau_eval_other(sve_logistic, stat, lambda_): + basis = sparse_ir.FiniteTempBasis(stat, 1, lambda_, + sve_result=sve_logistic[lambda_]) + smpl = sparse_ir.TauSampling(basis) + + n2 = (0.1, 0.4) + smpl2 = sparse_ir.TauSampling(basis, sampling_points=n2) + + rhol = basis.v([+.998, -.01, .5]) @ [0.8, -.2, 0.5] + Gl = basis.s * rhol + Gl_magn = np.linalg.norm(Gl) + np.testing.assert_allclose(smpl.evaluate(Gl, points=n2), smpl2.evaluate(Gl), + rtol=1e-15 * Gl_magn)