diff --git a/DESCRIPTION b/DESCRIPTION index 855f3cb0..5f7d7c46 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -53,7 +53,7 @@ Suggests: VignetteBuilder: knitr LazyData: true -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.2 Encoding: UTF-8 URL: https://github.com/rgcca-factory/RGCCA, https://rgcca-factory.github.io/RGCCA/ diff --git a/R/rgcca.R b/R/rgcca.R index 1721d6ac..e572a68e 100644 --- a/R/rgcca.R +++ b/R/rgcca.R @@ -487,10 +487,15 @@ rgcca <- function(blocks, connection = NULL, tau = 1, ncomp = 1, gcca_args[["blocks"]] <- blocks gcca_args[["disjunction"]] <- opt$disjunction gcca_args[[opt$param]] <- rgcca_args[[opt$param]] - gcca_args <- modifyList(gcca_args, opt$supplementary_parameters) - func_out <- do.call(rgcca_outer_loop, gcca_args) + # if (method == "netsgcca") { + # gcca_args <- modifyList(gcca_args, rgcca_args[c("lambda", "graph_laplacians")]) + # } + # func_out <- do.call(opt$gcca, gcca_args) + # gcca_args <- modifyList(gcca_args, opt$supplementary_parameters) + # func_out <- do.call(opt$gcca, gcca_args) + ### Format the output func_out <- format_output(func_out, rgcca_args, opt, blocks) diff --git a/R/rgcca_predict.R b/R/rgcca_predict.R index 91eb2063..21afeb52 100644 --- a/R/rgcca_predict.R +++ b/R/rgcca_predict.R @@ -24,6 +24,10 @@ #' block is missing.} #' @return \item{model}{A list of the models trained using caret to make the #' predictions and compute the scores.} +#' @return \item{probs}{A list of data.frames with the class probabilities +#' of the test and train response blocks predicted by the prediction +#' model. If the prediction model does not compute class probabilities, the +#' data.frames are empty.} #' @return \item{metric}{A list of data.frames containing the scores obtained #' on the training and testing sets.} #' @return \item{confusion}{A list containing NA for regression tasks. @@ -156,9 +160,15 @@ rgcca_predict <- function(rgcca_res, })) }) + probs <- lapply(c("train", "test"), function(mode) { + as.data.frame(lapply(results, function(res) { + res[["probs"]][[mode]] + })) + }) + confusion <- results[[1]]$confusion - names(prediction) <- names(metric) <- c("train", "test") + names(prediction) <- names(metric) <- names(probs) <- c("train", "test") model <- lapply(results, "[[", "model") score <- mean(unlist(lapply(results, "[[", "score")), na.rm = TRUE) @@ -169,6 +179,7 @@ rgcca_predict <- function(rgcca_res, prediction = prediction, confusion = confusion, metric = metric, + probs = probs, model = model, score = score ) @@ -221,6 +232,8 @@ core_prediction <- function(prediction_model, X_train, X_test, idx_train <- !(is.na(prediction_train$obs) | is.na(prediction_train$pred)) idx_test <- !(is.na(prediction_test$obs) | is.na(prediction_test$pred)) + probs_train <- probs_test <- NULL + if (classification) { confusion_train <- confusionMatrix(prediction_train$pred, reference = prediction_train$obs @@ -228,15 +241,9 @@ core_prediction <- function(prediction_model, X_train, X_test, confusion_test <- confusionMatrix(prediction_test$pred, reference = prediction_test$obs ) - if (is.null(prediction_model$prob)) { - prediction_train <- data.frame(cbind( - prediction_train, - predict(model, X_train, type = "prob") - )) - prediction_test <- data.frame(cbind( - prediction_test, - predict(model, X_test, type = "prob") - )) + if (is.function(prediction_model$prob)) { + probs_train <- data.frame(predict(model, X_train, type = "prob")) + probs_test <- data.frame(predict(model, X_test, type = "prob")) } metric_train <- multiClassSummary( data = prediction_train[idx_train, ], @@ -268,6 +275,7 @@ core_prediction <- function(prediction_model, X_train, X_test, return(list( score = score, model = model, + probs = list(train = probs_train, test = probs_test), metric = list(train = metric_train, test = metric_test), confusion = list(train = confusion_train, test = confusion_test), prediction = list(train = prediction_train, test = prediction_test) diff --git a/R/select_analysis.R b/R/select_analysis.R index 10d9a21a..ec638bcd 100644 --- a/R/select_analysis.R +++ b/R/select_analysis.R @@ -464,6 +464,14 @@ select_analysis <- function(rgcca_args, blocks) { rgcca_args[[param]] <- penalty + ### FIX HERE #### -> netsgcca needs other parameters + if (method == "netsgcca") { + param_list <- list(lambda = lambda, graph_laplacians = graph_laplacians) + } else { + param_list <- list() + } + ### end ### + rgcca_args <- modifyList(rgcca_args, list( ncomp = ncomp, scheme = scheme, @@ -477,8 +485,8 @@ select_analysis <- function(rgcca_args, blocks) { return(list( rgcca_args = rgcca_args, opt = list( - gcca = gcca, - supplementary_parameters = param_list, + # gcca = gcca, + # supplementary_parameters = param_list, param = param ) )) diff --git a/man/rgcca.Rd b/man/rgcca.Rd index 4e3c3df4..63d40a05 100644 --- a/man/rgcca.Rd +++ b/man/rgcca.Rd @@ -17,6 +17,8 @@ rgcca( verbose = FALSE, scale_block = "inertia", method = "rgcca", + lambda = 0, + graph_laplacians = NA, sparsity = 1, response = NULL, superblock = FALSE, diff --git a/man/rgcca_predict.Rd b/man/rgcca_predict.Rd index 7a4c8a37..898d23f0 100644 --- a/man/rgcca_predict.Rd +++ b/man/rgcca_predict.Rd @@ -43,6 +43,11 @@ block is missing.} \item{model}{A list of the models trained using caret to make the predictions and compute the scores.} +\item{probs}{A list of data.frames with the class probabilities +of the test and train response blocks predicted by the prediction +model. If the prediction model does not compute class probabilities, the +data.frames are empty.} + \item{metric}{A list of data.frames containing the scores obtained on the training and testing sets.} diff --git a/tests/testthat/test_rgcca_predict.r b/tests/testthat/test_rgcca_predict.r index 4fa08854..80268c9c 100644 --- a/tests/testthat/test_rgcca_predict.r +++ b/tests/testthat/test_rgcca_predict.r @@ -90,6 +90,12 @@ test_that("rgcca_predict with lm predictor gives the same prediction as expect_equal(as.matrix(A[[response]] - res_predict$prediction$test), res_lm) }) +test_that("rgcca_predict returns an empty probs in regression", { + res_predict <- rgcca_predict(rgcca_res = fit_rgcca) + expect_equal(nrow(res_predict$probs$train), 0) + expect_equal(nrow(res_predict$probs$test), 0) +}) + # Classification #--------------- test_that("rgcca_predict with lda predictor gives the same prediction as @@ -109,3 +115,16 @@ test_that("rgcca_predict with lda predictor gives the same prediction as data.frame(politic = prediction_lda) ) }) + +test_that("rgcca_predict returns probs in classification with adequate model", { + A <- lapply(blocks_classif, function(x) x[1:32, ]) + B <- lapply(blocks_classif, function(x) x[33:47, ]) + response <- 3 + fit_rgcca <- rgcca(A, tau = 1, ncomp = c(3, 2, 1), response = response) + res_predict <- rgcca_predict(fit_rgcca, + blocks_test = B[-3], + prediction_model = "lda" + ) + expect_equal(nrow(res_predict$probs$train), 32) + expect_equal(nrow(res_predict$probs$test), 15) +})