Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Mar 29, 2023
1 parent 7bb3c17 commit 1f82430
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/test_pipeline/test_autoregressive_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from tests.test_pipeline.utils import assert_pipeline_equals_loaded_original
from tests.test_pipeline.utils import assert_pipeline_forecasts_given_ts
from tests.test_pipeline.utils import assert_pipeline_forecasts_given_ts_with_prediction_intervals
from tests.utils import to_be_fixed

DEFAULT_METRICS = [MAE(mode=MetricAggregationMode.per_segment)]

Expand Down Expand Up @@ -347,7 +346,6 @@ def test_forecast_given_ts_with_prediction_interval(model, transforms, example_t
assert_pipeline_forecasts_given_ts_with_prediction_intervals(pipeline=pipeline, ts=example_tsds, horizon=horizon)


@to_be_fixed(NotImplementedError, "Adding target components is not currently implemented!")
@pytest.mark.parametrize(
"model_fixture",
(
Expand All @@ -357,14 +355,20 @@ def test_forecast_given_ts_with_prediction_interval(model, transforms, example_t
"prediction_interval_context_required_dummy_model",
),
)
def test_forecast_return_components(example_tsds, model_fixture, request):
def test_forecast_return_components(
example_tsds, model_fixture, request, expected_component_a=10, expected_component_b=90
):
model = request.getfixturevalue(model_fixture)
pipeline = AutoRegressivePipeline(model=model, horizon=10)
pipeline.fit(example_tsds)
forecast = pipeline.forecast(return_components=True)
assert sorted(forecast.target_components_names) == sorted(["target_component_a", "target_component_b"])

target_components_df = TSDataset.to_flatten(forecast.get_target_components())
assert (target_components_df["target_component_a"] == expected_component_a).all()
assert (target_components_df["target_component_b"] == expected_component_b).all()


@to_be_fixed(NotImplementedError, "Adding target components is not currently implemented!")
@pytest.mark.parametrize(
"model_fixture",
(
Expand All @@ -374,8 +378,15 @@ def test_forecast_return_components(example_tsds, model_fixture, request):
"prediction_interval_context_required_dummy_model",
),
)
def test_predict_return_components(example_tsds, model_fixture, request):
def test_predict_return_components(
example_tsds, model_fixture, request, expected_component_a=20, expected_component_b=180
):
model = request.getfixturevalue(model_fixture)
pipeline = AutoRegressivePipeline(model=model, horizon=10)
pipeline.fit(example_tsds)
forecast = pipeline.predict(ts=example_tsds, return_components=True)
assert sorted(forecast.target_components_names) == sorted(["target_component_a", "target_component_b"])

target_components_df = TSDataset.to_flatten(forecast.get_target_components())
assert (target_components_df["target_component_a"] == expected_component_a).all()
assert (target_components_df["target_component_b"] == expected_component_b).all()

0 comments on commit 1f82430

Please sign in to comment.