diff --git a/tests/simulate_data/test_get_simulated_data.py b/tests/simulate_data/test_get_simulated_data.py index d94ecc8..40cde5d 100644 --- a/tests/simulate_data/test_get_simulated_data.py +++ b/tests/simulate_data/test_get_simulated_data.py @@ -108,9 +108,15 @@ def test_dimension_y(y, dict_param): def test_m_is_binary(m, dict_param): if dict_param["type_m"] == "binary": - assert sum(m.ravel() == 1) + sum(m.ravel() == 0) == dict_param["n"] + assert ( + sum(m.ravel() == 1) + sum(m.ravel() == 0) + == dict_param["n"] * dict_param["dim_m"] + ) else: - assert sum(m.ravel() == 1) + sum(m.ravel() == 0) < dict_param["n"] + assert ( + sum(m.ravel() == 1) + sum(m.ravel() == 0) + < dict_param["n"] * dict_param["dim_m"] + ) def test_total_is_direct_plus_indirect(effects):