Skip to content

Commit

Permalink
Allow evaluation and fitting on different points.
Browse files Browse the repository at this point in the history
For the users, it is sometimes more convenient to use the sampling
objects to evaluate the basis on arbitrary times and frequencies
rather than go through the basis.u() and basis.uhat() objects.  We
now allow that usecase, even if it may be a little slow.
  • Loading branch information
mwallerb committed Mar 24, 2024
1 parent 88388cf commit 741dfd6
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 5 deletions.
73 changes: 68 additions & 5 deletions src/sparse_ir/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,64 @@ class AbstractSampling:
|________________| fit |___________________|
"""
def evaluate(self, al, axis=None):
"""Evaluate the basis coefficients at the sparse sampling points"""
def evaluate(self, al, axis=None, *, points=None):
"""Evaluate the basis coefficients at sampling points.
Arguments:
al (array):
Array where the `l`-th item along `axis` corresponds to the
`l`-th basis coefficient
axis (integer):
Axis or dimension of `al` along which to evaluate the function.
Defaults to the last, i.e., rightmost axis.
points (vector):
Points on which the results should be evaluated. Defaults
to the sampling points for which the sampling objects was
created.
Return:
Array where the `n`-th item along `axis` corresponds to the
value on the `n`-th sampling point (or value on `point[n]`, if
given.)
Note:
If `points` is given, a new sampling is created at each invocation,
which can result in a performance hit. Consider caching sampling
objects or simply using the `.u()` and `.uhat()` methods of the
underlying basis.
"""
if points is not None:
return self._for_sampling_points(points).evaluate(al, axis)

return self.matrix.matmul(al, axis)

def fit(self, ax, axis=None):
"""Fit basis coefficients from the sparse sampling points"""
def fit(self, ax, axis=None, *, points=None):
"""Fit the basis coefficients from the sampling points.
Arguments:
ax (array):
Array where the `n`-th item along `axis` corresponds to the
value on the `n`-th sampling point (or value on `point[n]`, if
given.)
axis (integer):
Axis or dimension of `ax` along which to fit the function.
Defaults to the last, i.e., rightmost axis.
points (vector):
Points on which the `ax` is given. Defaults to the sampling
points for which the sampling objects was created.
Return:
Array where the `l`-th item along `axis` corresponds to the
`l`-th basis coefficient
Note:
If `points` is given, a new sampling is created at each invocation,
which can result in a performance hit. Consider caching sampling
objects.
"""
if points is not None:
return self._for_sampling_points(points).fit(ax, axis)

matrix = self.matrix
if self.basis.is_well_conditioned and not (matrix.cond <= 1e8):
warn(f"Sampling matrix is poorly conditioned "
Expand Down Expand Up @@ -53,6 +105,9 @@ def basis(self):
"""Basis instance"""
raise NotImplementedError()

def _for_sampling_points(self, x):
raise RuntimeError("Changing sampling points is not possible")


class TauSampling(AbstractSampling):
"""Sparse sampling in imaginary time.
Expand All @@ -73,7 +128,6 @@ def __init__(self, basis, sampling_points=None):
self._sampling_points = sampling_points
self._matrix = DecomposedMatrix(matrix)


@property
def basis(self): return self._basis

Expand All @@ -88,6 +142,10 @@ def tau(self):
"""Sampling points in (reduced) imaginary time"""
return self._sampling_points

def _for_sampling_points(self, x):
x = np.asarray(x)
return TauSampling(self._basis, x)


class MatsubaraSampling(AbstractSampling):
"""Sparse sampling in Matsubara frequencies.
Expand Down Expand Up @@ -146,6 +204,11 @@ def wn(self):
"""Sampling points as (reduced) Matsubara frequencies"""
return self._sampling_points

def _for_sampling_points(self, x):
x = np.asarray(x)
return MatsubaraSampling(self._basis, x,
positive_only=self._positive_only)


class DecomposedMatrix:
"""Matrix in SVD decomposed form for fast and accurate fitting.
Expand Down
33 changes: 33 additions & 0 deletions test/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,36 @@ def test_wn_noise(sve_logistic, stat, lambda_, positive_only):
Gl_n = smpl.fit(Giw_n)
np.testing.assert_allclose(Gl, Gl_n,
atol=12 * np.sqrt(1 + positive_only) * noise * Gl_magn, rtol=0)


@pytest.mark.parametrize("stat, lambda_", [('F', 42)])
@pytest.mark.parametrize("positive_only", [False, True])
def test_wn_eval_other(sve_logistic, stat, lambda_, positive_only):
basis = sparse_ir.FiniteTempBasis(stat, 1, lambda_,
sve_result=sve_logistic[lambda_])
smpl = sparse_ir.MatsubaraSampling(basis, positive_only=positive_only)

n2 = [1, 3, 7]
smpl2 = sparse_ir.MatsubaraSampling(basis, sampling_points=n2)

rhol = basis.v([+.998, -.01, .5]) @ [0.8, -.2, 0.5]
Gl = basis.s * rhol
Gl_magn = np.linalg.norm(Gl)
np.testing.assert_allclose(smpl.evaluate(Gl, points=n2), smpl2.evaluate(Gl),
rtol=1e-15 * Gl_magn)


@pytest.mark.parametrize("stat, lambda_", [('F', 42)])
def test_tau_eval_other(sve_logistic, stat, lambda_):
basis = sparse_ir.FiniteTempBasis(stat, 1, lambda_,
sve_result=sve_logistic[lambda_])
smpl = sparse_ir.TauSampling(basis)

n2 = (0.1, 0.4)
smpl2 = sparse_ir.TauSampling(basis, sampling_points=n2)

rhol = basis.v([+.998, -.01, .5]) @ [0.8, -.2, 0.5]
Gl = basis.s * rhol
Gl_magn = np.linalg.norm(Gl)
np.testing.assert_allclose(smpl.evaluate(Gl, points=n2), smpl2.evaluate(Gl),
rtol=1e-15 * Gl_magn)

0 comments on commit 741dfd6

Please sign in to comment.