Skip to content

Commit

Permalink
[MNT] isolate matplotlib as soft dependency (#1636)
Browse files Browse the repository at this point in the history
Isolates `matplotlib` as soft dependency in a new soft dep set `all_extras`. #1616

The imports happen in `plot_sth` methods throughout the code base, some attached to classes, some not.
This allows to use `pytorch-forecasting` without plotting or graphical logging, or use a different plotting backend manually.

Isolation strategy:

* where the purpose of the function is creating a plot or nothing else, absence of `matplotlib` raises an exception
* where `matplotlib` is additional or part of the logic, absence of `matplotlib` causes the specific parts to be skipped. Example: `log_gradient_flow` - this is crucial, as raising an exception would prevent the models from running.
  • Loading branch information
fkiraly authored Sep 4, 2024
1 parent 3da947b commit 95fa06c
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev,github-actions,mqf2]"
python -m pip install ".[dev,all_extras,github-actions]"
- name: Show dependencies
run: python -m pip list
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ dependencies = [
"scipy >=1.8,<2.0",
"pandas >=1.3.0,<3.0.0",
"scikit-learn >=1.2,<2.0",
"matplotlib",
"pytorch-optimizer >=2.5.1,<4.0.0",
]

Expand All @@ -84,6 +83,7 @@ dependencies = [
#
all_extras = [
"cpflows",
"matplotlib",
"optuna >=3.1.0,<4.0.0",
"optuna-integration",
"statsmodels",
Expand Down
10 changes: 6 additions & 4 deletions pytorch_forecasting/data/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Any, Callable, Dict, List, Tuple, Union
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.exceptions import NotFittedError
Expand All @@ -32,6 +31,7 @@
)
from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler
from pytorch_forecasting.utils import repr_class
from pytorch_forecasting.utils._dependencies import _check_matplotlib


def _find_end_indices(diffs: np.ndarray, max_lengths: np.ndarray, min_length: int) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -1357,9 +1357,7 @@ def decoded_index(self) -> pd.DataFrame:
)
return index

def plot_randomization(
self, betas: Tuple[float, float] = None, length: int = None, min_length: int = None
) -> Tuple[plt.Figure, torch.Tensor]:
def plot_randomization(self, betas: Tuple[float, float] = None, length: int = None, min_length: int = None):
"""
Plot expected randomized length distribution.
Expand All @@ -1372,6 +1370,10 @@ def plot_randomization(
Returns:
Tuple[plt.Figure, torch.Tensor]: tuple of figure and histogram based on 1000 samples
"""
_check_matplotlib("plot_randomization")

import matplotlib.pyplot as plt

if betas is None:
betas = self.randomize_length
if length is None:
Expand Down
30 changes: 26 additions & 4 deletions pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from lightning.pytorch.callbacks import BasePredictionWriter, LearningRateFinder
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.parsing import get_init_args
import matplotlib.pyplot as plt
import numpy as np
from numpy import iterable
import pandas as pd
Expand Down Expand Up @@ -55,6 +54,7 @@
groupby_apply,
to_list,
)
from pytorch_forecasting.utils._dependencies import _check_matplotlib

# todo: compile models

Expand Down Expand Up @@ -940,6 +940,12 @@ def log_prediction(
)
else:
log_indices = [0]

mpl_available = _check_matplotlib("plot_prediction", raise_error=False)

if not mpl_available:
return None # don't log matplotlib plots if not available

for idx in log_indices:
fig = self.plot_prediction(x, out, idx=idx, add_loss_to_title=True, **kwargs)
tag = f"{self.current_stage} prediction"
Expand Down Expand Up @@ -971,7 +977,7 @@ def plot_prediction(
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
):
"""
Plot prediction of prediction vs actuals
Expand All @@ -990,6 +996,10 @@ def plot_prediction(
Returns:
matplotlib figure
"""
_check_matplotlib("plot_prediction")

from matplotlib import pyplot as plt

# all true values for y of the first sample in batch
encoder_targets = to_list(x["encoder_target"])
decoder_targets = to_list(x["decoder_target"])
Expand Down Expand Up @@ -1103,6 +1113,14 @@ def log_gradient_flow(self, named_parameters: Dict[str, torch.Tensor]) -> None:
layers.append(name)
ave_grads.append(p.grad.abs().cpu().mean())
self.logger.experiment.add_histogram(tag=name, values=p.grad, global_step=self.global_step)

mpl_available = _check_matplotlib("log_gradient_flow", raise_error=False)

if not mpl_available:
return None

import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(ave_grads)
ax.set_xlabel("Layers")
Expand Down Expand Up @@ -1842,7 +1860,7 @@ def calculate_prediction_actual_by_variable(

def plot_prediction_actual_by_variable(
self, data: Dict[str, Dict[str, torch.Tensor]], name: str = None, ax=None, log_scale: bool = None
) -> Union[Dict[str, plt.Figure], plt.Figure]:
):
"""
Plot predicions and actual averages by variables
Expand All @@ -1860,6 +1878,10 @@ def plot_prediction_actual_by_variable(
Returns:
Union[Dict[str, plt.Figure], plt.Figure]: matplotlib figure
"""
_check_matplotlib("plot_prediction_actual_by_variable")

from matplotlib import pyplot as plt

