Skip to content

Commit

Permalink
Merge branch 'master' into issue-1234
Browse files Browse the repository at this point in the history
  • Loading branch information
brsnw250 authored May 10, 2023
2 parents 380fb7f + 634a5c6 commit 6d4fcb7
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = F, E203, W605, E501, W503, D100, D104, C408, B023
ignore = F, E203, W605, E501, W503, D100, D104, C408
max-line-length = 121
max-complexity = 18
docstring-convention=numpy
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Add `tsfresh` into optional dependencies, remove instruction about `pip install tsfresh` ([#1246](https://github.com/tinkoff-ai/etna/pull/1246))
- Fix `DeepARModel` and `TFTModel` to work with changed `prediction_size` ([#1251](https://github.com/tinkoff-ai/etna/pull/1251))
- Fix problems with flake8 B023 ([#1252](https://github.com/tinkoff-ai/etna/pull/1252))
- Fix problem with swapped forecast methods in HierarchicalPipeline ([#1259](https://github.com/tinkoff-ai/etna/pull/1259))

## [2.0.0] - 2023-04-11
### Added
Expand Down
4 changes: 3 additions & 1 deletion etna/analysis/eda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def _create_holidays_df_str(holidays: str, index, as_is):

holidays_dict = {}
for holiday_name in holiday_names:
cur_holiday_index = pd.Series(timestamp).apply(lambda x: country_holidays.get(x, "") == holiday_name)
cur_holiday_index = pd.Series(timestamp).apply(
lambda x: country_holidays.get(x, "") == holiday_name # noqa: B023
)
holidays_dict[holiday_name] = cur_holiday_index

holidays_df = pd.DataFrame(holidays_dict)
Expand Down
2 changes: 1 addition & 1 deletion etna/analysis/feature_selection/mrmr_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mrmr(
last_selected_regressor = regressors.loc[pd.IndexSlice[:], pd.IndexSlice[:, last_selected_feature]]

redundancy_table.loc[not_selected_features, last_selected_feature] = (
not_selected_regressors.apply(lambda col: last_selected_regressor.corrwith(col))
not_selected_regressors.apply(lambda col: last_selected_regressor.corrwith(col)) # noqa: B023
.abs()
.groupby("feature")
.apply(redundancy_aggregation_fn)
Expand Down
7 changes: 4 additions & 3 deletions etna/auto/optuna/config_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List
from typing import Optional
from typing import Set
Expand Down Expand Up @@ -115,10 +116,10 @@ def _get_unfinished_hashes(self, study: Study, current_trial: Optional[FrozenTri
finished_trials_hash.append(t.user_attrs["hash"])
elif t.state == TrialState.RUNNING:

def _closure():
return study._storage.get_trial(t._trial_id).user_attrs["hash"]
def _closure(trial):
return study._storage.get_trial(trial._trial_id).user_attrs["hash"]

hash_to_add = retry(_closure, max_retries=self.retries)
hash_to_add = retry(partial(_closure, trial=t), max_retries=self.retries)
running_trials_hash.append(hash_to_add)
else:
pass
Expand Down
25 changes: 12 additions & 13 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,25 +312,24 @@ def _forecast_prediction_interval(
self, ts: TSDataset, predictions: TSDataset, quantiles: Sequence[float], n_folds: int
) -> TSDataset:
"""Add prediction intervals to the forecasts."""
# TODO: fix this: what if during backtest KeyboardInterrupt is raised
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

if self.ts is None:
raise ValueError("Pipeline is not fitted! Fit the Pipeline before calling forecast method.")

# TODO: rework intervals estimation for `BottomUpReconciliator`

with tslogger.disable():
_, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds)
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore
try:
# TODO: rework intervals estimation for `BottomUpReconciliator`

source_ts = self.reconciliator.aggregate(ts=ts)
self._add_forecast_borders(
ts=source_ts, backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions
)
with tslogger.disable():
_, forecasts, _ = self.backtest(ts=ts, metrics=[MAE()], n_folds=n_folds)

self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore
source_ts = self.reconciliator.aggregate(ts=ts)
self._add_forecast_borders(
ts=source_ts, backtest_forecasts=forecasts, quantiles=quantiles, predictions=predictions
)
return predictions

return predictions
finally:
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

def save(self, path: pathlib.Path):
"""Save the object.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_pipeline/test_hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,33 @@ def test_backtest_w_exog(product_level_constant_hierarchical_ts_with_exog, recon
np.testing.assert_allclose(metrics["MAE"], 0)


@pytest.mark.parametrize(
"reconciliator",
(
TopDownReconciliator(target_level="product", source_level="market", period=1, method="AHP"),
TopDownReconciliator(target_level="product", source_level="market", period=1, method="PHA"),
BottomUpReconciliator(target_level="total", source_level="market"),
),
)
def test_private_forecast_prediction_interval_no_swap_after_error(
product_level_constant_hierarchical_ts_with_exog, reconciliator
):
ts = product_level_constant_hierarchical_ts_with_exog
model = LinearPerSegmentModel()
pipeline = HierarchicalPipeline(reconciliator=reconciliator, model=model, transforms=[], horizon=1)
pipeline.backtest = Mock(side_effect=ValueError("Some error"))
forecast_method = pipeline.forecast
raw_forecast_method = pipeline.raw_forecast

pipeline.fit(ts=ts)
with pytest.raises(ValueError, match="Some error"):
_ = pipeline.forecast(prediction_interval=True, n_folds=1, quantiles=[0.025, 0.5, 0.975])

# check that methods aren't swapped
assert pipeline.forecast == forecast_method
assert pipeline.raw_forecast == raw_forecast_method


@pytest.mark.parametrize(
"reconciliator",
(
Expand Down

0 comments on commit 6d4fcb7

Please sign in to comment.