Skip to content

Commit

Permalink
refactor: rename to ForecastLearner
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 5, 2025
1 parent dd3c3b7 commit 53a75f9
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 79 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.3.2
Collate:
'DataBackendTimeSeries.R'
'Forecaster.R'
'ForecastLearner.R'
'zzz.R'
'MeasureDirectional.R'
'ResamplingForecastCV.R'
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ S3method(as_task_fcst,DataBackend)
S3method(as_task_fcst,TaskFcst)
S3method(as_task_fcst,data.frame)
export(DataBackendTimeSeries)
export(Forecaster)
export(ForecastLearner)
export(ResamplingForecastCV)
export(ResamplingForecastHoldout)
export(TaskFcst)
Expand Down
4 changes: 2 additions & 2 deletions R/Forecaster.R → R/ForecastLearner.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' @title Forecaster
#' @title Forecast Learner
#'
#' @export
Forecaster = R6::R6Class("Forecaster",
ForecastLearner = R6::R6Class("ForecastLearner",
inherit = Learner,
public = list(
#' @field learner ([Learner])\cr
Expand Down
6 changes: 5 additions & 1 deletion R/ResamplingForecastCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
),

private = list(
.sample = function(ids, task, ...) {
.sample = function(ids, ...) {
pars = self$param_set$get_values()
ids = sort(ids)
train_end = ids[ids <= (max(ids) - pars$horizon) & ids >= pars$window_size]
Expand All @@ -99,6 +99,10 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
list(train = train_ids, test = test_ids)
},

.sample_new = function(ids, task, ...) {
.NotYetImplemented()
},

.get_train = function(i) {
self$instance$train[[i]]
},
Expand Down
82 changes: 70 additions & 12 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ library(mlr3learners)
task = tsk("airpassengers")
task$select(setdiff(task$feature_names, "date"))
measure = msr("regr.rmse")
ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(task)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
newdata = data.frame(passengers = rep(NA_real_, 3L))
prediction = ff$predict_newdata(newdata, task)
prediction
prediction = ff$predict(task, 142:144)
prediction
prediction$score(measure)
ff = Forecaster$new(lrn("regr.ranger"), 1:3)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(task, ff, resampling)
rr$aggregate(measure)
Expand All @@ -85,7 +85,7 @@ graph = ppl("convert_types", "Date", "POSIXct") %>>%
param_vals = list(is_day = FALSE, hour = FALSE, minute = FALSE, second = FALSE)
)
new_task = graph$train(task)[[1L]]
ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(new_task)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(new_task)
prediction = ff$predict(new_task, 142:144)
prediction$score(measure)
Expand All @@ -106,7 +106,7 @@ rr$aggregate(measure)
### mlr3pipelines integration

```{r}
ff = Forecaster$new(lrn("regr.ranger"), 1:3)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)
glrn = as_learner(graph %>>% ff)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(measure)
Expand All @@ -123,16 +123,22 @@ task = tsibbledata::vic_elec |>
setnames(tolower) |>
_[
year(time) == 2014L,
.(demand = sum(demand) / 1e3, temperature = max(temperature), holiday = any(holiday)),
.(
demand = sum(demand) / 1e3,
temperature = max(temperature),
holiday = any(holiday)
),
by = date
] |>
as_task_fcst(target = "demand", index = "date")
graph = ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(year = FALSE, is_day = FALSE, hour = FALSE, minute = FALSE, second = FALSE)
param_vals = list(
year = FALSE, is_day = FALSE, hour = FALSE, minute = FALSE, second = FALSE
)
)
ff = Forecaster$new(lrn("regr.ranger"), 1:3)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)
glrn = as_learner(graph %>>% ff)$train(task)
max_date = task$data()[.N, date]
Expand All @@ -151,7 +157,7 @@ prediction
```{r, message = FALSE}
library(mlr3learners)
library(mlr3pipelines)
library(tsibble)
library(tsibble) # needs not be loaded for it to somehow work
task = tsibbledata::aus_livestock |>
as.data.table() |>
Expand All @@ -164,18 +170,70 @@ task = tsibbledata::aus_livestock |>
graph = ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
week_of_year = FALSE, day_of_week = FALSE, day_of_month = FALSE, day_of_year = FALSE,
is_day = FALSE, hour = FALSE, minute = FALSE, second = FALSE
week_of_year = FALSE, day_of_week = FALSE, day_of_month = FALSE,
day_of_year = FALSE, is_day = FALSE, hour = FALSE, minute = FALSE,
second = FALSE
)
)
task = graph$train(task)[[1L]]
ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(task)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
prediction = ff$predict(task, 4460:4464)
prediction$score(measure)
ff = Forecaster$new(lrn("regr.ranger"), 1:3)
ff = ForecastLearner$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(task, ff, resampling)
rr$aggregate(measure)
```

### Example: Global vs Local Forecasting

```{r}
# TODO: find better task example, since the effect is minor here
graph = ppl("convert_types", "Date", "POSIXct") %>>%
po("datefeatures",
param_vals = list(
week_of_year = FALSE, day_of_week = FALSE, day_of_month = FALSE,
day_of_year = FALSE, is_day = FALSE, hour = FALSE, minute = FALSE,
second = FALSE
)
)
# local forecasting
task = tsibbledata::aus_livestock |>
as.data.table() |>
setnames(tolower) |>
_[, month := as.Date(month)] |>
_[state == "Western Australia", .(count = sum(count)), by = .(month)] |>
setorder(month) |>
as_task_fcst(target = "count", index = "month")
task = graph$train(task)[[1L]]
ff = ForecastLearner$new(lrn("regr.ranger"), 1L)$train(task)
tab = task$backend$data(
rows = task$row_ids, cols = c(task$backend$primary_key, "month.year")
)
setnames(tab, c("row_id", "year"))
row_ids = tab[year >= 2015, row_id]
prediction = ff$predict(task, row_ids)
prediction$score(measure)
# global forecasting
task = tsibbledata::aus_livestock |>
as.data.table() |>
setnames(tolower) |>
_[, month := as.Date(month)] |>
_[, .(count = sum(count)), by = .(state, month)] |>
setorder(state, month) |>
as_task_fcst(target = "count", index = "month", key = "state")
task = graph$train(task)[[1L]]
ff = ForecastLearner$new(lrn("regr.ranger"), 1L)$train(task)
tab = task$backend$data(
rows = task$row_ids, cols = c(task$backend$primary_key, "month.year", "state")
)
setnames(tab, c("row_id", "year", "state"))
row_ids = tab[year >= 2015 & state == "Western Australia", row_id]
prediction = ff$predict(task, row_ids)
prediction$score(measure)
```
Loading

0 comments on commit 53a75f9

Please sign in to comment.