diff --git a/rdt/transformers/categorical.py b/rdt/transformers/categorical.py index 8d2f64e93..a2206f4dd 100644 --- a/rdt/transformers/categorical.py +++ b/rdt/transformers/categorical.py @@ -223,15 +223,43 @@ class OneHotEncodingTransformer(BaseTransformer): dummy_na = None dummies = None + @staticmethod + def _prepare_data(data): + """Transform data to appropriate format. + + If data is a valid list or a list of lists, transforms it into an np.array, + otherwise returns it. + + Args: + data (pandas.Series, numpy.ndarray, list or list of lists): + Data to prepare. + + Returns: + pandas.Series or numpy.ndarray + """ + if isinstance(data, list): + data = np.array(data) + + if len(data.shape) > 2: + raise ValueError("Unexpected format.") + if len(data.shape) == 2: + if data.shape[1] != 1: + raise ValueError("Unexpected format.") + + data = data[:, 0] + + return data + def fit(self, data): """Fit the transformer to the data. Get the pandas `dummies` which will be used later on for OneHotEncoding. Args: - data (pandas.Series or numpy.ndarray): + data (pandas.Series, numpy.ndarray, list or list of lists): Data to fit the transformer to. """ + data = self._prepare_data(data) self.dummy_na = pd.isnull(data).any() self.dummies = list(pd.get_dummies(data, dummy_na=self.dummy_na).columns) @@ -239,12 +267,13 @@ def transform(self, data): """Replace each category with the OneHot vectors. Args: - data (pandas.Series or numpy.ndarray): + data (pandas.Series, numpy.ndarray, list or list of lists): Data to transform. Returns: numpy.ndarray: """ + data = self._prepare_data(data) dummies = pd.get_dummies(data, dummy_na=self.dummy_na) return dummies.reindex(columns=self.dummies, fill_value=0).values.astype(int) diff --git a/tests/transformers/test_categorical.py b/tests/transformers/test_categorical.py index 971c9111f..f4dd144a6 100644 --- a/tests/transformers/test_categorical.py +++ b/tests/transformers/test_categorical.py @@ -283,6 +283,48 @@ def test_reversible_mixed(self): class TestOneHotEncodingTransformer: + def test__prepare_data_empty_lists(self): + # Setup + ohet = OneHotEncodingTransformer() + data = [[], [], []] + + # Assert + with pytest.raises(ValueError): + ohet._prepare_data(data) + + def test__prepare_data_nested_lists(self): + # Setup + ohet = OneHotEncodingTransformer() + data = [[[]]] + + # Assert + with pytest.raises(ValueError): + ohet._prepare_data(data) + + def test__prepare_data_list_of_lists(self): + # Setup + ohet = OneHotEncodingTransformer() + + # Run + data = [['a'], ['b'], ['c']] + out = ohet._prepare_data(data) + + # Assert + expected = np.array(['a', 'b', 'c']) + np.testing.assert_array_equal(out, expected) + + def test__prepare_data_pandas_series(self): + # Setup + ohet = OneHotEncodingTransformer() + + # Run + data = pd.Series(['a', 'b', 'c']) + out = ohet._prepare_data(data) + + # Assert + expected = pd.Series(['a', 'b', 'c']) + np.testing.assert_array_equal(out, expected) + def test_fit_no_nans(self): # Setup ohet = OneHotEncodingTransformer()