Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

fboudry
Copy link

@fboudry fboudry commented Nov 23, 2024

Feature requested in #1222

Added a R function to plot trees.
Basically used the code posted in #1222 by @SpeckledJim2 and followed the given instruction.

Added DiagrammeR as suggested in DESCRIPTION
Added lgb.plot.tree in _pkgdown.yml
Roxygenized.
@fboudry
Copy link
Author

fboudry commented Nov 23, 2024

@microsoft-github-policy-service agree

@jameslamb jameslamb changed the title R tree plot [R-package] add a tree plotting function Nov 24, 2024
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your interest in LightGBM.

As I mentioned in the discussion on #1222, I'm supportive of trying to add something like this (especially since xgboost has it as well).

But I hope you'll see from the first round of suggestions I left here... significant work remains before I'd support merging this change into the package. If you are willing to work with us on this and go through multiple rounds of reviews and suggestions, we'd be grateful for the help! But if you don't have the time/interest to get this ready for inclusion in the package, please let me know and we'll close this PR and leave #1222 open for someone else to pick up.

R-package/DESCRIPTION Outdated Show resolved Hide resolved
@@ -0,0 +1,184 @@
#' @name lgb.plot.tree
#' @title Plot a single LightGBM tree using DiagrammeR.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' @title Plot a single LightGBM tree using DiagrammeR.
#' @title Plot a single LightGBM tree.

Let's simplify this, please.

Comment on lines 45 to 46

# function to plot a single LightGBM tree using DiagrammeR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# function to plot a single LightGBM tree using DiagrammeR

We do not need to repeat in a comment here the same information that's already in the roxygen comments.

Comment on lines 49 to 51
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
}
if (!.is_Booster(x = model)) {
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
}

Please follow the patterns used elsewhere in the library for this:

if (!.is_Booster(x = model)) {
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster"))
}

