diff --git a/tests/estimation/test_get_estimation.py b/tests/estimation/test_get_estimation.py index 16953a5..37878d3 100644 --- a/tests/estimation/test_get_estimation.py +++ b/tests/estimation/test_get_estimation.py @@ -179,6 +179,7 @@ def x(data): return data[0] +# t is raveled because some estimators fail with (n,1) inputs @pytest.fixture def t(data): return data[1].ravel() @@ -191,7 +192,7 @@ def m(data): @pytest.fixture def y(data): - return data[3].ravel() + return data[3].ravel() # same reason as t @pytest.fixture @@ -249,3 +250,15 @@ def test_total_is_direct_plus_indirect(effects_chap): assert effects_chap[0] == pytest.approx(effects_chap[1] + effects_chap[4]) if not np.isnan(effects_chap[2]): assert effects_chap[0] == pytest.approx(effects_chap[2] + effects_chap[3]) + + +@pytest.mark.xfail +def test_robustness_to_ravel_format(data, estimator, config, effects_chap): + if "forest" in estimator: + pytest.skip("Forest estimator skipped") + assert np.all( + get_estimation(data[0], data[1], data[2], data[3], estimator, config)[0:5] + == pytest.approx( + effects_chap, nan_ok=True + ) # effects_chap is obtained with data[1].ravel() and data[3].ravel() + )