if name is None: # run recursion for figures
figs = {name: self.plot_prediction_actual_by_variable(data, name) for name in data["support"].keys()}
return figs
Expand Down Expand Up @@ -2230,7 +2252,7 @@ def plot_prediction(
ax=None,
quantiles_kwargs: Dict[str, Any] = {},
prediction_kwargs: Dict[str, Any] = {},
) -> plt.Figure:
):
"""
Plot prediction of prediction vs actuals
Expand Down
13 changes: 11 additions & 2 deletions pytorch_forecasting/models/nbeats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import Dict, List

import matplotlib.pyplot as plt
import torch
from torch import nn

Expand All @@ -13,6 +12,7 @@
from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric
from pytorch_forecasting.models.base_model import BaseModel
from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock
from pytorch_forecasting.utils._dependencies import _check_matplotlib


class NBeats(BaseModel):
Expand Down Expand Up @@ -263,6 +263,11 @@ def log_interpretation(self, x, out, batch_idx):
"""
Log interpretation of network predictions in tensorboard.
"""
mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
return None

label = ["val", "train"][self.training]
if self.log_interval > 0 and batch_idx % self.log_interval == 0:
fig = self.plot_interpretation(x, out, idx=0)
Expand All @@ -280,7 +285,7 @@ def plot_interpretation(
idx: int,
ax=None,
plot_seasonality_and_generic_on_secondary_axis: bool = False,
) -> plt.Figure:
):
"""
Plot interpretation.
Expand All @@ -299,6 +304,10 @@ def plot_interpretation(
Returns:
plt.Figure: matplotlib figure
"""
_check_matplotlib("plot_interpretation")

import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots(2, 1, figsize=(6, 8))
else:
Expand Down
13 changes: 11 additions & 2 deletions pytorch_forecasting/models/nhits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import copy
from typing import Dict, List, Optional, Tuple, Union

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
Expand All @@ -17,6 +16,7 @@
from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule
from pytorch_forecasting.models.nn.embeddings import MultiEmbedding
from pytorch_forecasting.utils import create_mask, to_list
from pytorch_forecasting.utils._dependencies import _check_matplotlib


class NHiTS(BaseModelWithCovariates):
Expand Down Expand Up @@ -419,7 +419,7 @@ def plot_interpretation(
output: Dict[str, torch.Tensor],
idx: int,
ax=None,
) -> plt.Figure:
):
"""
Plot interpretation.
Expand All @@ -436,6 +436,10 @@ def plot_interpretation(
Returns:
plt.Figure: matplotlib figure
"""
_check_matplotlib("plot_interpretation")

from matplotlib import pyplot as plt

if not isinstance(self.loss, MultiLoss): # not multi-target
prediction = self.to_prediction(dict(prediction=output["prediction"][[idx]].detach()))[0].cpu()
block_forecasts = [
Expand Down Expand Up @@ -535,6 +539,11 @@ def log_interpretation(self, x, out, batch_idx):
"""
Log interpretation of network predictions in tensorboard.
"""
mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
return None

label = ["val", "train"][self.training]
if self.log_interval > 0 and batch_idx % self.log_interval == 0:
fig = self.plot_interpretation(x, out, idx=0)
Expand Down
18 changes: 14 additions & 4 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import copy
from typing import Dict, List, Tuple, Union

from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn
Expand All @@ -24,6 +23,7 @@
VariableSelectionNetwork,
)
from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list
from pytorch_forecasting.utils._dependencies import _check_matplotlib


class TemporalFusionTransformer(BaseModelWithCovariates):
Expand Down Expand Up @@ -690,7 +690,7 @@ def plot_prediction(
show_future_observed: bool = True,
ax=None,
**kwargs,
) -> plt.Figure:
):
"""
Plot actuals vs prediction and attention
Expand All @@ -706,7 +706,6 @@ def plot_prediction(
Returns:
plt.Figure: matplotlib figure
"""

# plot prediction as normal
fig = super().plot_prediction(
x,
Expand Down Expand Up @@ -735,7 +734,7 @@ def plot_prediction(
f.tight_layout()
return fig

def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]) -> Dict[str, plt.Figure]:
def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]):
"""
Make figures that interpret model.
Expand All @@ -748,6 +747,10 @@ def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]) -> Dict[s
Returns:
dictionary of matplotlib figures
"""
_check_matplotlib("plot_interpretation")

import matplotlib.pyplot as plt

figs = {}

# attention
Expand Down Expand Up @@ -813,6 +816,13 @@ def log_interpretation(self, outputs):
interpretation["attention"] = interpretation["attention"] / attention_occurances.pow(2).clamp(1.0)
interpretation["attention"] = interpretation["attention"] / interpretation["attention"].sum()

mpl_available = _check_matplotlib("log_interpretation", raise_error=False)

if not mpl_available:
return None

import matplotlib.pyplot as plt

figs = self.plot_interpretation(interpretation) # make interpretation figures
label = self.current_stage
# log to tensorboard
Expand Down
22 changes: 22 additions & 0 deletions pytorch_forecasting/utils/_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,25 @@ def _get_installed_packages():
MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
"""
return _get_installed_packages_private().copy()


def _check_matplotlib(ref="This feature", raise_error=True):
"""Check if matplotlib is installed.
Parameters
----------
ref : str, optional (default="This feature")
reference to the feature that requires matplotlib, used in error message
raise_error : bool, optional (default=True)
whether to raise an error if matplotlib is not installed
Returns
-------
bool : whether matplotlib is installed
"""
pkgs = _get_installed_packages()

if raise_error and "matplotlib" not in pkgs:
raise ImportError(f"{ref} requires matplotlib. Please install matplotlib with `pip install matplotlib`.")

return "matplotlib" in pkgs
5 changes: 5 additions & 0 deletions tests/test_models/test_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from pytorch_forecasting.models import NBeats
from pytorch_forecasting.utils._dependencies import _get_installed_packages


def test_integration(dataloaders_fixed_window_without_covariates, tmp_path):
Expand Down Expand Up @@ -76,6 +77,10 @@ def test_pickle(model):
pickle.loads(pkl)


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_interpretation(model, dataloaders_fixed_window_without_covariates):
raw_predictions = model.predict(
dataloaders_fixed_window_without_covariates["val"], mode="raw", return_x=True, fast_dev_run=True
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def test_pickle(model):
pickle.loads(pkl)


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_interpretation(model, dataloaders_with_covariates):
raw_predictions = model.predict(dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True)
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ def test_predict_dependency(model, dataloaders_with_covariates, data_with_covari
model.predict_dependency(dataset, variable="agency", values=data_with_covariates.agency.unique()[:2], **kwargs)


@pytest.mark.skipif(
"matplotlib" not in _get_installed_packages(),
reason="skip test if required package matplotlib not installed",
)
def test_actual_vs_predicted_plot(model, dataloaders_with_covariates):
prediction = model.predict(dataloaders_with_covariates["val"], return_x=True)
averages = model.calculate_prediction_actual_by_variable(prediction.x, prediction.output)
Expand Down

0 comments on commit 95fa06c

Please sign in to comment.