Skip to content

Commit

Permalink
Implemented test in respond to #15 (#36)
Browse files Browse the repository at this point in the history
* implemented test to check robustness of methods to the input format in respond to #15 

---------

Co-authored-by: Sami Boumaiza <[email protected]>
  • Loading branch information
sami6mz and Sami Boumaiza authored Jul 10, 2023
1 parent d7272e2 commit 13c97da
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion tests/estimation/test_get_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def x(data):
return data[0]


# t is raveled because some estimators fail with (n,1) inputs
@pytest.fixture
def t(data):
return data[1].ravel()
Expand All @@ -191,7 +192,7 @@ def m(data):

@pytest.fixture
def y(data):
return data[3].ravel()
return data[3].ravel() # same reason as t


@pytest.fixture
Expand Down Expand Up @@ -249,3 +250,15 @@ def test_total_is_direct_plus_indirect(effects_chap):
assert effects_chap[0] == pytest.approx(effects_chap[1] + effects_chap[4])
if not np.isnan(effects_chap[2]):
assert effects_chap[0] == pytest.approx(effects_chap[2] + effects_chap[3])


@pytest.mark.xfail
def test_robustness_to_ravel_format(data, estimator, config, effects_chap):
if "forest" in estimator:
pytest.skip("Forest estimator skipped")
assert np.all(
get_estimation(data[0], data[1], data[2], data[3], estimator, config)[0:5]
== pytest.approx(
effects_chap, nan_ok=True
) # effects_chap is obtained with data[1].ravel() and data[3].ravel()
)

0 comments on commit 13c97da

Please sign in to comment.