Skip to content

Commit

Permalink
Now use n * p variational parameters for the spherical covariance cas…
Browse files Browse the repository at this point in the history
…e, fixing #79
  • Loading branch information
jchiquet committed Sep 22, 2021
1 parent aaf6d34 commit 06be69a
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 119 deletions.
2 changes: 1 addition & 1 deletion R/PLNLDAfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ PLNLDAfit <- R6Class(
covariates <- cbind(covariates, model.matrix( ~ grouping + 0))
super$postTreatment(responses, covariates, offsets)
rownames(private$B) <- colnames(private$B) <- colnames(responses)
if (private$covariance != "spherical") colnames(private$S2) <- 1:self$q
colnames(private$S2) <- 1:self$q
self$setVisualization()
},

Expand Down
5 changes: 2 additions & 3 deletions R/PLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ PLNfit <- R6Class(
private$Theta <- do.call(rbind, lapply(LMs, coefficients))
residuals <- do.call(cbind, lapply(LMs, residuals))
private$M <- residuals
private$S2 <- matrix(0.1, n, ifelse(control$covariance == "spherical", 1, p))
private$S2 <- matrix(0.1,n,p)
if (control$covariance == "spherical") {
private$Sigma <- diag(sum(residuals^2)/(n*p), p, p)
} else if (control$covariance == "diagonal") {
Expand Down Expand Up @@ -173,8 +173,7 @@ PLNfit <- R6Class(

## Initialize the variational parameters with the appropriate new dimension of the data
optim_out <- VEstep_optimizer(
list(M = matrix(0, n, p),
S = matrix(sqrt(0.1), n, ifelse(self$vcov_model == "spherical", 1, p))),
list(M = matrix(0, n, p), S = matrix(sqrt(0.1), n, p)),
responses,
covariates,
offsets,
Expand Down
6 changes: 1 addition & 5 deletions R/PLNmixturefamily-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,7 @@ PLNmixturefamily <-
myPLN <- PLNfit$new(responses, covariates, offsets, rep(1, nrow(responses)), formula, xlevels, control)
myPLN$optimize(responses, covariates, offsets, rep(1, nrow(responses)), control)

if(control$covariance == 'spherical')
Sbar <- c(myPLN$var_par$S2) * myPLN$p
else
Sbar <- rowSums(myPLN$var_par$S2)

Sbar <- rowSums(myPLN$var_par$S2)
D <- sqrt(as.matrix(dist(myPLN$var_par$M)^2) + outer(Sbar,rep(1,myPLN$n)) + outer(rep(1, myPLN$n), Sbar))

if (is.numeric(control$init_cl)) {
Expand Down
5 changes: 2 additions & 3 deletions R/PLNmixturefit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ PLNmixturefit <-

M <- private$comp %>% map("var_par") %>% map("M")
S2 <- private$comp %>% map("var_par") %>% map("S2")
if(private$covariance == "spherical") S2 <- map(S2, ~outer(as.numeric(.x), rep(1, self$p) ))

mu <- private$comp %>% map(coef) %>% map(~outer(rep(1, self$n), as.numeric(.x)))

Expand Down Expand Up @@ -252,7 +251,7 @@ PLNmixturefit <-
#' @return a [`ggplot`] graphic
plot_clustering_data = function(main = "Expected counts reorder by clustering", plot = TRUE, log_scale = TRUE) {
M <- private$mix_up('var_par$M')
S2 <- switch(private$covariance, "spherical" = private$mix_up('var_par$S2') %*% rbind(rep(1, ncol(M))), private$mix_up('var_par$S2'))
S2 <- private$mix_up('var_par$S2')
mu <- self$posteriorProb %*% t(self$group_means)
A <- exp(mu + M + .5 * S2)
p <- plot_matrix(A, 'samples', 'variables', self$memberships, log_scale)
Expand Down Expand Up @@ -346,7 +345,7 @@ PLNmixturefit <-
#' @field entropy_latent Entropy of the variational distribution of the latent vector (Gaussian)
entropy_latent = function() {
.5 * (sum(map_dbl(private$comp, function(component) {
sum( diag(component$weights) %*% log(component$var_par$S2) * ifelse(component$vcov_model == "spherical", self$p, 1) )
sum( diag(component$weights) %*% log(component$var_par$S2) )
})) + self$n * self$p * log(2*pi*exp(1)))
},
#' @field entropy Full entropy of the variational distribution (latent vector + clustering)
Expand Down
6 changes: 4 additions & 2 deletions inst/case_studies/oaks_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ plot(myLDA_tree_diagonal)
otu.family <- factor(rep(c("fungi", "E. aphiltoides", "bacteria"), c(47, 1, 66)))
plot(myLDA_tree, "variable", var_cols = otu.family) ## TODO: add color for arrows to check

myLDA_tree_spherical <- PLNLDA(Abundance ~ 1 + offset(log(Offset)), grouping = tree, data = oaks, control = list(covariance = "spherical"))
plot(myLDA_tree_spherical)

## One dimensional check of plot
myLDA_orientation <- PLNLDA(Abundance ~ 1 + offset(log(Offset)), grouping = orientation, data = oaks)
plot(myLDA_orientation)
Expand Down Expand Up @@ -91,8 +94,7 @@ data.frame(
ggplot(aes(x = nb_components, y = value, colour = score)) + geom_line() + theme_bw() + labs(y = "clustering similarity", x = "number of components")

## Mixture model to recover tree structure - with covariates
system.time(my_mixtures <- PLNmixture(Abundance ~ 0 + tree + distTOground + offset(log(Offset)), data = oaks,
control_main = list(covariance = "spherical")))
system.time(my_mixtures <- PLNmixture(Abundance ~ 0 + tree + distTOground + offset(log(Offset)), data = oaks))

plot(my_mixtures, criteria = c("loglik", "ICL", "BIC"), reverse = TRUE)

Expand Down
96 changes: 3 additions & 93 deletions src/optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,97 +123,7 @@ Rcpp::List cpp_optimize_spherical(
// Conversion from R, prepare optimization
const auto init_Theta = Rcpp::as<arma::mat>(init_parameters["Theta"]); // (p,d)
const auto init_M = Rcpp::as<arma::mat>(init_parameters["M"]); // (n,p)
const auto init_S = Rcpp::as<arma::vec>(init_parameters["S"]); // (n)

const auto metadata = tuple_metadata(init_Theta, init_M, init_S);
enum { THETA_ID, M_ID, S_ID }; // Names for metadata indexes

auto parameters = std::vector<double>(metadata.packed_size);
metadata.map<THETA_ID>(parameters.data()) = init_Theta;
metadata.map<M_ID>(parameters.data()) = init_M;
metadata.map<S_ID>(parameters.data()) = init_S;

auto optimizer = new_nlopt_optimizer(configuration, parameters.size());
if(configuration.containsElementNamed("xtol_abs")) {
SEXP value = configuration["xtol_abs"];
if(Rcpp::is<double>(value)) {
set_uniform_xtol_abs(optimizer.get(), Rcpp::as<double>(value));
} else {
auto per_param_list = Rcpp::as<Rcpp::List>(value);
auto packed = std::vector<double>(metadata.packed_size);
set_from_r_sexp(metadata.map<THETA_ID>(packed.data()), per_param_list["Theta"]);
set_from_r_sexp(metadata.map<M_ID>(packed.data()), per_param_list["M"]);
set_from_r_sexp(metadata.map<S_ID>(packed.data()), per_param_list["S"]);
set_per_value_xtol_abs(optimizer.get(), packed);
}
}

const double w_bar = accu(w);

// Optimize
auto objective_and_grad = [&metadata, &O, &X, &Y, &w, &w_bar](const double * params, double * grad) -> double {
const arma::mat Theta = metadata.map<THETA_ID>(params);
const arma::mat M = metadata.map<M_ID>(params);
const arma::mat S = metadata.map<S_ID>(params);

arma::mat S2 = S % S;
const arma::uword p = Y.n_cols;
arma::mat Z = O + X * Theta.t() + M;
arma::mat A = exp(Z + 0.5 * S2 * arma::ones(p).t());
double sigma2 = arma::as_scalar(dot(w, sum(pow(M, 2), 1) + double(p) * S2) / (double(p) * w_bar) );
double objective = accu(w.t() * (A - Y % Z)) - 0.5 * double(p) * dot(w, log(S2/sigma2)) ;

metadata.map<THETA_ID>(grad) = (A - Y).t() * (X.each_col() % w);
metadata.map<M_ID>(grad) = diagmat(w) * (M / sigma2 + A - Y);
metadata.map<S_ID>(grad) = w % (double(p) * S / sigma2 + S % sum(A, 1) - double(p) * pow(S, -1));

return objective;
};
OptimizerResult result = minimize_objective_on_parameters(optimizer.get(), objective_and_grad, parameters);

// Variational parameters
arma::mat M = metadata.copy<M_ID>(parameters.data());
arma::mat S = metadata.copy<S_ID>(parameters.data()); // vec(n) -> mat(n, 1)
arma::mat S2 = S % S;
// Regression parameters
arma::mat Theta = metadata.copy<THETA_ID>(parameters.data());
// Variance parameters
const arma::uword p = Y.n_cols;
const double sigma2 = arma::as_scalar(dot(w, sum(pow(M, 2), 1) + double(p) * S2)) / (double(p) * w_bar);
arma::sp_mat Sigma(p,p); Sigma.diag() = arma::ones<arma::vec>(p) * sigma2;
arma::sp_mat Omega(p,p); Omega.diag() = arma::ones<arma::vec>(p) * pow(sigma2, -1);
// Element-wise log-likelihood
arma::mat Z = O + X * Theta.t() + M;
arma::mat A = exp(Z + 0.5 * S2 * arma::ones(p).t());
arma::mat loglik = sum(Y % Z - A - 0.5 * pow(M, 2) / sigma2, 1) - 0.5 * double(p) * S2 / sigma2 +
0.5 * double(p) * log(S2 / sigma2) + ki(Y);

return Rcpp::List::create(
Rcpp::Named("status", static_cast<int>(result.status)),
Rcpp::Named("iterations", result.nb_iterations),
Rcpp::Named("Theta", Theta),
Rcpp::Named("M", M),
Rcpp::Named("S", S),
Rcpp::Named("Z", Z),
Rcpp::Named("A", A),
Rcpp::Named("Sigma", Sigma),
Rcpp::Named("Omega", Omega),
Rcpp::Named("loglik", loglik));
}

// [[Rcpp::export]]
Rcpp::List cpp_optimize_spherical_2(
const Rcpp::List & init_parameters, // List(Theta, M, S)
const arma::mat & Y, // responses (n,p)
const arma::mat & X, // covariates (n,d)
const arma::mat & O, // offsets (n,p)
const arma::vec & w, // weights (n)
const Rcpp::List & configuration // List of config values
) {
// Conversion from R, prepare optimization
const auto init_Theta = Rcpp::as<arma::mat>(init_parameters["Theta"]); // (p,d)
const auto init_M = Rcpp::as<arma::mat>(init_parameters["M"]); // (n,p)
const auto init_S = Rcpp::as<arma::vec>(init_parameters["S"]); // (n,p)
const auto init_S = Rcpp::as<arma::mat>(init_parameters["S"]); // (n,p)

const auto metadata = tuple_metadata(init_Theta, init_M, init_S);
enum { THETA_ID, M_ID, S_ID }; // Names for metadata indexes
Expand Down Expand Up @@ -250,7 +160,7 @@ Rcpp::List cpp_optimize_spherical_2(
const arma::uword p = Y.n_cols;
arma::mat Z = O + X * Theta.t() + M;
arma::mat A = exp(Z + 0.5 * S2);
double sigma2 = arma::as_scalar(dot(w, sum(pow(M, 2) + S2, 1)) / (double(p) * w_bar) );
double sigma2 = accu(diagmat(w) * (pow(M, 2) + S2)) / (double(p) * w_bar) ;
double objective = accu(w.t() * (A - Y % Z - 0.5 * log(S2))) - 0.5 * (double(p) * w_bar) * log(sigma2) ;

metadata.map<THETA_ID>(grad) = (A - Y).t() * (X.each_col() % w);
Expand All @@ -269,7 +179,7 @@ Rcpp::List cpp_optimize_spherical_2(
arma::mat Theta = metadata.copy<THETA_ID>(parameters.data());
// Variance parameters
const arma::uword p = Y.n_cols;
const double sigma2 = arma::as_scalar(dot(w, sum(pow(M, 2) + S2, 1))) / (double(p) * w_bar);
const double sigma2 = accu(diagmat(w) * (pow(M, 2) + S2)) / (double(p) * w_bar) ;
arma::sp_mat Sigma(p,p); Sigma.diag() = arma::ones<arma::vec>(p) * sigma2;
arma::sp_mat Omega(p,p); Omega.diag() = arma::ones<arma::vec>(p) * pow(sigma2, -1);
// Element-wise log-likelihood
Expand Down
24 changes: 12 additions & 12 deletions src/optimize_ve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ Rcpp::List cpp_optimize_vestep_spherical(
) {
// Conversion from R, prepare optimization
const auto init_M = Rcpp::as<arma::mat>(init_parameters["M"]); // (n,p)
const auto init_S = Rcpp::as<arma::vec>(init_parameters["S"]); // (n)
const auto init_S = Rcpp::as<arma::mat>(init_parameters["S"]); // (n)

const auto metadata = tuple_metadata(init_M, init_S);
enum { M_ID, S_ID }; // Names for metadata indexes
Expand Down Expand Up @@ -214,31 +214,31 @@ Rcpp::List cpp_optimize_vestep_spherical(
const arma::mat M = metadata.map<M_ID>(params);
const arma::mat S = metadata.map<S_ID>(params);

arma::vec S2 = S % S;
arma::mat S2 = S % S;
const arma::uword p = Y.n_cols;
arma::mat Z = O + X * Theta.t() + M;
arma::mat A = exp(Z + 0.5 * S2 * arma::ones(p).t());
double n_sigma2 = dot(w, sum(pow(M, 2), 1) + double(p) * S2);
arma::mat A = exp(Z + 0.5 * S2);
double n_sigma2 = accu(diagmat(w) * (pow(M, 2) + S2)) ;
double omega2 = Omega(0, 0);
double objective = accu(w.t() * (A - Y % Z)) - 0.5 * double(p) * dot(w, log(S2)) + 0.5 * n_sigma2 * omega2;
double objective = accu(w.t() * (A - Y % Z - 0.5 * log(S2))) + 0.5 * n_sigma2 * omega2;

metadata.map<M_ID>(grad) = diagmat(w) * (M / omega2 + A - Y);
metadata.map<S_ID>(grad) = diagmat(w) * (S / omega2 + S % A - pow(S, -1));

metadata.map<M_ID>(grad) = diagmat(w) * (M * omega2 + A - Y);
metadata.map<S_ID>(grad) = w % (double(p) * S * omega2 + S % sum(A, 1) - double(p) * pow(S, -1));
return objective;
};
OptimizerResult result = minimize_objective_on_parameters(optimizer.get(), objective_and_grad, parameters);

// Model and variational parameters
arma::mat M = metadata.copy<M_ID>(parameters.data());
arma::mat S = metadata.copy<S_ID>(parameters.data()); // vec(n) -> mat(n, 1)
arma::vec S2 = S % S;
arma::mat S = metadata.copy<S_ID>(parameters.data());
arma::mat S2 = S % S;
double omega2 = Omega(0, 0);
// Element-wise log-likelihood
const arma::uword p = Y.n_cols;
arma::mat Z = O + X * Theta.t() + M;
arma::mat A = exp(Z + 0.5 * S2 * arma::ones(p).t());
arma::mat loglik = sum(Y % Z - A - 0.5 * pow(M, 2) * omega2, 1) - 0.5 * double(p) * omega2 * S2 +
0.5 * double(p) * log(S2 * omega2) + ki(Y);
arma::mat A = exp(Z + 0.5 * S2);
arma::mat loglik = sum(Y % Z - A - 0.5 * (pow(M, 2) + S2 ) * omega2 + 0.5 * log(S2 * omega2), 1) + ki(Y);

return Rcpp::List::create(
Rcpp::Named("status") = (int)result.status,
Expand Down

0 comments on commit 06be69a

Please sign in to comment.