Skip to content

Commit

Permalink
Merge pull request #105 from ModelOriented/dev
Browse files Browse the repository at this point in the history
refactor dependency #103
  • Loading branch information
pbiecek authored Feb 17, 2020
2 parents bf16045 + 272c4b0 commit 857c3f5
Show file tree
Hide file tree
Showing 139 changed files with 1,898 additions and 699 deletions.
8 changes: 4 additions & 4 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: ingredients
Title: Effects and Importances of Model Ingredients
Version: 0.5.2
Version: 1.0
Authors@R: c(person("Przemyslaw", "Biecek", email = "[email protected]",
role = c("aut", "cre"),
comment = c(ORCID = "0000-0001-8423-1823")),
Expand All @@ -11,9 +11,9 @@ Description: Collection of tools for assessment of feature importance and featur
Key functions are:
feature_importance() for assessment of global level feature importance,
ceteris_paribus() for calculation of the what-if plots,
partial_dependency() for partial dependency plots,
conditional_dependency() for conditional dependency plots,
accumulated_dependency() for accumulated local effects plots,
partial_dependence() for partial dependence plots,
conditional_dependence() for conditional dependence plots,
accumulated_dependence() for accumulated local effects plots,
aggregate_profiles() and cluster_profiles() for aggregation of ceteris paribus profiles,
generic print() and plot() for better usability of selected explainers,
generic plotD3() for interactive, D3 based explanations, and
Expand Down
22 changes: 13 additions & 9 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# Generated by roxygen2: do not edit by hand

S3method(accumulated_dependency,ceteris_paribus_explainer)
S3method(accumulated_dependency,default)
S3method(accumulated_dependency,explainer)
S3method(accumulated_dependence,ceteris_paribus_explainer)
S3method(accumulated_dependence,default)
S3method(accumulated_dependence,explainer)
S3method(ceteris_paribus,default)
S3method(ceteris_paribus,explainer)
S3method(conditional_dependency,ceteris_paribus_explainer)
S3method(conditional_dependency,default)
S3method(conditional_dependency,explainer)
S3method(conditional_dependence,ceteris_paribus_explainer)
S3method(conditional_dependence,default)
S3method(conditional_dependence,explainer)
S3method(describe,ceteris_paribus_explainer)
S3method(describe,feature_importance_explainer)
S3method(describe,partial_dependence_explainer)
S3method(describe,partial_dependency_explainer)
S3method(feature_importance,default)
S3method(feature_importance,explainer)
S3method(partial_dependency,ceteris_paribus_explainer)
S3method(partial_dependency,default)
S3method(partial_dependency,explainer)
S3method(partial_dependence,ceteris_paribus_explainer)
S3method(partial_dependence,default)
S3method(partial_dependence,explainer)
S3method(plot,aggregated_profiles_explainer)
S3method(plot,ceteris_paribus_2d_explainer)
S3method(plot,ceteris_paribus_explainer)
Expand All @@ -29,16 +30,19 @@ S3method(print,ceteris_paribus_explainer)
S3method(print,feature_importance_explainer)
S3method(select_neighbours,default)
S3method(select_sample,default)
export(accumulated_dependence)
export(accumulated_dependency)
export(aggregate_profiles)
export(calculate_oscillations)
export(ceteris_paribus)
export(ceteris_paribus_2d)
export(cluster_profiles)
export(conditional_dependence)
export(conditional_dependency)
export(describe)
export(feature_importance)
export(local_dependency)
export(partial_dependence)
export(partial_dependency)
export(plotD3)
export(select_neighbours)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
ingredients 1.0
---------------------------------------------------------------
* change `dependency` to `dependence` [#103](https://github.com/ModelOriented/ingredients/issues/103)

ingredients 0.5.2
---------------------------------------------------------------
* `ceteris_paribus` profiles are now working for categorical variables
Expand Down
35 changes: 19 additions & 16 deletions R/accumulated_dependency.R → R/accumulated_dependence.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' Accumulated Local Effects Profiles aka ALEPlots
#'
#' Accumulated Local Effects Profiles accumulate local changes in Ceteris Paribus Profiles.
#' Function \code{\link{accumulated_dependency}} calls \code{\link{ceteris_paribus}} and then \code{\link{aggregate_profiles}}.
#' Function \code{\link{accumulated_dependence}} calls \code{\link{ceteris_paribus}} and then \code{\link{aggregate_profiles}}.
#'
#' Find more detailes in the \href{https://pbiecek.github.io/ema/accumulatedLocalProfiles.html}{Accumulated Local Dependency Chapter}.
#' Find more detailes in the \href{https://pbiecek.github.io/ema/accumulatedLocalProfiles.html}{Accumulated Local Dependence Chapter}.
#'
#' @param x an explainer created with function \code{DALEX::explain()}, an object of the class \code{ceteris_paribus_explainer}
#' or a model to be explained.
Expand All @@ -13,7 +13,7 @@
#' @param variables names of variables for which profiles shall be calculated.
#' Will be passed to \code{\link{calculate_variable_split}}.
#' If \code{NULL} then all variables from the validation data will be used.
#' @param N number of observations used for calculation of partial dependency profiles.
#' @param N number of observations used for calculation of partial dependence profiles.
#' By default, 500 observations will be chosen randomly.
#' @param ... other parameters
#' @param variable_splits named list of splits for variables, in most cases created with \code{\link{calculate_variable_split}}.
Expand All @@ -39,7 +39,7 @@
#' y = titanic_imputed[,8],
#' verbose = FALSE)
#'
#' adp_glm <- accumulated_dependency(explain_titanic_glm,
#' adp_glm <- accumulated_dependence(explain_titanic_glm,
#' N = 150, variables = c("age", "fare"))
#' head(adp_glm)
#' plot(adp_glm)
Expand All @@ -54,21 +54,21 @@
#' y = titanic_imputed[,8],
#' verbose = FALSE)
#'
#' adp_rf <- accumulated_dependency(explain_titanic_rf, N = 200, variable_type = "numerical")
#' adp_rf <- accumulated_dependence(explain_titanic_rf, N = 200, variable_type = "numerical")
#' plot(adp_rf)
#'
#' adp_rf <- accumulated_dependency(explain_titanic_rf, N = 200, variable_type = "categorical")
#' adp_rf <- accumulated_dependence(explain_titanic_rf, N = 200, variable_type = "categorical")
#' plotD3(adp_rf, label_margin = 80, scale_plot = TRUE)
#' }
#'
#' @export
#' @rdname accumulated_dependency
accumulated_dependency <- function(x, ...)
UseMethod("accumulated_dependency")
#' @rdname accumulated_dependence
accumulated_dependence <- function(x, ...)
UseMethod("accumulated_dependence")

#' @export
#' @rdname accumulated_dependency
accumulated_dependency.explainer <- function(x,
#' @rdname accumulated_dependence
accumulated_dependence.explainer <- function(x,
variables = NULL,
N = 500,
variable_splits = NULL,
Expand All @@ -81,7 +81,7 @@ accumulated_dependency.explainer <- function(x,
predict_function <- x$predict_function
label <- x$label

accumulated_dependency.default(x = model,
accumulated_dependence.default(x = model,
data = data,
predict_function = predict_function,
label = label,
Expand All @@ -94,8 +94,8 @@ accumulated_dependency.explainer <- function(x,


#' @export
#' @rdname accumulated_dependency
accumulated_dependency.default <- function(x,
#' @rdname accumulated_dependence
accumulated_dependence.default <- function(x,
data,
predict_function = predict,
label = class(x)[1],
Expand Down Expand Up @@ -127,10 +127,13 @@ accumulated_dependency.default <- function(x,


#' @export
#' @rdname accumulated_dependency
accumulated_dependency.ceteris_paribus_explainer <- function(x, ...,
#' @rdname accumulated_dependence
accumulated_dependence.ceteris_paribus_explainer <- function(x, ...,
variables = NULL) {

aggregate_profiles(x, ..., type = "accumulated", variables = variables)
}

#' @export
#' @rdname accumulated_dependence
accumulated_dependency <- accumulated_dependence
12 changes: 6 additions & 6 deletions R/aggregate_profiles.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' Aggregates Ceteris Paribus Profiles
#'
#' The function \code{aggregate_profiles()} calculates an aggregate of ceteris paribus profiles.
#' It can be: Partial Dependency Profile (average across Ceteris Paribus Profiles),
#' Conditional Dependency Profile (local weighted average across Ceteris Paribus Profiles) or
#' Accumulated Local Dependency Profile (cummulated average local changes in Ceteris Paribus Profiles).
#' It can be: Partial Dependence Profile (average across Ceteris Paribus Profiles),
#' Conditional Dependence Profile (local weighted average across Ceteris Paribus Profiles) or
#' Accumulated Local Dependence Profile (cummulated average local changes in Ceteris Paribus Profiles).
#'
#' @param x a ceteris paribus explainer produced with function \code{ceteris_paribus()}
#' @param ... other explainers that shall be calculated together
Expand Down Expand Up @@ -170,17 +170,17 @@ aggregate_profiles <- function(x, ...,
if (type == "partial") {
aggregated_profiles <- aggregated_profiles_partial(all_profiles, groups)
class(aggregated_profiles) <- c("aggregated_profiles_explainer",
"partial_dependency_explainer", "data.frame")
"partial_dependence_explainer", "data.frame")
}
if (type == "conditional") {
aggregated_profiles <- aggregated_profiles_conditional(all_profiles, groups, span = span)
class(aggregated_profiles) <- c("aggregated_profiles_explainer",
"conditional_dependency_explainer", "data.frame")
"conditional_dependence_explainer", "data.frame")
}
if (type == "accumulated") {
aggregated_profiles <- aggregated_profiles_accumulated(all_profiles, groups, span = span, center = center)
class(aggregated_profiles) <- c("aggregated_profiles_explainer",
"accumulated_dependency_explainer", "data.frame")
"accumulated_dependence_explainer", "data.frame")
}

# calculate mean(all observation's _yhat_), mean of prediction
Expand Down
46 changes: 25 additions & 21 deletions R/conditional_dependency.R → R/conditional_dependence.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#' Conditional Dependency Profiles
#' Conditional Dependence Profiles
#'
#' Conditional Dependency Profiles (aka Local Profiles) average localy Ceteris Paribus Profiles.
#' Function 'conditional_dependency' calls 'ceteris_paribus' and then 'aggregate_profiles'.
#' Conditional Dependence Profiles (aka Local Profiles) average localy Ceteris Paribus Profiles.
#' Function 'conditional_dependence' calls 'ceteris_paribus' and then 'aggregate_profiles'.
#'
#' Find more detailes in the \href{https://pbiecek.github.io/ema/accumulatedLocalProfiles.html}{Accumulated Local Dependency Chapter}.
#' Find more detailes in the \href{https://pbiecek.github.io/ema/accumulatedLocalProfiles.html}{Accumulated Local Dependence Chapter}.
#'
#' @param x an explainer created with function \code{DALEX::explain()}, an object of the class \code{ceteris_paribus_explainer}
#' or a model to be explained.
Expand All @@ -12,7 +12,7 @@
#' @param predict_function predict function, will be extracted from \code{x} if it's an explainer
#' @param variables names of variables for which profiles shall be calculated.
#' Will be passed to \code{\link{calculate_variable_split}}. If \code{NULL} then all variables from the validation data will be used.
#' @param N number of observations used for calculation of partial dependency profiles. By default 500.
#' @param N number of observations used for calculation of partial dependence profiles. By default 500.
#' @param ... other parameters
#' @param variable_splits named list of splits for variables, in most cases created with \code{\link{calculate_variable_split}}.
#' If \code{NULL} then it will be calculated based on validation data avaliable in the \code{explainer}.
Expand All @@ -36,7 +36,7 @@
#' y = titanic_imputed[,8],
#' verbose = FALSE)
#'
#' cdp_glm <- conditional_dependency(explain_titanic_glm,
#' cdp_glm <- conditional_dependence(explain_titanic_glm,
#' N = 150, variables = c("age", "fare"))
#' head(cdp_glm)
#' plot(cdp_glm)
Expand All @@ -51,21 +51,21 @@
#' y = titanic_imputed[,8],
#' verbose = FALSE)
#'
#' cdp_rf <- conditional_dependency(explain_titanic_rf, N = 200, variable_type = "numerical")
#' cdp_rf <- conditional_dependence(explain_titanic_rf, N = 200, variable_type = "numerical")
#' plot(cdp_rf)
#'
#' cdp_rf <- conditional_dependency(explain_titanic_rf, N = 200, variable_type = "categorical")
#' cdp_rf <- conditional_dependence(explain_titanic_rf, N = 200, variable_type = "categorical")
#' plotD3(cdp_rf, label_margin = 80, scale_plot = TRUE)
#' }
#'
#' @export
#' @rdname conditional_dependency
conditional_dependency <- function(x, ...)
UseMethod("conditional_dependency")
#' @rdname conditional_dependence
conditional_dependence <- function(x, ...)
UseMethod("conditional_dependence")

#' @export
#' @rdname conditional_dependency
conditional_dependency.explainer <- function(x,
#' @rdname conditional_dependence
conditional_dependence.explainer <- function(x,
variables = NULL,
N = 500,
variable_splits = NULL,
Expand All @@ -78,7 +78,7 @@ conditional_dependency.explainer <- function(x,
predict_function <- x$predict_function
label <- x$label

conditional_dependency.default(x = model,
conditional_dependence.default(x = model,
data = data,
predict_function = predict_function,
label = label,
Expand All @@ -91,8 +91,8 @@ conditional_dependency.explainer <- function(x,


#' @export
#' @rdname conditional_dependency
conditional_dependency.default <- function(x,
#' @rdname conditional_dependence
conditional_dependence.default <- function(x,
data,
predict_function = predict,
label = class(x)[1],
Expand All @@ -119,18 +119,22 @@ conditional_dependency.default <- function(x,
variable_splits = variable_splits,
label = label, ...)

conditional_dependency.ceteris_paribus_explainer(cp, variables = variables, variable_type = variable_type, ...)
conditional_dependence.ceteris_paribus_explainer(cp, variables = variables, variable_type = variable_type, ...)
}


#' @export
#' @rdname conditional_dependency
conditional_dependency.ceteris_paribus_explainer <- function(x, ...,
#' @rdname conditional_dependence
conditional_dependence.ceteris_paribus_explainer <- function(x, ...,
variables = NULL) {

aggregate_profiles(x, ..., type = "conditional", variables = variables)
}

#' @export
#' @rdname conditional_dependency
local_dependency <- conditional_dependency
#' @rdname conditional_dependence
local_dependency <- conditional_dependence

#' @export
#' @rdname conditional_dependence
conditional_dependency <- conditional_dependence
5 changes: 4 additions & 1 deletion R/describe_aggregated_profiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#'
#' @export
#' @rdname describe
describe.partial_dependency_explainer <- function(x,
describe.partial_dependence_explainer <- function(x,
nonsignificance_treshold = 0.15,
...,
display_values = FALSE,
Expand Down Expand Up @@ -290,3 +290,6 @@ specify_df_aggregated <- function(x, variables, nonsignificance_treshold) {

list("df" = df, "treshold" = treshold)
}

#' @export
describe.partial_dependency_explainer <- describe.partial_dependence_explainer
Loading

0 comments on commit 857c3f5

Please sign in to comment.