Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lightning] refactored Trainer args #959

Merged
merged 2 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,6 @@ class NeuralProphet:
Options
* ``True``: test data is normalized with global data params even if trained with local data params (global modeling with local normalization)
* (default) ``False``: no global modeling with local normalization
logger: str
Name of logger from pytorch_lightning.loggers to log metrics to.

Options
* TensorBoardLogger
* CSVLogger
* (MLFlowLogger)
* (NeptuneLogger)
* (CometLogger)
* (WandbLogger)
trainer_config: dict
Dictionary of additional trainer configuration parameters.
"""
Expand Down Expand Up @@ -346,7 +336,6 @@ def __init__(
global_normalization=False,
global_time_normalization=True,
unknown_data_normalization=False,
logger=None,
trainer_config={},
):
kwargs = locals()
Expand Down Expand Up @@ -411,7 +400,6 @@ def __init__(

# Pytorch Lightning Trainer
self.metrics_logger = MetricsLogger(save_dir=os.getcwd())
self.additional_logger = logger
self.trainer_config = trainer_config
self.trainer = None

Expand Down Expand Up @@ -2387,8 +2375,9 @@ def _train(self, df, df_val=None, minimal=False, continue_training=False):
config_train=self.config_train,
config=self.trainer_config,
metrics_logger=self.metrics_logger,
additional_logger=self.additional_logger,
early_stopping_target="Loss_val" if df_val is not None else "Loss",
minimal=minimal,
num_batches_per_epoch=len(train_loader),
)

# Set parameters for the learning rate finder
Expand Down Expand Up @@ -2449,7 +2438,6 @@ def restore_trainer(self):
config_train=self.config_train,
config=self.trainer_config,
metrics_logger=self.metrics_logger,
additional_logger=self.additional_logger,
)
self.metrics = metrics.get_metrics(self.collect_metrics)

Expand Down
40 changes: 19 additions & 21 deletions neuralprophet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,9 @@ def configure_trainer(
config_train: dict,
config: dict,
metrics_logger,
additional_logger: str = None,
early_stopping_target: str = "Loss",
minimal=False,
num_batches_per_epoch=100,
):
"""
Configures the PyTorch Lightning trainer.
Expand All @@ -761,10 +762,12 @@ def configure_trainer(
dictionary containing the custom PyTorch Lightning trainer configuration.
metrics_logger : MetricsLogger
MetricsLogger object to log metrics to.
additional_logger : str
Name of logger from pytorch_lightning.loggers to log metrics to.
early_stopping_target : str
Target metric to use for early stopping.
minimal : bool
If True, no metrics are logged and no progress bar is displayed.
num_batches_per_epoch : int
Number of batches per epoch.

Returns
-------
Expand All @@ -782,37 +785,32 @@ def configure_trainer(
if config_train.epochs is not None:
config["max_epochs"] = config_train.epochs

# Auto-configure the metric logging frequency
if "log_every_n_steps" not in config.keys() and "max_epochs" in config.keys():
config["log_every_n_steps"] = math.ceil(config["max_epochs"] / 10)

# Configure the logthing-logs directory
if "default_root_dir" not in config.keys():
config["default_root_dir"] = os.getcwd()

# Configure the loggers
# TODO: technically additional loggers work, but somehow the TensorBoard logger interferes with the custom
# metrics logger. Resolve before activating this feature. Docs: https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html
# if additional_logger in pl.loggers.__all__: # TODO: pl.loggers.__all__ seems to be incomplete
# Logger = importlib.import_module(f"pytorch_lightning.loggers.__init__").__dict__[additional_logger]
# config["logger"] = [Logger(config["default_root_dir"]), metrics_logger]
# elif additional_logger is not None:
# log.error(f"Additional logger {additional_logger} not found in pytorch_lightning.loggers")
if additional_logger:
log.error("Additional loggers are not yet supported")
# Configure callbacks
callbacks = []

# Configure the logger
if minimal:
config["enable_progress_bar"] = False
config["enable_model_summary"] = False
config["logger"] = False
else:
config["logger"] = metrics_logger

# Configure callbacks
config["callbacks"] = []
# Configure the progress bar, refresh every 2nd batch
prog_bar_callback = pl.callbacks.TQDMProgressBar(refresh_rate=max(1, int(num_batches_per_epoch / 4)))
callbacks.append(prog_bar_callback)

# Early stopping monitor
if config_train.early_stopping:
early_stop_callback = pl.callbacks.EarlyStopping(
monitor=early_stopping_target, mode="min", patience=20, divergence_threshold=5.0
)
config["callbacks"].append(early_stop_callback)
callbacks.append(early_stop_callback)

config["callbacks"] = callbacks
config["num_sanity_val_steps"] = 0

return pl.Trainer(**config)