Skip to content

Commit

Permalink
Add implementation for TGCCA with tau != 1 and separable = FALSE
Browse files Browse the repository at this point in the history
  • Loading branch information
GFabien committed Jan 12, 2024
1 parent 7e7d8be commit 9fc9c4e
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ S3method(block_project,sparse_block)
S3method(block_project,tensor_block)
S3method(block_update,block)
S3method(block_update,dual_block)
S3method(block_update,regularized_tensor_block)
S3method(block_update,tensor_block)
S3method(plot,rgcca)
S3method(plot,rgcca_bootstrap)
Expand Down
29 changes: 28 additions & 1 deletion R/block_init.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,34 @@ block_init.tensor_block <- function(x, init = "svd") {

#' @export
block_init.regularized_tensor_block <- function(x, init = "svd") {
NextMethod()
# Compute the highest singular value of the regularization matrix
p <- prod(dim(x$x)[-1])
if (p > x$n) {
x$M <- eigen(
pm(matrix(x$x, x$n), t(matrix(x$x, x$n)), na.rm = x$na.rm),
symmetric = TRUE, only.values = TRUE
)$values[1]
} else {
x$M <- eigen(
pm(t(matrix(x$x, x$n)), matrix(x$x, x$n), na.rm = x$na.rm),
symmetric = TRUE, only.values = TRUE
)$values[1]
}
x$M <- x$tau + (1 - x$tau) * x$M / x$N

# Initialize the factors and weights using the tau = 1 strategy
x <- NextMethod()

# Change weights to satisfy the constraints
x$weights <- x$weights / sqrt(x$M)
x$a <- x$a / sqrt(x$M)
x$Y <- x$Y / sqrt(x$M)
return(x)
}

#' @export
block_init.separable_regularized_tensor_block <- function(x, init = "svd") {
# Compute separable estimation of the regularization matrix
d <- length(dim(x$x)) - 1
x$M <- estimate_separable_covariance(x$x)
x$M <- lapply(x$M, function(y) {
Expand All @@ -77,8 +100,12 @@ block_init.separable_regularized_tensor_block <- function(x, init = "svd") {
inv = TRUE
)
})

# Make a change of variables
for (m in seq_len(d)) {
x$x <- mode_product(x$x, x$M[[m]], m = m + 1)
}

# Initialize the factors and weights using the tau = 1 strategy
NextMethod()
}
60 changes: 60 additions & 0 deletions R/block_update.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,63 @@ block_update.tensor_block <- function(x, grad) {
x$weights <- drop(x$weights) / norm(drop(x$weights), type = "2")
return(block_project(x))
}

#' @export
block_update.regularized_tensor_block <- function(x, grad) {
grad <- array(
pm(t(matrix(x$x, nrow = nrow(x$x))), grad, na.rm = x$na.rm),
dim = dim(x$x)[-1]
)
other_factors <- NULL
# Update factors
for (m in seq_along(dim(x$x)[-1])) {
grad_m <- matrix(
aperm(grad, c(m, seq_along(dim(grad))[-m])), nrow = dim(x$x)[m + 1]
)
grad_m <- grad_m %*% khatri_rao(
Reduce(khatri_rao, rev(x$factors[-seq_len(m)])), other_factors
) %*% diag(x$weights, nrow = x$rank)
if (m == x$mode_orth) {
SVD <- svd(grad_m, nu = x$rank, nv = x$rank)
x$factors[[m]] <- SVD$u %*% t(SVD$v)
} else {
x$factors[[m]] <- apply(grad_m, 2, function(y) y / norm(y, type = "2"))
}

other_factors <- khatri_rao(x$factors[[m]], other_factors)
}
# Update weights
u <- drop(t(other_factors) %*% as.vector(grad))

w_ref <- drop(ginv(
x$tau * diag(x$rank) + (1 - x$tau) * crossprod(
pm(matrix(x$x, x$n), other_factors, na.rm = x$na.rm)
) / x$N
) %*% u)
w_ref_norm <- w_ref / (norm(w_ref, type = "2") * sqrt(x$M))

w_opt <- u / (norm(u, type = "2") * sqrt(x$M))

eps <- 0.5 * drop(t(u) %*% (x$weights + w_opt))

# If w_ref is satisfying, keep w_ref, otherwise find a point that
# increases the criterion and satisfies the constraints between
# w_ref and w_opt
if (all(w_ref_norm == w_opt)) {
x$weights <- w_ref_norm
}
else if (t(u) %*% w_ref_norm >= eps) {
x$weights <- w_ref_norm
} else {
if (1 / x$M - eps^2 / crossprod(u) > .Machine$double.eps) {
Pu <- diag(x$rank) - tcrossprod(u / norm(u, type = "2"))
mu <- norm(Pu %*% w_ref, type = "2") / drop(sqrt(
1 / x$M - eps^2 / crossprod(u)
))
x$weights <- eps / drop(crossprod(u)) * u + drop(Pu %*% w_ref) / mu
} else { # collinearity between u and w_ref
x$weights <- w_opt
}
}
return(block_project(x))
}

0 comments on commit 9fc9c4e

Please sign in to comment.