From 926982a87cba34316785b058c7a6f2387e898568 Mon Sep 17 00:00:00 2001 From: Mr-Geekman <36005824+Mr-Geekman@users.noreply.github.com> Date: Fri, 14 Oct 2022 15:04:43 +0300 Subject: [PATCH] Change returned model in `get_model` of `BATSModel`, `TBATSModel` (#987) --- CHANGELOG.md | 2 +- etna/models/tbats.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 260972a15..24861a2e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Changed - -- +- Change returned model in get_model of BATSModel, TBATSModel ([#987](https://github.com/tinkoff-ai/etna/pull/987)) - - - diff --git a/etna/models/tbats.py b/etna/models/tbats.py index 629e63728..cafae4116 100644 --- a/etna/models/tbats.py +++ b/etna/models/tbats.py @@ -16,7 +16,7 @@ class _TBATSAdapter(BaseAdapter): def __init__(self, model: Estimator): - self.model = model + self._model = model self._fitted_model: Optional[Model] = None self._last_train_timestamp = None self._freq = None @@ -27,7 +27,7 @@ def fit(self, df: pd.DataFrame, regressors: Iterable[str]): raise ValueError("Can't determine frequency of a given dataframe") target = df["target"] - self._fitted_model = self.model.fit(target) + self._fitted_model = self._model.fit(target) self._last_train_timestamp = df["timestamp"].max() self._freq = freq @@ -68,8 +68,15 @@ def predict(self, df: pd.DataFrame, prediction_interval: bool, quantiles: Iterab return y_pred - def get_model(self) -> Estimator: - return self.model + def get_model(self) -> Model: + """Get internal :py:class:`tbats.tbats.Model` model that was fitted inside etna class. + + Returns + ------- + : + Internal model + """ + return self._fitted_model class BATSModel(PerSegmentPredictionIntervalModel):