From 811bcacca076415ac40fcb0d37bf78f896b52454 Mon Sep 17 00:00:00 2001 From: GFabien Date: Mon, 22 Apr 2024 15:23:43 +0200 Subject: [PATCH] Add probs as output of rgcca_predict --- R/rgcca_predict.R | 24 ++++++++++++++---------- tests/testthat/test_rgcca_predict.r | 19 +++++++++++++++++++ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/R/rgcca_predict.R b/R/rgcca_predict.R index 91eb2063..68af15cf 100644 --- a/R/rgcca_predict.R +++ b/R/rgcca_predict.R @@ -156,9 +156,15 @@ rgcca_predict <- function(rgcca_res, })) }) + probs <- lapply(c("train", "test"), function(mode) { + as.data.frame(lapply(results, function(res) { + res[["probs"]][[mode]] + })) + }) + confusion <- results[[1]]$confusion - names(prediction) <- names(metric) <- c("train", "test") + names(prediction) <- names(metric) <- names(probs) <- c("train", "test") model <- lapply(results, "[[", "model") score <- mean(unlist(lapply(results, "[[", "score")), na.rm = TRUE) @@ -169,6 +175,7 @@ rgcca_predict <- function(rgcca_res, prediction = prediction, confusion = confusion, metric = metric, + probs = probs, model = model, score = score ) @@ -221,6 +228,8 @@ core_prediction <- function(prediction_model, X_train, X_test, idx_train <- !(is.na(prediction_train$obs) | is.na(prediction_train$pred)) idx_test <- !(is.na(prediction_test$obs) | is.na(prediction_test$pred)) + probs_train <- probs_test <- NULL + if (classification) { confusion_train <- confusionMatrix(prediction_train$pred, reference = prediction_train$obs @@ -228,15 +237,9 @@ core_prediction <- function(prediction_model, X_train, X_test, confusion_test <- confusionMatrix(prediction_test$pred, reference = prediction_test$obs ) - if (is.null(prediction_model$prob)) { - prediction_train <- data.frame(cbind( - prediction_train, - predict(model, X_train, type = "prob") - )) - prediction_test <- data.frame(cbind( - prediction_test, - predict(model, X_test, type = "prob") - )) + if (is.function(prediction_model$prob)) { + probs_train <- data.frame(predict(model, X_train, type = "prob")) + probs_test <- data.frame(predict(model, X_test, type = "prob")) } metric_train <- multiClassSummary( data = prediction_train[idx_train, ], @@ -268,6 +271,7 @@ core_prediction <- function(prediction_model, X_train, X_test, return(list( score = score, model = model, + probs = list(train = probs_train, test = probs_test), metric = list(train = metric_train, test = metric_test), confusion = list(train = confusion_train, test = confusion_test), prediction = list(train = prediction_train, test = prediction_test) diff --git a/tests/testthat/test_rgcca_predict.r b/tests/testthat/test_rgcca_predict.r index 4fa08854..80268c9c 100644 --- a/tests/testthat/test_rgcca_predict.r +++ b/tests/testthat/test_rgcca_predict.r @@ -90,6 +90,12 @@ test_that("rgcca_predict with lm predictor gives the same prediction as expect_equal(as.matrix(A[[response]] - res_predict$prediction$test), res_lm) }) +test_that("rgcca_predict returns an empty probs in regression", { + res_predict <- rgcca_predict(rgcca_res = fit_rgcca) + expect_equal(nrow(res_predict$probs$train), 0) + expect_equal(nrow(res_predict$probs$test), 0) +}) + # Classification #--------------- test_that("rgcca_predict with lda predictor gives the same prediction as @@ -109,3 +115,16 @@ test_that("rgcca_predict with lda predictor gives the same prediction as data.frame(politic = prediction_lda) ) }) + +test_that("rgcca_predict returns probs in classification with adequate model", { + A <- lapply(blocks_classif, function(x) x[1:32, ]) + B <- lapply(blocks_classif, function(x) x[33:47, ]) + response <- 3 + fit_rgcca <- rgcca(A, tau = 1, ncomp = c(3, 2, 1), response = response) + res_predict <- rgcca_predict(fit_rgcca, + blocks_test = B[-3], + prediction_model = "lda" + ) + expect_equal(nrow(res_predict$probs$train), 32) + expect_equal(nrow(res_predict$probs$test), 15) +})