diff --git a/R/model-partykit.R b/R/model-partykit.R index beac61e..9fa5cda 100644 --- a/R/model-partykit.R +++ b/R/model-partykit.R @@ -1,9 +1,18 @@ partykit_tree_info <- function(model) { model_nodes <- map(seq_along(model), ~ model[[.x]]) is_split <- map_lgl(model_nodes, ~ class(.x$node[1]) == "partynode") - # non-cat model - mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) - prediction <- ifelse(!is_split, mean_resp, NA) + if (is.numeric(model_nodes[[1]]$fitted[["(response)"]])) { + mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"])) + prediction <- ifelse(!is_split, mean_resp, NA) + } else { + stat_mode <- function(x) { + counts <- sort(table(x)) + names(counts)[1] + } + mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"])) + prediction <- ifelse(!is_split, mode_resp, NA) + } + party_nodes <- map(seq_along(model), ~ partykit::nodeapply(model, .x)) kids <- map(party_nodes, ~ {