Skip to content

Commit

Permalink
Merge branch 'main' into memory-efficient-stride
Browse files Browse the repository at this point in the history
  • Loading branch information
leoniewgnr authored Aug 14, 2023
2 parents fddf249 + 4c31824 commit 19cebdb
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 27 deletions.
7 changes: 5 additions & 2 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,15 @@ class NeuralProphet:
* (default) ``True``: [``mae``, ``rmse``]
* ``False``: No metrics
* ``list``: Valid options: [``mae``, ``rmse``, ``mse``]
* ``dict``: Collection of torchmetrics.Metric objects
* ``dict``: Collection of names of torchmetrics.Metric objects
Examples
--------
>>> from neuralprophet import NeuralProphet
>>> # computer MSE, MAE and RMSE
>>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"])
>>> # use custorm torchmetrics names
>>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError",
COMMENT
Uncertainty Estimation
Expand Down Expand Up @@ -366,7 +369,7 @@ def __init__(
impute_linear: int = 10,
impute_rolling: int = 10,
drop_missing: bool = False,
collect_metrics: np_types.CollectMetricsMode = True,
collect_metrics: Union[bool, list, dict] = True,
normalize: np_types.NormalizeMode = "auto",
global_normalization: bool = False,
global_time_normalization: bool = True,
Expand Down
5 changes: 2 additions & 3 deletions neuralprophet/np_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import sys
from typing import Dict, List, Union
from typing import Dict, Union

import torch
import torchmetrics

# Ensure compatibility with python 3.7
if sys.version_info >= (3, 8):
Expand All @@ -19,7 +18,7 @@

GrowthMode = Literal["off", "linear", "discontinuous"]

CollectMetricsMode = Union[List[str], bool, Dict[str, torchmetrics.Metric]]
CollectMetricsMode = Union[Dict, bool]

SeasonGlobalLocalMode = Literal["global", "local"]

Expand Down
1 change: 1 addition & 0 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(
# Metrics Config
self.metrics_enabled = bool(metrics) # yields True if metrics is not an empty dictionary
if self.metrics_enabled:
metrics = {metric: torchmetrics.__dict__[metrics[metric][0]](**metrics[metric][1]) for metric in metrics}
self.log_args = {
"on_step": False,
"on_epoch": True,
Expand Down
30 changes: 17 additions & 13 deletions neuralprophet/utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
log = logging.getLogger("NP.metrics")

METRICS = {
"MAE": torchmetrics.MeanAbsoluteError(),
"MSE": torchmetrics.MeanSquaredError(squared=True),
"RMSE": torchmetrics.MeanSquaredError(squared=False),
# "short_name": [torchmetrics.Metric name, {optional args}]
"MAE": ["MeanAbsoluteError", {}],
"MSE": ["MeanSquaredError", {"squared": True}],
"RMSE": ["MeanSquaredError", {"squared": False}],
}


def get_metrics(metric_input):
"""
Returns a list of metrics.
Returns a dict of metrics.
Parameters
----------
Expand All @@ -23,29 +24,32 @@ def get_metrics(metric_input):
Returns
-------
dict
Dict of torchmetrics.Metric metrics.
Dict of names of torchmetrics.Metric metrics
"""
if metric_input is None:
return {}
elif metric_input is True:
return {k: v for k, v in METRICS.items() if k in ["MAE", "RMSE"]}
return {"MAE": METRICS["MAE"], "RMSE": METRICS["RMSE"]}
elif isinstance(metric_input, str):
if metric_input.upper() in METRICS.keys():
return {metric_input: METRICS[metric_input]}
return {metric_input: METRICS[metric_input.upper()]}
else:
raise ValueError("Received unsupported argument for collect_metrics.")
elif isinstance(metric_input, list):
if all([m.upper() in METRICS.keys() for m in metric_input]):
return {k: v for k, v in METRICS.items() if k in metric_input}
return {m: METRICS[m.upper()] for m in metric_input}
else:
raise ValueError("Received unsupported argument for collect_metrics.")
elif isinstance(metric_input, dict):
if all([isinstance(_metric, torchmetrics.Metric) for _, _metric in metric_input.items()]):
return metric_input
else:
# check if all values are names belonging to torchmetrics.Metric
try:
for _metric in metric_input.values():
torchmetrics.__dict__[_metric]()
except KeyError:
raise ValueError(
"Received unsupported argument for collect_metrics. All metrics must be an instance of "
"torchmetrics.Metric."
"Received unsupported argument for collect_metrics."
"All metrics must be valid names of torchmetrics.Metric objects."
)
return {k: [v, {}] for k, v in metric_input.items()}
elif metric_input is not False:
raise ValueError("Received unsupported argument for collect_metrics.")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "neuralprophet"
version = "1.0.0rc2"
version = "1.0.0rc3"
description = "NeuralProphet is an easy to learn framework for interpretable time series forecasting."
authors = ["Oskar Triebe <[email protected]>"]
license = "MIT"
Expand Down
50 changes: 42 additions & 8 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pandas as pd
import pytest
import torch
import torchmetrics

from neuralprophet import NeuralProphet, df_utils, set_random_seed
from neuralprophet.data.process import _handle_missing_data, _validate_column_name
Expand Down Expand Up @@ -1367,25 +1366,60 @@ def test_get_latest_forecast():
def test_metrics():
log.info("testing: Plotting")
df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
m = NeuralProphet(
# list
m_list = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=["MAE", "MSE", "RMSE"],
)
metrics_df = m.fit(df, freq="D")
metrics_df = m_list.fit(df, freq="D")
assert all([metric in metrics_df.columns for metric in ["MAE", "MSE", "RMSE"]])
m.predict(df)
m_list.predict(df)

m2 = NeuralProphet(
# dict
m_dict = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics={"ABC": torchmetrics.MeanAbsoluteError()},
collect_metrics={"ABC": "MeanSquaredLogError"},
)
metrics_df = m2.fit(df, freq="D")
metrics_df = m_dict.fit(df, freq="D")
assert "ABC" in metrics_df.columns
m2.predict(df)
m_dict.predict(df)

# string
m_string = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics="MAE",
)
metrics_df = m_string.fit(df, freq="D")
assert "MAE" in metrics_df.columns
m_string.predict(df)

# False
m_false = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=False,
)
metrics_df = m_false.fit(df, freq="D")
assert metrics_df is None
m_false.predict(df)

# None
m_none = NeuralProphet(
epochs=EPOCHS,
batch_size=BATCH_SIZE,
learning_rate=LR,
collect_metrics=None,
)
metrics_df = m_none.fit(df, freq="D")
assert metrics_df is None
m_none.predict(df)


def test_progress_display():
Expand Down

0 comments on commit 19cebdb

Please sign in to comment.