Skip to content

Commit

Permalink
Merge pull request #584 from Sand-jrd/temporal-pca-local
Browse files Browse the repository at this point in the history
Annular Temporal PCA
  • Loading branch information
VChristiaens authored Mar 30, 2023
2 parents 4114d47 + 8b45d22 commit 3282af8
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 27 deletions.
5 changes: 5 additions & 0 deletions tests/test_pipeline_adi.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def algo_pca_annular(ds):
return vip.psfsub.pca_annular(ds.cube, ds.angles, fwhm=ds.fwhm,
n_segments='auto')

def algo_pca_annular_left_eigv(ds):
return vip.psfsub.pca_annular(ds.cube, ds.angles, fwhm=ds.fwhm,
n_segments='auto', left_eigv=True)

def algo_pca_annular_auto(ds):
return vip.psfsub.pca_annular(ds.cube, ds.angles, fwhm=ds.fwhm,
ncomp='auto')
Expand Down Expand Up @@ -173,6 +177,7 @@ def verify_expcoord(vectory, vectorx, exp_yx):
(algo_pca_grid, snrmap_fast),
(algo_pca_incremental, snrmap_fast),
(algo_pca_annular, snrmap_fast),
(algo_pca_annular_left_eigv, snrmap_fast),
(algo_pca_annular_auto, snrmap_fast),
],
ids=lambda x: (x.__name__.replace("algo_", "") if callable(x) else x))
Expand Down
80 changes: 57 additions & 23 deletions vip_hci/psfsub/pca_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def pca_annular(
cube_sig=None,
full_output=False,
verbose=True,
left_eigv=False,
**rot_options
):
"""PCA model PSF subtraction for ADI, ADI+RDI or ADI+mSDI (IFS) data.
Expand Down Expand Up @@ -244,6 +245,12 @@ def pca_annular(
global start_time
start_time = time_ini()

if left_eigv :
if (cube_ref is not None) or (cube_sig is not None) or (ncomp=='auto'):
raise NotImplementedError( "left_eigv is not compatible"
"with 'cube_ref', 'cube_sig', ncomp='auto'"
)

# ADI or ADI+RDI data
if cube.ndim == 3:
res = _pca_adi_rdi(
Expand All @@ -270,6 +277,7 @@ def pca_annular(
theta_init,
weights,
cube_sig,
left_eigv,
**rot_options
)

Expand Down Expand Up @@ -328,6 +336,7 @@ def pca_annular(
theta_init,
weights,
cube_sig,
left_eigv,
**rot_options
)
cube_out.append(res_pca[0])
Expand Down Expand Up @@ -442,6 +451,7 @@ def pca_annular(
theta_init,
weights,
cube_sig,
left_eigv=left_eigv,
**rot_options
)
if full_output:
Expand Down Expand Up @@ -595,6 +605,7 @@ def _pca_adi_rdi(
theta_init=0,
weights=None,
cube_sig=None,
left_eigv=False,
**rot_options
):
"""PCA exploiting angular variability (ADI fashion)."""
Expand All @@ -604,7 +615,7 @@ def _pca_adi_rdi(
if array.shape[0] != angle_list.shape[0]:
raise TypeError("Input vector or parallactic angles has wrong length")

n, y, _ = array.shape
n, y, x = array.shape

angle_list = check_pa_vector(angle_list)
n_annuli = int((y / 2 - radius_int) / asize)
Expand Down Expand Up @@ -665,6 +676,13 @@ def _pca_adi_rdi(
indices = get_annulus_segments(
array[0], inner_radius, asize, n_segments_ann, theta_init
)

if left_eigv :
indices_out = get_annulus_segments(array[0], inner_radius, asize,
n_segments_ann, theta_init,
out=True
)

# Library matrix is created for each segment and scaled if needed
for j in range(n_segments_ann):
yy = indices[j][0]
Expand All @@ -681,28 +699,44 @@ def _pca_adi_rdi(
else:
matrix_sig_segm = None

res = pool_map(
nproc,
do_pca_patch,
matrix_segm,
iterable(range(n)),
angle_list,
fwhm,
pa_thr,
ann_center,
svd_mode,
ncompann,
min_frames_lib,
max_frames_lib,
tol,
matrix_segm_ref,
matrix_sig_segm,
)

res = np.array(res, dtype=object)
residuals = np.array(res[:, 0])
ncomps = res[:, 1]
nfrslib = res[:, 2]
if not left_eigv:
res = pool_map(
nproc,
do_pca_patch,
matrix_segm,
iterable(range(n)),
angle_list,
fwhm,
pa_thr,
ann_center,
svd_mode,
ncompann,
min_frames_lib,
max_frames_lib,
tol,
matrix_segm_ref,
matrix_sig_segm,
)

res = np.array(res, dtype=object)
residuals = np.array(res[:, 0])
ncomps = res[:, 1]
nfrslib = res[:, 2]
else:
yy_out = indices_out[j][0]
xx_out = indices_out[j][1]
matrix_out_segm = array[:, yy_out, xx_out] # shape [nframes x npx_out_segment]
matrix_out_segm = matrix_scaling(matrix_out_segm, scaling)

V = get_eigenvectors(
ncomp, matrix_out_segm, svd_mode, noise_error=tol, left_eigv=True
)

transformed = np.dot(V, matrix_segm.T)
reconstructed = np.dot(transformed.T, V)
residuals = matrix_segm - reconstructed
nfrslib = matrix_out_segm.shape[0]

for fr in range(n):
cube_out[fr][yy, xx] = residuals[fr]

Expand Down
4 changes: 2 additions & 2 deletions vip_hci/psfsub/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def svd_wrapper(matrix, mode, ncomp, verbose, full_output=False,

def get_eigenvectors(ncomp, data, svd_mode, mode='noise', noise_error=1e-3,
cevr=0.9, max_evs=None, data_ref=None, debug=False,
collapse=False, scaling=None):
collapse=False, scaling=None, left_eigv=False):
""" Getting ``ncomp`` eigenvectors. Choosing the size of the PCA truncation
when ``ncomp`` is set to ``auto``. Used in ``pca_annular`` and ``llsg``.
"""
Expand Down Expand Up @@ -683,7 +683,7 @@ def get_eigenvectors(ncomp, data, svd_mode, mode='noise', noise_error=1e-3,
else:
# Performing SVD/PCA according to "svd_mode" flag
ncomp = min(ncomp, min(data_ref.shape[0], data_ref.shape[1]))
V = svd_wrapper(data_ref, svd_mode, ncomp, verbose=False)
V = svd_wrapper(data_ref, svd_mode, ncomp, verbose=False, left_eigv=False)

return V

Expand Down
7 changes: 5 additions & 2 deletions vip_hci/var/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def distance(yc, xc, y1, x1):


def get_annulus_segments(data, inner_radius, width, nsegm=1, theta_init=0,
optim_scale_fact=1, mode="ind"):
optim_scale_fact=1, mode="ind", out=False):
"""
Return indices or values in segments of a centered annulus.
Expand All @@ -415,7 +415,8 @@ def get_annulus_segments(data, inner_radius, width, nsegm=1, theta_init=0,
mode : {'ind', 'val', 'mask'}, optional
Controls what is returned: indices of selected pixels, values of
selected pixels, or a boolean mask.
out : bool; optional
Return all indices or values outside the centered annulus.
Returns
-------
indices : list of ndarrays
Expand Down Expand Up @@ -484,6 +485,8 @@ def get_annulus_segments(data, inner_radius, width, nsegm=1, theta_init=0,
masks.append((rad >= inner_radius) & (rad < outer_radius) &
(phirot >= phi_start) & (phirot < phi_end))

if out: mask = ~np.array(masks)

if mode == "ind":
return [np.where(mask) for mask in masks]
elif mode == "val":
Expand Down

0 comments on commit 3282af8

Please sign in to comment.