Skip to content

Commit

Permalink
Merge pull request #158 from christophM/grid-parameter
Browse files Browse the repository at this point in the history
Grid parameter
  • Loading branch information
christophM authored Jan 18, 2021
2 parents eb82e71 + 004d6f5 commit 022bec1
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 48 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# iml 0.11

- Allow computation of importance for groups of features (FeatureImp)
- FeatureEffect can now be computed with user provided grid points. Works for ice, ale and pdp

# iml 0.10.1.9000

Expand Down
47 changes: 21 additions & 26 deletions R/FeatureEffect-ale.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
#' @param dat the data.frame with same columns as training data
#' @param run.prediction Predict function of type: f(newdata)
#' @param feature.name The column name of the feature for which to compute ALE
#' @param grid.size The number of intervals
#' @param grid.dt data.table with single column with grid values for the numerical feature
#' @keywords internal
calculate.ale.num <- function(dat, run.prediction, feature.name, grid.size) {
# from number of intervals to number of borders
n.borders <- grid.size + 1
# Handling duplicated grid values
grid.dt <- unique(get.grid(dat[, feature.name, with = FALSE], n.borders, type = "quantile"))
calculate.ale.num <- function(dat, run.prediction, feature.name, grid.dt) {
# Matching data instances to intervals
interval.index <- findInterval(dat[[feature.name]], grid.dt[[1]], left.open = TRUE)
# Data point in the left most interval should be in interval 1, not zero
Expand Down Expand Up @@ -68,19 +64,20 @@ calculate.ale.num <- function(dat, run.prediction, feature.name, grid.size) {
#' @param dat the data.frame with same columns as training data
#' @param run.prediction Predict function of type: f(newdata)
#' @param feature.name The column names of the feature for which to compute ALE
#' @param grid.size The number of cells
#' @param grid.dt1 Data.table with single column with the grid value for feature 1
#' @param grid.dt2 Data.table with single column with the grid value for feature 2
#' @keywords internal
calculate.ale.num.num <- function(dat, run.prediction, feature.name, grid.size) {
# Create grid for feature 1
grid.dt1 <- unique(get.grid(dat[, feature.name[1], with = FALSE], grid.size = grid.size[1] + 1, type = "quantile"))
colnames(grid.dt1) <- feature.name[1]
calculate.ale.num.num <- function(dat, run.prediction, feature.name, grid.dt1, grid.dt2) {
# Remove data outside of boundaries
dat <- dat[(dat[[feature.name[1]]] <= max(grid.dt1[[1]])) &
(dat[[feature.name[1]]] >= min(grid.dt1[[1]])) &
(dat[[feature.name[2]]] <= max(grid.dt2[[1]])) &
(dat[[feature.name[2]]] >= min(grid.dt2[[1]])),]
print(dat)
# Matching instances to the grid of feature 1
interval.index1 <- findInterval(dat[[feature.name[1]]], grid.dt1[[1]], left.open = TRUE)
# Data point in the left most interval should be in interval 1, not zero
interval.index1[interval.index1 == 0] <- 1
## Create grid for feature 2
grid.dt2 <- unique(get.grid(dat[, feature.name[2], with = FALSE], grid.size = grid.size[2] + 1, type = "quantile"))
colnames(grid.dt2) <- feature.name[2]
# Matching instances to the grid of feature 2
interval.index2 <- findInterval(dat[[feature.name[2]]], grid.dt2[[1]], left.open = TRUE)
# Data point in the left most interval should be in interval 1, not zero
Expand Down Expand Up @@ -214,6 +211,7 @@ calculate.ale.num.num <- function(dat, run.prediction, feature.name, grid.size)
#' @keywords internal
calculate.ale.cat <- function(dat, run.prediction, feature.name) {
x <- dat[, feature.name, with = FALSE][[1]]

levels.original <- levels(droplevels(x))
nlev <- nlevels(droplevels(x))
# if ordered, than already use that
Expand Down Expand Up @@ -277,16 +275,20 @@ calculate.ale.cat <- function(dat, run.prediction, feature.name) {
#' @param dat the data.frame with same columns as training data
#' @param run.prediction Predict function of type: f(newdata)
#' @param feature.name The column name of the features for which to compute ALE
#' @param grid.size The number of intervals for the numerical feature
#' @param grid.dt data.table with single column with grid values for the numerical feature
#' @keywords internal
calculate.ale.num.cat <- function(dat, run.prediction, feature.name, grid.size) {
calculate.ale.num.cat <- function(dat, run.prediction, feature.name, grid.dt) {

# Figure out which feature is numeric and which categeorical
x.num.index <- ifelse(inherits(dat[, feature.name, with = FALSE][[1]], "numeric"), 1, 2)
x.cat.index <- setdiff(c(1, 2), x.num.index)
x.num <- dat[, feature.name[x.num.index], with = FALSE][[1]]
x.cat <- dat[, feature.name[x.cat.index], with = FALSE][[1]]
# We can only compute ALE within min and max boundaries of given intervals
# This part is only relevat for user-defined intervals
dat <- dat[which((x.num >= min(grid.dt[[1]])) &
(x.num <= max(grid.dt[[1]])))]

x.cat.index <- setdiff(c(1, 2), x.num.index)
x.cat <- dat[, feature.name[x.cat.index], with = FALSE][[1]]
levels.original <- levels(x.cat)
# if ordered, than already use that
if (inherits(x.cat, "ordered")) {
Expand All @@ -302,13 +304,7 @@ calculate.ale.num.cat <- function(dat, run.prediction, feature.name, grid.size)
# The rows for which the category can be increased
row.ind.increase <- (1:nrow(dat))[x.cat.ordered < nlevels(x.cat)]
row.ind.decrease <- (1:nrow(dat))[x.cat.ordered > 1]


## Create ALE for increasing categorical feature
grid.dt <- unique(get.grid(dat[, feature.name[x.num.index], with = FALSE],
grid.size = grid.size[x.num.index] + 1, type = "quantile"
))
colnames(grid.dt) <- feature.name[x.num.index]

interval.index <- findInterval(dat[[feature.name[x.num.index]]], grid.dt[[1]], left.open = TRUE)
# Data point in the left most interval should be in interval 1, not zero
interval.index[interval.index == 0] <- 1
Expand Down Expand Up @@ -476,7 +472,6 @@ calculate.ale.num.cat <- function(dat, run.prediction, feature.name, grid.size)
".level", ".num"
)), with = FALSE]
deltas$.type <- "ale"

data.frame(deltas)
}

Expand Down
77 changes: 62 additions & 15 deletions R/FeatureEffect.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ FeatureEffect <- R6Class("FeatureEffect",
#' Value at which the plot should be centered. Ignored in the case of two
#' features.
#' @template grid.size
#' @template grid.points
initialize = function(predictor, feature, method = "ale", center.at = NULL,
grid.size = 20) {
grid.size = 20, grid.points = NULL) {

feature_index <- private$sanitize.feature(
feature,
Expand All @@ -159,6 +160,9 @@ FeatureEffect <- R6Class("FeatureEffect",
min.len = 1, max.len = 2
)
assert_numeric(grid.size, min.len = 1, max.len = length(feature))
assert(check_numeric(grid.points, null.ok = TRUE, min.len = 2),
check_list(grid.points), null.ok = TRUE, min.len = 1, max.len = 2,
combine = "or")
assert_number(center.at, null.ok = TRUE)
assert_choice(method, c("ale", "pdp", "ice", "pdp+ice"))
self$method <- method
Expand All @@ -169,6 +173,7 @@ FeatureEffect <- R6Class("FeatureEffect",
stop("ICE is not implemented for two features.")
}
}
if (!is.null(grid.points)) private$grid.points = unique(grid.points)
private$anchor.value <- center.at
super$initialize(predictor)
private$set_feature_from_index(feature_index)
Expand Down Expand Up @@ -258,6 +263,7 @@ FeatureEffect <- R6Class("FeatureEffect",
private = list(
run = function(n) {
if (self$method == "ale") {

private$run.ale()
} else {
private$run.pdp(self$predictor$batch.size)
Expand Down Expand Up @@ -296,34 +302,66 @@ FeatureEffect <- R6Class("FeatureEffect",
anchor.value = NULL,
grid.size.original = NULL,
y_axis_label = NULL,
grid.points = NULL,
# core functionality of self$predict
predict_inner = NULL,
run.ale = function() {

private$dataSample <- private$getData()
dat <- private$dataSample
if (self$n.features == 1) {
if (self$feature.type == "numerical") { # one numerical feature
if (is.null(private$grid.points)) {
grid.dt <- unique(get.grid(dat[, self$feature.name, with = FALSE],
self$grid.size + 1, type = "quantile"))
} else {
grid.dt <- data.table(unique(sort(private$grid.points)))
colnames(grid.dt) <- self$feature.name
}
results <- calculate.ale.num(
dat = private$dataSample, run.prediction = private$run.prediction,
feature.name = self$feature.name, grid.size = self$grid.size
dat = dat, run.prediction = private$run.prediction,
feature.name = self$feature.name, grid.dt = grid.dt
)
} else { # one categorical feature
results <- calculate.ale.cat(
dat = private$dataSample, run.prediction = private$run.prediction,
dat = dat, run.prediction = private$run.prediction,
feature.name = self$feature.name
)
}
} else { # two features
if (all(self$feature.type == "numerical")) { # two numerical features
# Create grid for feature 1
if (is.null(private$grid.points)) {
grid.dt1 <- unique(get.grid(dat[, self$feature.name[1], with = FALSE],
grid.size = self$grid.size[1] + 1, type = "quantile"))
grid.dt2 <- unique(get.grid(dat[, self$feature.name[2], with = FALSE],
grid.size = self$grid.size[2] + 1, type = "quantile"))
} else {
grid.dt1 <- data.table(unique(sort(private$grid.points[[1]])))
grid.dt2 <- data.table(unique(sort(private$grid.points[[2]])))
}
colnames(grid.dt1) <- self$feature.name[[1]]
colnames(grid.dt2) <- self$feature.name[[2]]
results <- calculate.ale.num.num(
dat = private$dataSample, run.prediction = private$run.prediction,
feature.name = self$feature.name, grid.size = self$grid.size
)
dat = dat, run.prediction = private$run.prediction,
feature.name = self$feature.name, grid.dt1 = grid.dt1, grid.dt2 = grid.dt2)
} else if (all(self$feature.type == "categorical")) { # two categorical features
stop("ALE for two categorical features is not yet implemented.")
} else { # mixed numerical and categorical

x.num.index <- ifelse(inherits(dat[, self$feature.name, with = FALSE][[1]], "numeric"), 1, 2)
if (is.null(private$grid.points)) {
grid.dt <- unique(get.grid(dat[, self$feature.name[x.num.index], with = FALSE],
grid.size = self$grid.size[x.num.index] + 1, type = "quantile"
))
} else {
grid.dt <- data.table(sort(unique(private$grid.points)))
}
colnames(grid.dt) <- self$feature.name[x.num.index]

results <- calculate.ale.num.cat(
dat = private$dataSample, run.prediction = private$run.prediction,
feature.name = self$feature.name, grid.size = self$grid.size
dat = dat, run.prediction = private$run.prediction,
feature.name = self$feature.name, grid.dt = grid.dt
)
}
}
Expand All @@ -334,20 +372,29 @@ FeatureEffect <- R6Class("FeatureEffect",
self$results <- results
},
run.pdp = function(n) {

private$dataSample <- private$getData()
grid.dt <- get.grid(private$getData()[, self$feature.name, with = FALSE],
self$grid.size,
anchor.value = private$anchor.value
)
if (is.null(private$grid.points)) {
grid.dt <- get.grid(private$getData()[, self$feature.name, with = FALSE],
self$grid.size,
anchor.value = private$anchor.value
)
} else {
if(self$n.features == 1){
grid.dt <- data.table(private$grid.points)
} else {
grid.dt <- data.table(expand.grid(private$grid.points[[1]],
private$grid.points[[2]]))
}
names(grid.dt) <- self$feature.name
}
mg <- MarginalGenerator$new(grid.dt, private$dataSample,
self$feature.name,
id.dist = TRUE, cartesian = TRUE
)
results.ice <- data.table()
while (!mg$finished) {
results.ice.inter <- mg$next.batch(n)
predictions <- private$run.prediction(results.ice.inter)
predictions <- private$run.prediction(results.ice.inter[,self$predictor$data$feature.names, with = FALSE])
results.ice.inter <- results.ice.inter[, c(self$feature.name, ".id.dist"),
with = FALSE
]
Expand Down
9 changes: 8 additions & 1 deletion man/FeatureEffect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/calculate.ale.num.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/calculate.ale.num.cat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/calculate.ale.num.num.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 022bec1

Please sign in to comment.