Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] add a tree plotting function #6729

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

fboudry
Copy link

@fboudry fboudry commented Nov 23, 2024

Feature requested in #1222

Added a R function to plot trees.
Basically used the code posted in #1222 by @SpeckledJim2 and followed the given instruction.

Added DiagrammeR as suggested in DESCRIPTION
Added lgb.plot.tree in _pkgdown.yml
Roxygenized.
@fboudry
Copy link
Author

fboudry commented Nov 23, 2024

@microsoft-github-policy-service agree

@jameslamb jameslamb changed the title R tree plot [R-package] add a tree plotting function Nov 24, 2024
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your interest in LightGBM.

As I mentioned in the discussion on #1222, I'm supportive of trying to add something like this (especially since xgboost has it as well).

But I hope you'll see from the first round of suggestions I left here... significant work remains before I'd support merging this change into the package. If you are willing to work with us on this and go through multiple rounds of reviews and suggestions, we'd be grateful for the help! But if you don't have the time/interest to get this ready for inclusion in the package, please let me know and we'll close this PR and leave #1222 open for someone else to pick up.

@@ -49,7 +49,8 @@ Suggests:
markdown,
processx,
RhpcBLASctl,
testthat
testthat,
DiagrammeR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please keep this list in alphabetical order (move DiagrammeR to the top of this list).

You will also need to add DiagrammeR to every place in continuous integration scripts that installs optional dependencies for the project. You can find those like this:

git grep processx

@@ -0,0 +1,184 @@
#' @name lgb.plot.tree
#' @title Plot a single LightGBM tree using DiagrammeR.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' @title Plot a single LightGBM tree using DiagrammeR.
#' @title Plot a single LightGBM tree.

Let's simplify this, please.

Comment on lines +45 to +46

# function to plot a single LightGBM tree using DiagrammeR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# function to plot a single LightGBM tree using DiagrammeR

We do not need to repeat in a comment here the same information that's already in the roxygen comments.

Comment on lines +49 to +51
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!inherits(model, "lgb.Booster")) {
stop("model: Has to be an object of class lgb.Booster")
}
if (!.is_Booster(x = model)) {
stop("lgb.plot.tree: model should be an ", sQuote("lgb.Booster"))
}

Please follow the patterns used elsewhere in the library for this:

if (!.is_Booster(x = model)) {
stop("lgb.restore_handle: model should be an ", sQuote("lgb.Booster"))
}

stop("tree: Has to be an integer numeric")
}
# extract data.table model structure
dt <- lgb.model.dt.tree(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dt <- lgb.model.dt.tree(model)
modelDT <- lgb.model.dt.tree(model)

Please don't use the name dt. That is a function in the {stats} package (for finding the density of a t-distribution)... try ?dt to see that.

Shadowing names from the standard library can lead to confusing errors. Please use modelDT as the name for this data.table instead.

nodes_df = nodes,
edges_df = edges,
attr_theme = NULL
) %>%
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this project, by convention we:

  • do not use the %>% operator
  • use comma-first style everywhere

Please update this code and all the other code you're adding to follow that. Keeping all of the code looking the same across the codebase helps us to develop and review changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xgboost's implementation of similar functionality might be useful as a reference. See https://github.com/dmlc/xgboost/blob/e988b7cf1515b08ad0f949c26beb043ce0b33fe8/R-package/R/xgb.plot.tree.R#L159-L181

@@ -0,0 +1,59 @@
test_that("lgb.plot.tree works as expected"){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add tests for the other types of machine learning tasks LightGBM can be used for:

  • binary classification
  • multiclass classification (where, please note, there are num_classes trees produced per iteration)
  • learning-to-rank

And for the following model situations:

  • uses categorical features

These are all cases that could affect the code as written... for example, categorical features have different splitting rules.

Comment on lines +76 to +77
dt[, Value := 0.0]
dt[, Value := leaf_value]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dt[, Value := 0.0]
dt[, Value := leaf_value]
dt[, Value := leaf_value]

I don't understand this... what's the purpose of setting all rows to 0.0 and then immediately overwriting them? It seems to me that the 0.0 could probably be removed.

Comment on lines +78 to +96
dt[is.na(Value), Value := internal_value]
dt[is.na(Gain), Gain := leaf_value]
dt[is.na(Feature), Feature := "Leaf"]
dt[, Cover := internal_count][Feature == "Leaf", Cover := leaf_count]
dt[, c("leaf_count", "internal_count", "leaf_value", "internal_value") := NULL]
dt[, Node := split_index]
max_node <- max(dt[["Node"]], na.rm = TRUE)
dt[is.na(Node), Node := max_node + leaf_index + 1]
dt[, ID := paste(Tree, Node, sep = "-")]
dt[, c("depth", "leaf_index") := NULL]
dt[, parent := node_parent][is.na(parent), parent := leaf_parent]
dt[, c("node_parent", "leaf_parent", "split_index") := NULL]
dt[, Yes := dt$ID[match(dt$Node, dt$parent)]]
dt <- dt[nrow(dt):1, ]
dt[, No := dt$ID[match(dt$Node, dt$parent)]]
# which way do the NA's go (this path will get a thicker arrow)
# for categorical features, NA gets put into the zero group
dt[default_left == TRUE, Missing := Yes]
dt[default_left == FALSE, Missing := No]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add some comments to make it a bit easier to understand what's happening in this wall of code? It's very difficult to read (at least for me) as currently written).

# trees start from 0 in lgb.model.dt.tree
tree_table <- lgb.model.dt.tree(model)
expect_error({
lgb.plot.tree(model, 999)TRUE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lgb.plot.tree(model, 999)TRUE
lgb.plot.tree(model, 999)

This looks like it was included accidentally?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants