Skip to content

Commit

Permalink
Raise informative error when including target in X (#962)
Browse files Browse the repository at this point in the history
* add error when target is in X_df

* add test

* rename test

* format
  • Loading branch information
cluhmann authored Aug 22, 2024
1 parent 255eac1 commit 943a2c4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,10 @@ def fit(
self._generate_and_preprocess_model_data(X, y_df.values.flatten())
if self.X is None or self.y is None:
raise ValueError("X and y must be set before calling build_model!")
if self.output_var in X.columns:
raise ValueError(
f"X includes a column named '{self.output_var}', which conflicts with the target variable."
)

if not hasattr(self, "model"):
self.build_model(self.X, self.y)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ def test_fit_no_t(toy_X):
assert "posterior" in model_builder.idata.groups()


@pytest.mark.xfail
def test_fit_dup_Y(toy_X):
# create redundant target column in X
toy_X = pd.concat((toy_X, toy_y), axis=1)
model_builder = ModelBuilderTest()
model_builder.fit(X=toy_X, chains=1, draws=100, tune=100)


@pytest.mark.skipif(
sys.platform == "win32",
reason="Permissions for temp files not granted on windows CI.",
Expand Down

0 comments on commit 943a2c4

Please sign in to comment.