Skip to content

Commit

Permalink
Gp mc cov fix (#57)
Browse files Browse the repository at this point in the history
* gp_mc_cov_fix do off-block diags

* bugfix: set smoothed_mc_cov_anaflat[sel] = np.diag(smoothed_block_diag), not smoothed_mc_cov_anaflat[sel] = smoothed_block_diag

* remove deprecated correct_analytical_cov_eigenspectrum_ratio_gp

* doc update
  • Loading branch information
zatkins2 authored Oct 22, 2024
1 parent c0d7b12 commit ddbd63e
Showing 1 changed file with 52 additions and 36 deletions.
88 changes: 52 additions & 36 deletions pspipe_utils/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,21 +368,21 @@ def skew(cov, dir=1):
return corrected_cov



def correct_analytical_cov_eigenspectrum_ratio_gp(lb, an_full_cov, mc_full_cov,
var_eigenspectrum_ratios_by_block=None,
idx_arrs_by_block=None, return_all=False):
def correct_analytical_cov_block_diag_gp(lb, an_full_cov, mc_full_cov,
var_mc_cov_anaflat=None,
idx_arrs_by_block=None, return_all=False):
"""Correct an analytical covariance matrix with a monte carlo covariance
matrix. Assumes the following rotated monte carlo matrix is diagonal:
matrix. Assumes the following rotated monte carlo matrix is diagonal
in each cross-block:
mc_rot = (ana**-.5) @ mc @ (ana**-.5).T
and then fits the diagonal in that basis with a Gaussian process (GP) on
the observed values:
and then fits the cross-block diagonals in that basis with a Gaussian process
(GP) on the observed values:
ana_corrected = (ana**.5) @ np.diag(GP(np.diag(mc_rot))) @ (ana**.5).T
ana_corrected = (ana**.5) @ GP(mc_rot) @ (ana**.5).T
The GP is applied to each block diagonal separately.
The GP is applied to each diagonal separately.
Parameters
----------
Expand All @@ -392,10 +392,10 @@ def correct_analytical_cov_eigenspectrum_ratio_gp(lb, an_full_cov, mc_full_cov,
Analytic covariance matrix to be corrected.
mc_full_cov : (nblock*nbin, nblock*nbin) np.ndarray
Noisy monte carlo matrix to use to correct the analytic matrix.
var_eigenspectrum_ratios_by_block : (nblock, nbin) np.ndarray, optional
The variance of the diagonal elements of mc_rot, by default None. These
var_mc_cov_anaflat : (nblock*nbin, nblock*nbin) np.ndarray, optional
The variance of the elements of mc_rot, by default None. These
would need to be precomputed in the rotated basis. If None, the
noise level in the diagonal elements of mc_rot is estimated from the
noise level in the elements of mc_rot is estimated from the
elements themselves by the Gaussian process.
idx_arrs_by_block : (nblock,) list, optional
A list of np.ndarrays, each of which may have between 0 and nbin elements,
Expand All @@ -407,51 +407,67 @@ def correct_analytical_cov_eigenspectrum_ratio_gp(lb, an_full_cov, mc_full_cov,
observed values are used without any smoothing applied to them.
return_all : bool, optional
If True, in addition to returning ana_corrected, also return the
diagonal of the mc_rot matrix, the smoothed diagonal, and a list
mc_rot matrix, the smoothed mc_rot matrix, and a dict
of the Gaussian process regression objects for each block,
by default False.
Returns
-------
(nblock*nbin, nblock*nbin) np.ndarray, {(nblock*nbin,) np.ndarray,
(nblock*nbin,) np.ndarray, (nblock,) list of sklearn.GaussianProcessRegressor
objects}
ana_corrected as above, and if return_all, np.diag(mc_rot),
GP(np.diag(mc_rot)), and a list of the GPs for each block. If for any
(nblock*nbin, nblock*nbin) np.ndarray, {(nblock*nbin, nblock*nbin) np.ndarray,
(nblock*nbin, nblock*nbin) np.ndarray, dict[(nblock, nblock)] of
sklearn.GaussianProcessRegressor objects}
ana_corrected as above, and if return_all, mc_rot,
GP(mc_rot), and a dict of the GPs for each block. If for any
block the idx_arrs_by_block array is empty, the returned GP is None
since no GP was actually used for that block.
"""
sqrt_an_full_cov = utils.eigpow(an_full_cov, 0.5)
inv_sqrt_an_full_cov = np.linalg.inv(sqrt_an_full_cov)
res = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov.T # res should be close to the identity if an_full_cov is good
res_diag = np.diag(res)

n_spec = len(res_diag) // len(lb)

res_diags = np.split(res_diag, n_spec)

if var_eigenspectrum_ratios_by_block is None:
var_eigenspectrum_ratios_by_block = [None] * n_spec
# mc_cov_anaflat should be close to the identity if an_full_cov is good
mc_cov_anaflat = inv_sqrt_an_full_cov @ mc_full_cov @ inv_sqrt_an_full_cov.T
n_spec = mc_cov_anaflat.shape[0] // len(lb)
n_bins = len(lb)

if idx_arrs_by_block is None:
idx_arrs_by_block = [None] * n_spec

smoothed_res_diags = []
gprs = []
# iterate over each block diagonal
smoothed_mc_cov_anaflat = np.zeros_like(mc_cov_anaflat)
gprs = {}
for i in range(n_spec):
smoothed_gp_diag, gpr = smooth_gp_diag(lb, res_diags[i], var_eigenspectrum_ratios_by_block[i],
idx_arr=idx_arrs_by_block[i], return_gpr=True)
smoothed_res_diags.append(smoothed_gp_diag)
gprs.append(gpr)
for j in range(i, n_spec):
sel = np.s_[i*n_bins:(i+1)*n_bins, j*n_bins:(j+1)*n_bins]

# diag of this block and var of this block (None if not supplied)
block_diag = np.diag(mc_cov_anaflat[sel])
if var_mc_cov_anaflat is not None:
var_block_diag = np.diag(var_mc_cov_anaflat[sel])
else:
var_block_diag = None

# which idxs to use
idxs_i = idx_arrs_by_block[i]
idxs_j = idx_arrs_by_block[j]
if idxs_i is not None and idxs_j is not None:
idxs = np.intersect1d(idxs_i, idxs_j)
elif idxs_i is None:
idxs = idxs_j # idxs_i is "all idxs" so use idxs_j (which might also be "all idxs")
else:
idxs = idxs_i # idxs_j is "all idxs" so use idxs_i (which might also be "all idxs")

smoothed_res_diag = np.hstack(smoothed_res_diags)
smoothed_block_diag, gpr = smooth_gp_diag(lb, block_diag, var_block_diag,
idx_arr=idxs, return_gpr=True)
smoothed_mc_cov_anaflat[sel] = np.diag(smoothed_block_diag)
smoothed_mc_cov_anaflat.T[sel] = np.diag(smoothed_block_diag)
gprs[i, j] = gpr

corrected_cov = sqrt_an_full_cov @ np.diag(smoothed_res_diag) @ sqrt_an_full_cov.T
corrected_mc_cov = sqrt_an_full_cov @ smoothed_mc_cov_anaflat @ sqrt_an_full_cov.T

if return_all:
return corrected_cov, res_diag, smoothed_res_diag, gprs
return corrected_mc_cov, mc_cov_anaflat, smoothed_mc_cov_anaflat, gprs
else:
return corrected_cov
return corrected_mc_cov


def smooth_gp_diag(lb, arr_diag, var_diag=None, idx_arr=None,
Expand Down

0 comments on commit ddbd63e

Please sign in to comment.