Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/test-r-package-windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ Write-Output "Done installing CMake"

Write-Output "Installing dependencies"
$packages = -join @(
"c('data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'testthat'), ",
"c('data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'processx', 'R6', 'RhpcBLASctl', 'testthat'), ",
"dependencies = c('Imports', 'Depends', 'LinkingTo')"
)
$params = -join @(
Expand Down
2 changes: 1 addition & 1 deletion .ci/test-r-package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Rscript --vanilla -e "install.packages('https://cran.r-project.org/src/contrib/A

# Manually install Depends and Imports libraries + 'knitr', 'markdown', 'RhpcBLASctl', 'testthat'
# to avoid a CI-time dependency on devtools (for devtools::install_deps())
packages="c('data.table', 'jsonlite', 'knitr', 'markdown', 'R6', 'RhpcBLASctl', 'testthat')"
packages="c('data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'R6', 'RhpcBLASctl', 'testthat')"
compile_from_source="both"
if [[ $OS_NAME == "macos" ]]; then
packages+=", type = 'binary'"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/r_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ jobs:
- name: Install packages
shell: bash
run: |
RDscript${{ matrix.r_customization }} -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
RDscript${{ matrix.r_customization }} -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
sh build-cran-package.sh --r-executable=RD${{ matrix.r_customization }}
RD${{ matrix.r_customization }} CMD INSTALL lightgbm_*.tar.gz || exit 1
- name: Run tests with sanitizers
Expand Down Expand Up @@ -295,7 +295,7 @@ jobs:
- name: Install packages and run tests
shell: bash
run: |
Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
Rscript -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
sh build-cran-package.sh

# 'rchk' isn't run through 'R CMD check', use the approach documented at
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/static_analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
- name: Install packages
shell: bash
run: |
Rscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'roxygen2', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
Rscript -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'roxygen2', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
sh build-cran-package.sh || exit 1
R CMD INSTALL --with-keep.source lightgbm_*.tar.gz || exit 1
- name: Test documentation
Expand Down
2 changes: 1 addition & 1 deletion .vsts-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ jobs:
R_LIB_PATH=~/Rlib
export R_LIBS=${R_LIB_PATH}
mkdir -p ${R_LIB_PATH}
RDscript -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), lib = '${R_LIB_PATH}', dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" || exit 1
RDscript -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl'), lib = '${R_LIB_PATH}', dependencies = c('Depends', 'Imports', 'LinkingTo'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())" || exit 1
sh build-cran-package.sh --r-executable=RD || exit 1
mv lightgbm_${LGB_VER}.tar.gz $(Build.ArtifactStagingDirectory)/lightgbm-${LGB_VER}-r-cran.tar.gz
displayName: 'Build CRAN R-package'
Expand Down
3 changes: 2 additions & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ NeedsCompilation: yes
Biarch: true
VignetteBuilder: knitr
Suggests:
DiagrammeR,
knitr,
markdown,
processx,
RhpcBLASctl,
testthat
testthat,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
Depends:
R (>= 3.5)
Imports:
Expand Down
1 change: 1 addition & 0 deletions R-package/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export(lgb.make_serializable)
export(lgb.model.dt.tree)
export(lgb.plot.importance)
export(lgb.plot.interpretation)
export(lgb.plot.tree)
export(lgb.restore_handle)
export(lgb.save)
export(lgb.slice.Dataset)
Expand Down
204 changes: 204 additions & 0 deletions R-package/R/lgb.plot.tree.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#' @name lgb.plot.tree
#' @title Plot a single LightGBM tree.
#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree.
#' @param model a \code{lgb.Booster} object.
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation).
#' @param rules a list of rules to replace the split values with feature levels.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally convinced about this idea... it should be possible to recover the feature names from the model directly.

But before you remove this... can you please expand this doc and add examples and tests showing what this would look like? Right now, it's hard for me to understand what the content of rules is supposed to be.

