From 004f752ec692b05d485575951c6b4dec04e55e21 Mon Sep 17 00:00:00 2001 From: Karsten Chu Date: Tue, 10 Aug 2021 12:38:52 -0400 Subject: [PATCH] Changed the target imputer to just look for all nulls in the target and raise. --- .../components/transformers/imputers/target_imputer.py | 7 ++----- evalml/tests/component_tests/test_target_imputer.py | 4 ++-- evalml/tests/pipeline_tests/test_pipeline_utils.py | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/evalml/pipelines/components/transformers/imputers/target_imputer.py b/evalml/pipelines/components/transformers/imputers/target_imputer.py index 2a4a1fed7d..eab81a03e6 100644 --- a/evalml/pipelines/components/transformers/imputers/target_imputer.py +++ b/evalml/pipelines/components/transformers/imputers/target_imputer.py @@ -74,15 +74,12 @@ def fit(self, X, y): Returns: self """ - from woodwork.logical_types import Unknown - if y is None: return self y = infer_feature_types(y) - if isinstance(y.ww.logical_type, Unknown): - raise TypeError("Provided target full of pd.NA.") + if all(y.isnull()): + raise TypeError("Provided target full of nulls.") y = y.to_frame() - # should y be an un-inited dataframe? # Convert all bool dtypes to category for fitting if (y.dtypes == bool).all(): diff --git a/evalml/tests/component_tests/test_target_imputer.py b/evalml/tests/component_tests/test_target_imputer.py index 2b00d1292f..5a77c41db2 100644 --- a/evalml/tests/component_tests/test_target_imputer.py +++ b/evalml/tests/component_tests/test_target_imputer.py @@ -134,11 +134,11 @@ def test_target_imputer_fit_transform_all_nan_empty(y): imputer = TargetImputer() - with pytest.raises(TypeError, match="Provided target full of pd.NA."): + with pytest.raises(TypeError, match="Provided target full of nulls."): imputer.fit(None, y) imputer = TargetImputer() - with pytest.raises(TypeError, match="Provided target full of pd.NA."): + with pytest.raises(TypeError, match="Provided target full of nulls."): imputer.fit_transform(None, y) diff --git a/evalml/tests/pipeline_tests/test_pipeline_utils.py b/evalml/tests/pipeline_tests/test_pipeline_utils.py index 3e04ab77b5..ed59003767 100644 --- a/evalml/tests/pipeline_tests/test_pipeline_utils.py +++ b/evalml/tests/pipeline_tests/test_pipeline_utils.py @@ -203,7 +203,7 @@ def test_make_pipeline( ) imputer = ( [] - if (column_names == ["dates"] and input_type == "ww") + if (column_names in [["dates"], ["all_null"]] and input_type == "ww") or ((column_names in [["text"], ["dates"]]) and input_type == "pd") else [Imputer] )