Skip to content

Commit

Permalink
refactor: adjust date col reflection
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 1, 2025
1 parent b907c96 commit 94f8e48
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 60 deletions.
13 changes: 7 additions & 6 deletions R/ResamplingForecastCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,13 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",

private = list(
.sample = function(ids, task, ...) {
if (length(task$col_roles$order) == 0L) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
self$id, task$id
)
}
# if (length(task$col_roles$order) == 0L) {
# stopf(
# "Resampling '%s' requires an ordered task, but Task '%s' has no order.",
# self$id, task$id
# )
# }

pars = self$param_set$get_values()
ids = sort(ids)
train_end = ids[ids <= (max(ids) - pars$horizon) & ids >= pars$window_size]
Expand Down
13 changes: 7 additions & 6 deletions R/ResamplingForecastHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",

private = list(
.sample = function(ids, task, ...) {
if (length(task$col_roles$order) == 0L) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
self$id, task$id
)
}
# if (length(task$col_roles$order) == 0L) {
# stopf(
# "Resampling '%s' requires an ordered task, but Task '%s' has no order.",
# self$id, task$id
# )
# }

pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
Expand Down
2 changes: 1 addition & 1 deletion R/TaskFcstAirpassengers.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ load_task_airpassengers = function(id = "airpassengers") {
if (!requireNamespace("tsbox", quietly = TRUE)) {
stopf("Package 'tsbox' is required to load the 'AirPassengers' dataset.")
}
dt = tsbox::ts_dt(load_dataset("AirPassengers", package = "datasets"))
dt = tsbox::ts_dt(load_dataset("AirPassengers", "datasets"))
setnames(dt, c("date", "passengers"))
b = as_data_backend(dt)

Expand Down
6 changes: 3 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ mlr3forecast_resamplings = new.env()
mlr3forecast_tasks = new.env()
mlr3forecast_learners = new.env()
mlr3forecast_measures = new.env()
mlr3forecast_feature_types = c(date = "Date")
# TODO: check if this can be moved to mlr3, copy components from @mllg PR
mlr3forecast_feature_types = c(dte = "Date")

named_union = function(x, y) set_names(union(x, y), union(names(x), names(y)))

Expand All @@ -35,7 +36,6 @@ register_mlr3 = function() {
"fcst", "mlr3forecast", "TaskFcst", "LearnerFcst", "PredictionFcst", "PredictionDataFcst", "MeasureFcst" # nolint
), fill = TRUE), "type")
mlr_reflections$learner_predict_types$fcst = mlr_reflections$learner_predict_types$regr
mlr_reflections$task_col_roles$fcst = union(mlr_reflections$task_col_roles$regr, "index")
mlr_reflections$task_feature_types = named_union(
mlr_reflections$task_feature_types, mlr3forecast_feature_types
)
Expand Down Expand Up @@ -79,7 +79,7 @@ register_mlr3 = function() {
mlr_reflections$task_types = mlr_reflections$task_types[!"fcst"]
mlr_reflections$task_feature_types =
mlr_reflections$task_feature_types[mlr_reflections$task_feature_types %nin% mlr3forecast_feature_types] # nolint
reflections = c("learner_predict_types", "task_col_roles", "task_properties")
reflections = c("learner_predict_types", "task_properties")
walk(reflections, function(x) mlr_reflections[[x]] = remove_named(mlr_reflections[[x]], "fcst"))
}

Expand Down
50 changes: 25 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,32 @@ prediction = ff$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 452.7575
#> 2 NA 474.8485
#> 3 NA 481.4720
#> 1 NA 446.9409
#> 2 NA 477.9439
#> 3 NA 480.5694
prediction = ff$predict(task, 142:144)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 461 458.1226
#> 2 390 412.7669
#> 3 432 396.2460
#> 1 461 459.4145
#> 2 390 411.2457
#> 3 432 400.4514
prediction$score(measure)
#> regr.rmse
#> 24.52863
#> 21.97883

ff = Forecaster$new(lrn("regr.ranger"), 1:3)
resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 112.7031
#> 105.0997

resampling = rsmp("forecast_cv")
rr = resample(task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 48.80621
#> 54.93903
```

### Multivariate
Expand All @@ -90,34 +90,34 @@ ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(new_task)
prediction = ff$predict(new_task, 142:144)
prediction$score(measure)
#> regr.rmse
#> 19.26131
#> 17.55705

row_ids = new_task$nrow - 0:2
ff$predict_newdata(new_task$data(rows = row_ids), new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 432 407.1069
#> 2 390 391.4622
#> 3 461 393.6115
#> 1 432 405.2216
#> 2 390 388.3066
#> 3 461 385.6412
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
ff$predict_newdata(newdata, new_task)
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 407.1069
#> 2 NA 391.4622
#> 3 NA 393.6115
#> 1 NA 405.2216
#> 2 NA 388.3066
#> 3 NA 385.6412

resampling = rsmp("forecast_holdout", ratio = 0.8)
rr = resample(new_task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 81.82989
#> 82.35283

resampling = rsmp("forecast_cv")
rr = resample(new_task, ff, resampling)
rr$aggregate(measure)
#> regr.rmse
#> 45.80208
#> 45.54337
```

### mlr3pipelines integration
Expand All @@ -128,7 +128,7 @@ glrn = as_learner(graph %>>% ff)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(measure)
#> regr.rmse
#> 34.39579
#> 34.29322
```

### Example: Forecasting electricity demand
Expand Down Expand Up @@ -166,11 +166,11 @@ prediction = glrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 14 observations:
#> row_ids truth response
#> 1 NA 187.1619
#> 2 NA 191.8612
#> 3 NA 184.2280
#> 1 NA 187.9399
#> 2 NA 190.5695
#> 3 NA 184.2617
#> --- --- ---
#> 12 NA 214.1141
#> 13 NA 216.5287
#> 14 NA 217.9717
#> 12 NA 214.6350
#> 13 NA 218.8392
#> 14 NA 221.4170
```
20 changes: 1 addition & 19 deletions man/TaskFcst.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 94f8e48

Please sign in to comment.