stop("tree: Has to be an integer numeric")
}
# extract data.table model structure
dt <- lgb.model.dt.tree(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dt <- lgb.model.dt.tree(model)
modelDT <- lgb.model.dt.tree(model)

Please don't use the name dt. That is a function in the {stats} package (for finding the density of a t-distribution)... try ?dt to see that.

Shadowing names from the standard library can lead to confusing errors. Please use modelDT as the name for this data.table instead.

nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this project, by convention we:

  • do not use the %>% operator
  • use comma-first style everywhere

Please update this code and all the other code you're adding to follow that. Keeping all of the code looking the same across the codebase helps us to develop and review changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xgboost's implementation of similar functionality might be useful as a reference. See https://github.com/dmlc/xgboost/blob/e988b7cf1515b08ad0f949c26beb043ce0b33fe8/R-package/R/xgb.plot.tree.R#L159-L181

@@ -0,0 +1,59 @@
test_that("lgb.plot.tree works as expected"){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add tests for the other types of machine learning tasks LightGBM can be used for:

  • binary classification
  • multiclass classification (where, please note, there are num_classes trees produced per iteration)
  • learning-to-rank

And for the following model situations:

  • uses categorical features

These are all cases that could affect the code as written... for example, categorical features have different splitting rules.

Comment on lines 76 to 77
dt[, Value := 0.0]
dt[, Value := leaf_value]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dt[, Value := 0.0]
dt[, Value := leaf_value]
dt[, Value := leaf_value]

I don't understand this... what's the purpose of setting all rows to 0.0 and then immediately overwriting them? It seems to me that the 0.0 could probably be removed.

Comment on lines 78 to 96
dt[is.na(Value), Value := internal_value]
dt[is.na(Gain), Gain := leaf_value]
dt[is.na(Feature), Feature := "Leaf"]
dt[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
dt[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
dt[, Node := split_index]
max_node <- max(dt[["Node"]], na.rm = TRUE)
dt[is.na(Node), Node := max_node + leaf_index + 1]
dt[, ID := paste(Tree, Node, sep = "-")]
dt[, c("depth", "leaf_index") := NULL]
dt[, parent := node_parent][is.na(parent), parent := leaf_parent]
dt[, c("node_parent", "leaf_parent", "split_index") := NULL]
dt[, Yes := dt$ID[match(dt$Node, dt$parent)]]
dt <- dt[nrow(dt):1, ]
dt[, No := dt$ID[match(dt$Node, dt$parent)]]
# which way do the NA's go (this path will get a thicker arrow)
# for categorical features, NA gets put into the zero group
dt[default_left == TRUE, Missing := Yes]
dt[default_left == FALSE, Missing := No]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add some comments to make it a bit easier to understand what's happening in this wall of code? It's very difficult to read (at least for me) as currently written).

# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_error({
lgb.plot.tree(model, 999)TRUE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree(model, 999)TRUE
lgb.plot.tree(model, 999)

This looks like it was included accidentally?

DiagrammeR in CI.
Error messages.
Default parameters.
Changed tests.
…ree.R)

Now tests regressions, binary, multiclass classification and ranks.
Added a warning to functions and shorter stop message to make tests work.
@fboudry
Copy link
Author

fboudry commented Dec 31, 2024

Thanks for the review @jameslamb, helped me a lot!
I think I've made all suggested changes, however the checks are failing on the "R CMD check". When run locally I also have an error about the Description file (missing maintainer field). I don't really know what I'm supposed to do with that/how you want to handle it in this package so I'll take any suggestion!

R-package/DESCRIPTION Outdated Show resolved Hide resolved
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I left a few more comments for your consideration. To be clear, I still haven't very-thoroughly reviewed so this is not a comprehensive list... these are just quick things I noticed in the few minutes I had to review.

In addition to those... it'd be helpful if you could include some screenshots of what the plots look like, in the description of the PR. That'll really help me understand what the goal is here, without needing to run this code myself.

#' }
#'
#' @export
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree <- function(model = NULL, tree = NULL, rules = NULL) {
lgb.plot.tree <- function(model, tree, rules = NULL) {

I can't think of any situation where it would be ok for model or tree to be NULL, can you?

If not, let's please require callers to provide values explicitly.

Comment on lines +54 to +61
# tree must be numeric
if (!inherits(tree, "numeric")) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be integer
if (tree %% 1 != 0) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# tree must be numeric
if (!inherits(tree, "numeric")) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be integer
if (tree %% 1 != 0) {
stop("lgb.plot.tree: Has to be an integer numeric")
}
# tree must be numeric
tree <- as.integer(tree)
if (length(tree) != 1L || tree < 1L) {
stop(sprintf("lgb.plot.tree: 'tree' must be a single, positive integer.)
}

Let's combine these, and make it clear what has to be an integer.

Comment on lines +66 to +67
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
stop("lgb.plot.tree: Invalid tree number")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
warning("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")
stop("lgb.plot.tree: Invalid tree number")
stop("lgb.plot.tree: Value of 'tree' should be between 1 and the total number of trees in the model (", max(modelDT$tree_index), "). Got: ", tree, ".")

What's the reason for having all of the information in a warning() and then immediately raising an error after? If there isn't a specific reason, then let's please combine these for simplicity and to make the logs easier for users to understand.

return(invisible(NULL))
}

.levels.to.names <- function(x, feature_name, rules) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.levels.to.names <- function(x, feature_name, rules) {
.levels_to_names <- function(x, feature_name, rules) {

Please avoid using . in any of these private functions' names.

#' @description The \code{lgb.plot.tree} function creates a DiagrammeR plot of a single LightGBM tree.
#' @param model a \code{lgb.Booster} object.
#' @param tree an integer specifying the tree to plot. This is 1-based, so e.g. a value of '7' means 'the 7th tree' (tree_index=6 in LightGBM's underlying representation).
#' @param rules a list of rules to replace the split values with feature levels.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally convinced about this idea... it should be possible to recover the feature names from the model directly.

But before you remove this... can you please expand this doc and add examples and tests showing what this would look like? Right now, it's hard for me to understand what the content of rules is supposed to be.

Comment on lines +72 to +74
data.table::setnames(modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data.table::setnames(modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain"))
data.table::setnames(
modelDT
, old = c("tree_index", "split_feature", "threshold", "split_gain")
, new = c("Tree", "Feature", "Split", "Gain")
)

Please, follow the style the rest of the project uses. I suspect that the linting configuration here would have caught this (not sure, as I haven't run it myself and it failed in CI for other unrelated reasons).

From this point forward, before you push a commit please run the R-code linting and fix any issues it reports.

From the root of the repo:

Rscript ./.ci/lint-r-code.R ./R-package

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants