Skip to content

nn-models typings fix #840

Merged
merged 17 commits into from
Aug 10, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
### Fixed
-
- Type hints for `Pipeline.model` match `models.nn`([#768](https://github.com/tinkoff-ai/etna/pull/840))
-
-
- Fix behavior of SARIMAXModel if simple_differencing=True is set ([#837](https://github.com/tinkoff-ai/etna/pull/837))
Expand Down
29 changes: 28 additions & 1 deletion etna/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,4 +820,31 @@ def get_model(self) -> "DeepBaseNet":
return self.net


BaseModel = Union[PerSegmentModel, PerSegmentPredictionIntervalModel, MultiSegmentModel, DeepBaseModel]
class MultiSegmentPredictionIntervalModel(FitAbstractModel, PredictIntervalAbstractModel, BaseMixin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we need just empty class without any implementation.

Downstream classes has implemented own logic already. Seems we can't share it

"""Class for holding specific models for multi-segment prediction which are able to build prediction intervals."""

def __init__(self):
"""Init MultiSegmentPredictionIntervalModel."""
self.model = None

def get_model(self) -> Any:
"""Get internal model that is used inside etna class.

Internal model is a model that is used inside etna to forecast segments,
e.g. :py:class:`catboost.CatBoostRegressor` or :py:class:`sklearn.linear_model.Ridge`.

Returns
-------
:
Internal model
"""
return self.model


BaseModel = Union[
PerSegmentModel,
PerSegmentPredictionIntervalModel,
MultiSegmentModel,
DeepBaseModel,
MultiSegmentPredictionIntervalModel,
]
6 changes: 3 additions & 3 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from etna import SETTINGS
from etna.datasets.tsdataset import TSDataset
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import PredictIntervalAbstractModel
from etna.models.base import MultiSegmentPredictionIntervalModel
from etna.models.base import log_decorator
from etna.models.nn.utils import _DeepCopyMixin
from etna.transforms import PytorchForecastingTransform
Expand All @@ -25,7 +24,7 @@
from pytorch_lightning import LightningModule


class DeepARModel(Model, PredictIntervalAbstractModel, _DeepCopyMixin):
class DeepARModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin):
"""Wrapper for :py:class:`pytorch_forecasting.models.deepar.DeepAR`.

Notes
Expand Down Expand Up @@ -84,6 +83,7 @@ def __init__(
quantiles_kwargs:
Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss.
"""
super().__init__()
if loss is None:
loss = NormalDistributionLoss()
self.max_epochs = max_epochs
Expand Down
6 changes: 3 additions & 3 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from etna import SETTINGS
from etna.datasets.tsdataset import TSDataset
from etna.loggers import tslogger
from etna.models.base import Model
from etna.models.base import PredictIntervalAbstractModel
from etna.models.base import MultiSegmentPredictionIntervalModel
from etna.models.base import log_decorator
from etna.models.nn.utils import _DeepCopyMixin
from etna.transforms import PytorchForecastingTransform
Expand All @@ -26,7 +25,7 @@
from pytorch_lightning import LightningModule


class TFTModel(Model, PredictIntervalAbstractModel, _DeepCopyMixin):
class TFTModel(MultiSegmentPredictionIntervalModel, _DeepCopyMixin):
"""Wrapper for :py:class:`pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`.

Notes
Expand Down Expand Up @@ -89,6 +88,7 @@ def __init__(
quantiles_kwargs:
Additional arguments for computing quantiles, look at ``to_quantiles()`` method for your loss.
"""
super().__init__()
if loss is None:
loss = QuantileLoss()
self.max_epochs = max_epochs
Expand Down