Skip to content

Commit

Permalink
Added plot param of the NeuralProphet fit() and uncomment test_progre…
Browse files Browse the repository at this point in the history
…ss_display in test_integration.py. (#958)

* Added plot param of the NeuralProphet fit().

* Applied black formatting to forecaster.py.

Co-authored-by: Kevin R. Chen <[email protected]>
  • Loading branch information
Kevin-Chen0 and Kevin-Chen2 authored Nov 15, 2022
1 parent 974ee77 commit 686335a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
9 changes: 7 additions & 2 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,9 @@ def add_seasonality(self, name, period, fourier_order):
self.config_season.append(name=name, period=period, resolution=fourier_order, arg="custom")
return self

def fit(self, df, freq="auto", validation_df=None, progress="bar", minimal=False, continue_training=False):
def fit(
self, df, freq="auto", validation_df=None, progress="bar", minimal=False, continue_training=False, plot=True
):
"""Train, and potentially evaluate model.
Training/validation metrics may be distorted in case of auto-regression,
Expand All @@ -657,6 +659,8 @@ def fit(self, df, freq="auto", validation_df=None, progress="bar", minimal=False
whether to train without any printouts or metrics collection
continue_training : bool
whether to continue training from the last checkpoint
plot : bool
where to show the progress plot or not
Returns
-------
Expand Down Expand Up @@ -725,7 +729,8 @@ def fit(self, df, freq="auto", validation_df=None, progress="bar", minimal=False
_ = plt.plot(metrics_df[["Loss"]])
else:
_ = plt.plot(metrics_df[["Loss", "Loss_val"]])
plt.show()
if plot:
plt.show()

self.fitted = True
return metrics_df
Expand Down
4 changes: 1 addition & 3 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,6 @@ def test_metrics():
forecast = m2.predict(df)


"""
def test_progress_display():
log.info("testing: Progress Display")
df = pd.read_csv(AIR_FILE, nrows=100)
Expand All @@ -1445,8 +1444,7 @@ def test_progress_display():
batch_size=BATCH_SIZE,
learning_rate=LR,
)
metrics_df = m.fit(df, progress=progress)
"""
metrics_df = m.fit(df, progress=progress, plot=PLOT)


def test_n_lags_for_regressors():
Expand Down

0 comments on commit 686335a

Please sign in to comment.