From b446afec0a7e2899bb5ba135639459ca001d3eea Mon Sep 17 00:00:00 2001 From: Grant Irvine-Smith Date: Thu, 7 Jan 2021 19:54:27 +0100 Subject: [PATCH] Joint feature importance Allows for the calculation of joint importance scores for group(s) of features. You can now specify which features/groups you want importance scores calculated on. This is usefull for large datasets where permuting all features is computationally expensive. --- R/FeatureImp.R | 66 +++++++++++++++++++++++++------- man/FeatureImp.Rd | 30 ++++++++++++++- tests/testthat/test-FeatureImp.R | 46 ++++++++++++++++++++++ 3 files changed, 128 insertions(+), 14 deletions(-) diff --git a/R/FeatureImp.R b/R/FeatureImp.R index 9c8a14a18..803c4ed10 100644 --- a/R/FeatureImp.R +++ b/R/FeatureImp.R @@ -76,6 +76,18 @@ #' # Plot the results directly #' plot(imp) #' +#' # We can calculate feature importance for a subset of features +#' imp <- FeatureImp$new(mod, loss = "mae", features = c("crim", "zn", "indus")) +#' plot(imp) +#' +#'# We can calculate joint importance of groups of features +#'groups = list( +#' grp1 = c("crim", "zn", "indus", "chas"), +#' grp2 = c("nox", "rm", "age", "dis"), +#' grp3 = c("rad", "tax", "ptratio", "black", "lstat") +#') +#'imp <- FeatureImp$new(mod, loss = "mae", features = groups) +#'plot(imp) #' #' # FeatureImp also works with multiclass classification. #' # In this case, the importance measurement regards all classes @@ -122,6 +134,11 @@ FeatureImp <- R6::R6Class("FeatureImp", #' How often should the shuffling of the feature be repeated? #' The higher the number of repetitions the more stable and accurate the #' results become. + #' @param features (`character or list`)\cr + #' For which features do you want importance scores calculated. The default + #' value of `NULL` implies all features. Use a named list of character vectors + #' to define groups of features for which joint importance will be calculated. + #' See examples. #' @return (data.frame)\cr #' data.frame with the results of the feature importance computation. One #' row per feature with the following columns: @@ -134,11 +151,13 @@ FeatureImp <- R6::R6Class("FeatureImp", #' plots, the median importance over the repetitions as a point. #' initialize = function(predictor, loss, compare = "ratio", - n.repetitions = 5) { + n.repetitions = 5, features = NULL) { assert_choice(compare, c("ratio", "difference")) assert_number(n.repetitions) self$compare <- compare + + if (!inherits(loss, "function")) { ## Only allow metrics from Metrics package allowedLosses <- c( @@ -166,6 +185,18 @@ FeatureImp <- R6::R6Class("FeatureImp", warning("Model error is 0, switching from compare='ratio' to compare='difference'") self$compare <- "difference" } + + # process features argument + if (is.null(features)) { + features <- private$sampler$feature.names + } + if (!is.list(features)) { + features <- as.list(features) + names(features) <- unlist(features) + } + assert_subset(unique(unlist(features)), private$sampler$feature.names, empty.ok = FALSE) + self$features <- features + # suppressing package startup messages suppressPackageStartupMessages(private$run(self$predictor$batch.size)) }, @@ -187,7 +218,13 @@ FeatureImp <- R6::R6Class("FeatureImp", #' depending on whether the importance was calculated as difference #' between original model error and model error after permutation or as #' ratio. - compare = NULL + compare = NULL, + + #' @field features (`list`)\cr Features for which importance scores are to + #' be calculated. The names are the feature/group names, while the contents + #' specify which feature(s) are to be permuted. + features = NULL + ), private = list( @@ -204,20 +241,21 @@ FeatureImp <- R6::R6Class("FeatureImp", private$dataSample <- private$getData() result <- NULL - estimate_feature_imp <- function(feature, + estimate_feature_imp <- function(group, + features, data.sample, y, n.repetitions, y.names, pred, loss) { - + cnames <- setdiff(colnames(data.sample), y.names) qResults <- data.table::data.table() y.vec <- data.table::data.table() for (repi in 1:n.repetitions) { mg <- MarginalGenerator$new(data.sample, data.sample, - features = feature, n.sample.dist = 1, y = y, cartesian = FALSE, + features = features, n.sample.dist = 1, y = y, cartesian = FALSE, id.dist = TRUE ) while (!mg$finished) { @@ -231,7 +269,7 @@ FeatureImp <- R6::R6Class("FeatureImp", } # AGGREGATE measurements results <- data.table::data.table( - feature = feature, actual = y.vec[[1]], predicted = qResults[[1]], + feature = group, actual = y.vec[[1]], predicted = qResults[[1]], num_rep = rep(1:n.repetitions, each = nrow(data.sample)) ) results <- results[, list("permutation_error" = loss(actual, predicted)), @@ -248,19 +286,21 @@ FeatureImp <- R6::R6Class("FeatureImp", loss <- self$loss result <- rbindlist(unname( - future.apply::future_lapply(private$sampler$feature.names, function(x) { - estimate_feature_imp(x, + future.apply::future_mapply(estimate_feature_imp, + group = names(self$features), + features = self$features, + MoreArgs = list( data.sample = data.sample, y = y, n.repetitions = n.repetitions, y.names = y.names, pred = pred, loss = loss - ) - }, - future.seed = TRUE, - future.globals = FALSE, - future.packages = loadedNamespaces() + ), + SIMPLIFY = FALSE, + future.seed = TRUE, + future.globals = FALSE, + future.packages = loadedNamespaces() ) ), use.names = TRUE) diff --git a/man/FeatureImp.Rd b/man/FeatureImp.Rd index e4535c239..4709b084c 100644 --- a/man/FeatureImp.Rd +++ b/man/FeatureImp.Rd @@ -80,6 +80,18 @@ imp <- FeatureImp$new(mod, loss = "mae", compare = "difference") # Plot the results directly plot(imp) +# We can calculate feature importance for a subset of features +imp <- FeatureImp$new(mod, loss = "mae", features = c("crim", "zn", "indus")) +plot(imp) + +# We can calculate joint importance of groups of features +groups = list( + grp1 = c("crim", "zn", "indus", "chas"), + grp2 = c("nox", "rm", "age", "dis"), + grp3 = c("rad", "tax", "ptratio", "black", "lstat") +) +imp <- FeatureImp$new(mod, loss = "mae", features = groups) +plot(imp) # FeatureImp also works with multiclass classification. # In this case, the importance measurement regards all classes @@ -127,6 +139,10 @@ Number of repetitions.} depending on whether the importance was calculated as difference between original model error and model error after permutation or as ratio.} + +\item{\code{features}}{(\code{list})\cr Features for which importance scores are to +be calculated. The names are the feature/group names, while the contents +specify which feature(s) are to be permuted.} } \if{html}{\out{}} } @@ -151,7 +167,13 @@ ratio.} \subsection{Method \code{new()}}{ Create a FeatureImp object \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{FeatureImp$new(predictor, loss, compare = "ratio", n.repetitions = 5)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{FeatureImp$new( + predictor, + loss, + compare = "ratio", + n.repetitions = 5, + features = NULL +)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -179,6 +201,12 @@ original model error and model error after permutation? How often should the shuffling of the feature be repeated? The higher the number of repetitions the more stable and accurate the results become.} + +\item{\code{features}}{(\verb{character or list})\cr +For which features do you want importance scores calculated. The default +value of \code{NULL} implies all features. Use a named list of character vectors +to define groups of features for which joint importance will be calculated. +See examples.} } \if{html}{\out{}} } diff --git a/tests/testthat/test-FeatureImp.R b/tests/testthat/test-FeatureImp.R index 8c8e43d1c..5ab778c62 100644 --- a/tests/testthat/test-FeatureImp.R +++ b/tests/testthat/test-FeatureImp.R @@ -140,3 +140,49 @@ test_that("Feature Importance 0", { fimp <- FeatureImp$new(pred, loss = "mae", n.repetitions = 3) expect_equal(fimp$results$importance[3], 1) }) + +test_that("FeatureImp works for a subset of features", { + var.imp <- FeatureImp$new(predictor1, loss = "mse", features = c("a", "b")) + dat <- var.imp$results + expect_class(dat, "data.frame") + expect_false("data.table" %in% class(dat)) + expect_equal(colnames(dat), expected_colnames) + expect_equal(nrow(dat), 2) + p <- plot(var.imp) + expect_s3_class(p, c("gg", "ggplot")) + p +}) + +test_that("Invalid feature names are caught", { + expect_error( + FeatureImp$new(predictor1, loss = "mse", features = c("x", "y", "z")), + "failed: Must be a subset of {'a','b','c','d'}, but is {'x','y','z'}", + fixed = TRUE + ) +}) + +test_that("FeatureImp works for groups of features", { + groups = list(ab = c("a", "b"), cd = c("c", "d")) + var.imp <- FeatureImp$new(predictor1, loss = "mse", features = groups) + dat <- var.imp$results + expect_class(dat, "data.frame") + expect_false("data.table" %in% class(dat)) + expect_equal(colnames(dat), expected_colnames) + expect_equal(nrow(dat), 2) + p <- plot(var.imp) + expect_s3_class(p, c("gg", "ggplot")) + p +}) + +test_that("FeatureImp works for overlapping groups of features", { + groups = list(ab = c("a", "b"), bc = c("b", "c")) + var.imp <- FeatureImp$new(predictor1, loss = "mse", features = groups) + dat <- var.imp$results + expect_class(dat, "data.frame") + expect_false("data.table" %in% class(dat)) + expect_equal(colnames(dat), expected_colnames) + expect_equal(nrow(dat), 2) + p <- plot(var.imp) + expect_s3_class(p, c("gg", "ggplot")) + p +})