Skip to content

Commit

Permalink
feat: try to use custom field for key instead of col roles to make gl…
Browse files Browse the repository at this point in the history
…obal forecasting work
  • Loading branch information
m-muecke committed Jan 3, 2025
1 parent d1f3df2 commit f19594e
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 37 deletions.
6 changes: 5 additions & 1 deletion R/Forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ Forecaster = R6::R6Class("Forecaster",
lag = self$lag
nms = sprintf("%s_lag_%s", target, lag)
dt = copy(dt)
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
if (is.null(private$.task$key)) {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
} else {
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = get(private$.task$key), .SDcols = target]
}
dt
},

Expand Down
50 changes: 48 additions & 2 deletions R/ResamplingForecastHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,66 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
nr = max(n_obs + n, 0L)
}

ids = sort(ids)
ii = ids[1:nr]
list(train = ii, test = ids[(nr + 1L):n_obs])
},

.sample_new = function(ids, task, ...) {
pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
n_obs = task$nrow

has_ratio = !is.null(ratio)
if (!xor(!has_ratio, is.null(n))) {
stopf("Either parameter `ratio` (x)or `n` must be provided.")
}
group_cols = task$col_roles$group
has_group = length(group_cols) > 0L
if (has_ratio) {
nr = round(n_obs * ratio)
if (has_group) {
nr = floor(nr / length(ids))
}
} else if (n > 0L) {
nr = min(n_obs, n)
} else {
nr = max(n_obs + n, 0L)
}

# for ratio this needs to be adjusted for i.e. divided by the groups,
# n is fine, but would have to be documented for group usage
group_cols = task$col_roles$group
# note: by = NULL also works, could make it type consistent
if (length(group_cols) > 0L) {
tab = task$backend$data(rows = task$row_ids, cols = c(task$backend$primary_key, group_cols))
setnames(tab, c("row_id", "group"))
# assumes its sorted correct
res = list(
train = tab[, .SD[1:nr], by = group][, row_id],
test = tab[, .SD[(nr + 1L):.N], by = group][, row_id]
)
browser()
return(res)
}

browser()

if (TRUE) {
ids = sort(ids)
ii = ids[1:nr]
list(train = ii, test = ids[(nr + 1L):n_obs])
} else {
# check when this is even needed
order = row_id = NULL
order_cols = private$.col_roles$order
order_cols = task$col_roles$order
tab = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols))
setnames(tab, c("row_id", "order"))
setorder(tab, order)
list(
train = tab[1:nr, row_id],
test = tab[(nr + 1L):n_obs, row_id]
test = tab[(nr + 1L):.N, row_id]
)
}
},
Expand Down
8 changes: 7 additions & 1 deletion R/TaskFcst.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
TaskFcst = R6::R6Class("TaskFcst",
inherit = TaskRegr,
public = list(
#' @field key (`character(1)`)\cr
#' Key of data.
key = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#' The function [as_task_fcst()] provides an alternative way to construct forecast tasks.
#'
#' @template param_target
#' @template param_label
#' @template param_extra_args
initialize = function(id, backend, target, label = NA_character_, extra_args = list()) { # nolint
#' @param key (`character(1)`) key of data.
initialize = function(id, backend, target, key = NULL, label = NA_character_, extra_args = list()) { # nolint
assert_string(target)

super$initialize(
Expand All @@ -42,6 +47,7 @@ TaskFcst = R6::R6Class("TaskFcst",
label = label,
extra_args = extra_args
)
self$key = key
},

#' @description
Expand Down
9 changes: 5 additions & 4 deletions R/as_task_fcst.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ as_task_fcst.TaskFcst = function(x, clone = FALSE, ...) {

#' @rdname as_task_fcst
#' @export
as_task_fcst.DataBackend = function(x, target = NULL, index = NULL, id = deparse1(substitute(x)), label = NA_character_, ...) { # nolint
as_task_fcst.DataBackend = function(x, target = NULL, index = NULL, key = NULL, id = deparse1(substitute(x)), label = NA_character_, ...) { # nolint
force(id)

assert_choice(target, x$colnames)
assert_choice(index, x$colnames)

task = TaskFcst$new(
id = id, backend = x, target = target, target = target, label = label, ...
id = id, backend = x, target = target, target = target, key = key, label = label, ...
)
task$col_roles$order = index
task
Expand All @@ -49,19 +49,20 @@ as_task_fcst.DataBackend = function(x, target = NULL, index = NULL, id = deparse
#' Name of the column in the data containing the index.
#' @template param_label
#' @export
as_task_fcst.data.frame = function(x, target = NULL, index = NULL, id = deparse1(substitute(x)), label = NA_character_, ...) { # nolint
as_task_fcst.data.frame = function(x, target = NULL, index = NULL, key = NULL, id = deparse1(substitute(x)), label = NA_character_, ...) { # nolint
force(id)

assert_data_frame(x, min.rows = 1L, min.cols = 1L, col.names = "unique")
assert_choice(target, names(x))
assert_choice(index, names(x))
assert_choice(key, names(x), null.ok = TRUE)

ii = which(map_lgl(keep(x, is.double), anyInfinite))
if (length(ii)) {
warningf("Detected columns with unsupported Inf values in data: %s", str_collapse(names(ii)))
}

task = TaskFcst$new(id = id, backend = x, target = target, label = label)
task = TaskFcst$new(id = id, backend = x, target = target, key = key, label = label)
task$col_roles$order = index
task
}
9 changes: 6 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ mlr3forecast_resamplings = new.env()
mlr3forecast_tasks = new.env()
mlr3forecast_learners = new.env()
mlr3forecast_measures = new.env()
# TODO: check if this can be moved to mlr3, copy components from @mllg PR
mlr3forecast_feature_types = c(dte = "Date")
# mlr3forecast_col_roles = "key"

named_union = function(x, y) set_names(union(x, y), union(names(x), names(y)))

Expand All @@ -33,9 +33,12 @@ register_mlr3 = function() {
mlr_reflections$task_types = mlr_reflections$task_types[!"fcst"]
mlr_reflections$task_types = setkeyv(rbind(mlr_reflections$task_types, rowwise_table(
~type, ~package, ~task, ~learner, ~prediction, ~prediction_data, ~measure,
"fcst", "mlr3forecast", "TaskFcst", "LearnerFcst", "PredictionFcst", "PredictionDataFcst", "MeasureFcst" # nolint
"fcst", "mlr3forecast", "TaskFcst", "LearnerRegr", "PredictionFcst", "PredictionDataFcst", "MeasureRegr" # nolint
), fill = TRUE), "type")
mlr_reflections$learner_predict_types$fcst = mlr_reflections$learner_predict_types$regr
# mlr_reflections$task_col_roles$fcst = union(
# mlr_reflections$task_col_roles$regr, mlr3forecast_col_roles
# )
mlr_reflections$task_feature_types = named_union(
mlr_reflections$task_feature_types, mlr3forecast_feature_types
)
Expand Down Expand Up @@ -79,7 +82,7 @@ register_mlr3 = function() {
mlr_reflections$task_types = mlr_reflections$task_types[!"fcst"]
mlr_reflections$task_feature_types =
mlr_reflections$task_feature_types[mlr_reflections$task_feature_types %nin% mlr3forecast_feature_types] # nolint
reflections = c("learner_predict_types", "task_properties")
reflections = c("learner_predict_types", "task_col_roles", "task_properties")
walk(reflections, function(x) mlr_reflections[[x]] = remove_named(mlr_reflections[[x]], "fcst"))
}

Expand Down
33 changes: 33 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,36 @@ newdata = data.frame(
prediction = glrn$predict_newdata(newdata, task)
prediction
```

### Global Forecasting

```{r}
library(mlr3learners)
library(mlr3pipelines)
library(tsibble)
task = tsibbledata::aus_livestock |>
as.data.table() |>
setnames(tolower) |>
_[, month := as.Date(month)] |>
_[, .(count = sum(count)), by = .(state, month)] |>
setorder(state, month) |>
as_task_fcst(target = "count", index = "month", key = "state")
graph = ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
week_of_year = FALSE, day_of_week = FALSE, day_of_month = FALSE, day_of_year = FALSE,
is_day = FALSE, hour = FALSE, minute = FALSE, second = FALSE
)
)
task = graph$train(task)[[1L]]
ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(task)
prediction = ff$predict(task, 4460:4464)
prediction$score(measure)
# resampling = rsmp("forecast_holdout", ratio = 0.8)
# rr = resample(task, ff, resampling)
# rr$aggregate(measure)
```
96 changes: 71 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,32 @@ prediction = ff$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 448.8710
#> 2 NA 475.2456
#> 3 NA 480.5179
#> 1 NA 447.8017
#> 2 NA 473.3637
#> 3 NA 486.9652
prediction = ff$predict(task, 142:144)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 461 456.4968
#> 2 390 411.1712
#> 3 432 393.9585
#> 1 461 461.4039
#> 2 390 412.0604
#> 3 432 393.8162
prediction$score(measure)
#> regr.rmse
#> 25.26957
#> 25.46126

ff = Forecaster$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 105.8215
#> 108.0431

resampling = rsmp("forecast_cv")
rr = resample(task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 54.28352
#> 53.05832
```

### Multivariate
Expand All @@ -90,34 +90,34 @@ ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(new_task)
prediction = ff$predict(new_task, 142:144)
prediction$score(measure)
#> regr.rmse
#> 17.0878
#> 18.86887

row_ids = new_task$nrow - 0:2
ff$predict_newdata(new_task$data(rows = row_ids), new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 432 405.5814
#> 2 390 388.3657
#> 3 461 390.9778
#> 1 432 409.5521
#> 2 390 390.2928
#> 3 461 392.8769
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
ff$predict_newdata(newdata, new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 405.5814
#> 2 NA 388.3657
#> 3 NA 390.9778
#> 1 NA 409.5521
#> 2 NA 390.2928
#> 3 NA 392.8769

resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(new_task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 81.91252
#> 82.88035

resampling = rsmp("forecast_cv")
rr = resample(new_task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 41.87113
#> 44.645
```

### mlr3pipelines integration
Expand All @@ -128,7 +128,7 @@ glrn = as_learner(graph %>>% ff)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(measure)
#> regr.rmse
#> 33.74039
#> 32.7928
```

### Example: Forecasting electricity demand
Expand Down Expand Up @@ -166,11 +166,57 @@ prediction = glrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 14 observations:
#> row_ids truth response
#> 1 NA 187.6208
#> 2 NA 191.8121
#> 3 NA 183.6753
#> 1 NA 187.1951
#> 2 NA 191.1492
#> 3 NA 184.2040
#> --- --- ---
#> 12 NA 213.8759
#> 13 NA 218.4198
#> 14 NA 218.8139
#> 12 NA 213.9886
#> 13 NA 217.0293
#> 14 NA 219.1662
```

### Global Forecasting

``` r
library(mlr3learners)
library(mlr3pipelines)
library(tsibble)
#> Registered S3 method overwritten by 'tsibble':
#> method from
#> as_tibble.grouped_df dplyr
#>
#> Attaching package: 'tsibble'
#> The following object is masked from 'package:data.table':
#>
#> key
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, union

task = tsibbledata::aus_livestock |>
as.data.table() |>
setnames(tolower) |>
_[, month := as.Date(month)] |>
_[, .(count = sum(count)), by = .(state, month)] |>
setorder(state, month) |>
as_task_fcst(target = "count", index = "month", key = "state")

graph = ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
week_of_year = FALSE, day_of_week = FALSE, day_of_month = FALSE, day_of_year = FALSE,
is_day = FALSE, hour = FALSE, minute = FALSE, second = FALSE
)
)
task = graph$train(task)[[1L]]

ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(task)
prediction = ff$predict(task, 4460:4464)
prediction$score(measure)
#> regr.rmse
#> 23554.31

# resampling = rsmp("forecast_holdout", ratio = 0.8)
# rr = resample(task, ff, resampling)
# rr$aggregate(measure)
```
Loading

0 comments on commit f19594e

Please sign in to comment.