#'
#' @return
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot.
#'
#' @details
#' The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree. The tree is extracted from the model and displayed as a directed graph. The nodes are labelled with the feature, split value, gain, cover and value. The edges are labelled with the decision type and split value.
#'
#' @examples
#' \donttest{
#' # EXAMPLE: use the LightGBM example dataset to build a model with a single tree
#' data(agaricus.train, package = "lightgbm")
#' train <- agaricus.train
#' dtrain <- lgb.Dataset(train$data, label = train$label)
#' data(agaricus.test, package = "lightgbm")
#' test <- agaricus.test
#' dtest <- lgb.Dataset.create.valid(dtrain, test$data, label = test$label)
#' # define model parameters and build a single tree
#' params <- list(
#' objective = "regression",
#' min_data = 1L,
#' )
#' valids <- list(test = dtest)
#' model <- lgb.train(
#' params = params,
#' data = dtrain,
#' nrounds = 1L,
#' valids = valids,
#' early_stopping_rounds = 1L
#' )
#' # plot the tree and compare to the tree table
#' # trees start from 0 in lgb.model.dt.tree
#' tree_table <- lgb.model.dt.tree(model)
#' lgb.plot.tree(model, 0)
#' }
#'
#' @export
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
lgb.plot.tree <- function(model, tree, rules = NULL) {

I can't think of any situation where it would be ok for model or tree to be NULL, can you?

If not, let's please require callers to provide values explicitly.

# check model is lgb.Booster
if (!.is_Booster(x = model)) {
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
}
# check DiagrammeR is available
if (!requireNamespace("DiagrammeR", quietly = TRUE)) {
stop("lgb.plot.tree: DiagrammeR package is required",
call. = FALSE
)
}
# tree must be numeric
if (!inherits(tree, "numeric")) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be integer
if (tree %% 1 != 0) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
Comment on lines +54 to +61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# tree must be numeric
if (!inherits(tree, "numeric")) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be integer
if (tree %% 1 != 0) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be numeric
tree <- as.integer(tree)
if (length(tree) != 1L || tree < 1L) {
stop(sprintf("lgb.plot.tree: 'tree' must be a single, positive integer.)
}

Let's combine these, and make it clear what has to be an integer.

# extract data.table model structure
modelDT <- lgb.model.dt.tree(model)
# check that tree is less than or equal to the maximum tree index in the model
if (tree > max(modelDT$tree_index) || tree < 1) {
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
stop("lgb.plot.tree: Invalid tree number")
Comment on lines +66 to +67
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
stop("lgb.plot.tree: Invalid tree number")
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")

What's the reason for having all of the information in a warning() and then immediately raising an error after? If there isn't a specific reason, then let's please combine these for simplicity and to make the logs easier for users to understand.

}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify this error message so that it has enough information for someone to quickly debug the issue, like the provided value of tree and the number of trees in the model. And please combine it with the other check that the value is `>=01.

Something like this:

lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (125). Got: 181.

# filter modelDT to just the rows for the selected tree
modelDT <- modelDT[tree_index == tree, ]
# change the column names to shorter more diagram friendly versions
data.table::setnames(modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain"))
Comment on lines +72 to +74
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data.table::setnames(modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain"))
data.table::setnames(
modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain")
)

Please, follow the style the rest of the project uses. I suspect that the linting configuration here would have caught this (not sure, as I haven't run it myself and it failed in CI for other unrelated reasons).

From this point forward, before you push a commit please run the R-code linting and fix any issues it reports.

From the root of the repo:

Rscript ./.ci/lint-r-code.R ./R-package

# assign leaf_value to the Value column in modelDT
modelDT[, Value := leaf_value]
# assign new values if NA
modelDT[is.na(Value), Value := internal_value]
modelDT[is.na(Gain), Gain := leaf_value]
modelDT[is.na(Feature), Feature := "Leaf"]
# assign internal_count to Cover, and if Feature is "Leaf", assign leaf_count to Cover
modelDT[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
# remove unnecessary columns
modelDT[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
# assign split_index to Node
modelDT[, Node := split_index]
# find the maximum value of Node, if Node is NA, assign max_node + leaf_index + 1 to Node
max_node <- max(modelDT[["Node"]], na.rm = TRUE)
modelDT[is.na(Node), Node := max_node + leaf_index + 1]
# adding ID column
modelDT[, ID := paste(Tree, Node, sep = "-")]
# remove unnecessary columns
modelDT[, c("depth", "leaf_index") := NULL]
modelDT[, parent := node_parent][is.na(parent), parent := leaf_parent]
modelDT[, c("node_parent", "leaf_parent", "split_index") := NULL]
# assign the IDs of the matching parent nodes to Yes and No
modelDT[, Yes := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
modelDT <- modelDT[nrow(modelDT):1, ]
modelDT[, No := modelDT$ID[match(modelDT$Node, modelDT$parent)]]
# which way do the NA's go (this path will get a thicker arrow)
# for categorical features, NA gets put into the zero group
modelDT[default_left == TRUE, Missing := Yes]
modelDT[default_left == FALSE, Missing := No]
modelDT[.zero_present(Split), Missing := Yes]
# create the label text
modelDT[, label := paste0(
Feature
, "\nCover: "
, Cover
, ifelse(Feature == "Leaf", "", "\nGain: "), ifelse(Feature == "Leaf"
, ""
, round(Gain, 4))
, "\nValue: "
, round(Value, 4)
)]
# style the nodes - same format as xgboost
modelDT[Node == 0, label := paste0("Tree ", Tree, "\n", label)]
modelDT[, shape := "rectangle"][Feature == "Leaf", shape := "oval"]
modelDT[, filledcolor := "Beige"][Feature == "Leaf", filledcolor := "Khaki"]
# in order to draw the first tree on top:
modelDT <- modelDT[order(-Tree)]
nodes <- DiagrammeR::create_node_df(
n = nrow(modelDT)
, ID = modelDT$ID
, label = modelDT$label
, fillcolor = modelDT$filledcolor
, shape = modelDT$shape
, data = modelDT$Feature
, fontcolor = "black"
)
# round the edge labels to 4 s.f. if they are numeric
# as otherwise get too many decimal places and the diagram looks bad
# would rather not use suppressWarnings
numeric_idx <- suppressWarnings(!is.na(as.numeric(modelDT[["Split"]])))
modelDT[numeric_idx, Split := round(as.numeric(Split), 4)]
# replace indices with feature levels if rules supplied

if (!is.null(rules)) {
for (f in names(rules)) {
modelDT[Feature == f & decision_type == "==", Split := .levels.to.names(Split, f, rules)]
}
}
# replace long split names with a message
modelDT[nchar(Split) > 500, Split := "Split too long to render"]
# create the edge labels
edges <- DiagrammeR::create_edge_df(
from = match(modelDT[Feature != "Leaf", c(ID)] %>% rep(2), modelDT$ID),
to = match(modelDT[Feature != "Leaf", c(Yes, No)], modelDT$ID),
label = modelDT[Feature != "Leaf", paste(decision_type, Split)] %>%
c(rep("", nrow(modelDT[Feature != "Leaf"]))),
style = modelDT[Feature != "Leaf", ifelse(Missing == Yes, "bold", "solid")] %>%
c(modelDT[Feature != "Leaf", ifelse(Missing == No, "bold", "solid")]),
rel = "leading_to"
)
# create the graph
graph <- DiagrammeR::create_graph(
nodes_df = nodes
, edges_df = edges
, attr_theme = NULL
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph
, attr_type = "graph"
, attr = c("layout", "rankdir")
, value = c("dot", "LR")
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph
, attr_type = "node"
, attr = c("color", "style", "fontname")
, value = c("DimGray", "filled", "Helvetica")
)
graph <- DiagrammeR::add_global_graph_attrs(
graph = graph
, attr_type = "edge"
, attr = c("color", "arrowsize", "arrowhead", "fontname")
, value = c("DimGray", "1.5", "vee", "Helvetica")
)
# render the graph
DiagrammeR::render_graph(graph)
return(invisible(NULL))
}

.zero_present <- function(x) {
sapply(strsplit(as.character(x), "||", fixed = TRUE), function(el) {
any(el == "0")
})
return(invisible(NULL))
}

.levels.to.names <- function(x, feature_name, rules) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.levels.to.names <- function(x, feature_name, rules) {
.levels_to_names <- function(x, feature_name, rules) {

Please avoid using . in any of these private functions' names.

lvls <- sort(rules[[feature_name]])
result <- strsplit(x, "||", fixed = TRUE)
result <- lapply(result, as.numeric)
result <- lapply(result, .levels_to_names)
result <- lapply(result, paste, collapse = "\n")
result <- as.character(result)
return(invisible(NULL))
}

.levels_to_names <- function(x) {
names(lvls)[as.numeric(x)]
return(invisible(NULL))
}
4 changes: 2 additions & 2 deletions R-package/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ docker run \

# install dependencies
RDscript${R_CUSTOMIZATION} \
-e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())"
-e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.r-project.org', Ncpus = parallel::detectCores())"

# install lightgbm
sh build-cran-package.sh --r-executable=RD${R_CUSTOMIZATION}
Expand Down Expand Up @@ -459,7 +459,7 @@ docker run \
-it \
wch1/r-debug

RDscriptvalgrind -e "install.packages(c('R6', 'data.table', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"
RDscriptvalgrind -e "install.packages(c('R6', 'data.table', 'DiagrammeR', 'jsonlite', 'knitr', 'markdown', 'Matrix', 'RhpcBLASctl', 'testthat'), repos = 'https://cran.rstudio.com', Ncpus = parallel::detectCores())"

sh build-cran-package.sh \
--r-executable=RDvalgrind
Expand Down
55 changes: 55 additions & 0 deletions R-package/man/lgb.plot.tree.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ reference:
- '`lgb.interprete`'
- '`lgb.plot.importance`'
- '`lgb.plot.interpretation`'
- '`lgb.plot.tree`'
- '`print.lgb.Booster`'
- '`summary.lgb.Booster`'
- title: Multithreading Control
Expand Down
Loading
Loading