From 38b37dfb3ff1aa3af9e3d0d58128a0c0301dd4e2 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 20 Nov 2024 16:37:48 -0500 Subject: [PATCH] update partykit_tree_info() to handle classification outputs --- R/model-partykit.R | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/R/model-partykit.R b/R/model-partykit.R index beac61e..9fa5cda 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -1,9 +1,18 @@ partykit_tree_info <- function(model) { model_nodes <- map(seq_along(model), ~ model[[.x]]) is_split <- map_lgl(model_nodes, ~ class(.x$node[1]) == "partynode") - # non-cat model - mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) - prediction <- ifelse(!is_split, mean_resp, NA) + if (is.numeric(model_nodes[[1]]$fitted[["(response)"]])) { + mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) + prediction <- ifelse(!is_split, mean_resp, NA) + } else { + stat_mode <- function(x) { + counts <- sort(table(x)) + names(counts)[1] + } + mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"])) + prediction <- ifelse(!is_split, mode_resp, NA) + } + party_nodes <- map(seq_along(model), ~ partykit::nodeapply(model, .x)) kids <- map(party_nodes, ~ {