Skip to content

Commit

Permalink
OneHotEncodingTransformer support for lists and lists of lists (#137)
Browse files Browse the repository at this point in the history
* Created _prepare_data

* Tests _prepare_data

* Improved code.

* Fix lint.

* Fix documentation.
  • Loading branch information
fealho authored Nov 18, 2020
1 parent a69c68f commit c2842b6
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
33 changes: 31 additions & 2 deletions rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,28 +223,57 @@ class OneHotEncodingTransformer(BaseTransformer):
dummy_na = None
dummies = None

@staticmethod
def _prepare_data(data):
"""Transform data to appropriate format.
If data is a valid list or a list of lists, transforms it into an np.array,
otherwise returns it.
Args:
data (pandas.Series, numpy.ndarray, list or list of lists):
Data to prepare.
Returns:
pandas.Series or numpy.ndarray
"""
if isinstance(data, list):
data = np.array(data)

if len(data.shape) > 2:
raise ValueError("Unexpected format.")
if len(data.shape) == 2:
if data.shape[1] != 1:
raise ValueError("Unexpected format.")

data = data[:, 0]

return data

def fit(self, data):
"""Fit the transformer to the data.
Get the pandas `dummies` which will be used later on for OneHotEncoding.
Args:
data (pandas.Series or numpy.ndarray):
data (pandas.Series, numpy.ndarray, list or list of lists):
Data to fit the transformer to.
"""
data = self._prepare_data(data)
self.dummy_na = pd.isnull(data).any()
self.dummies = list(pd.get_dummies(data, dummy_na=self.dummy_na).columns)

def transform(self, data):
"""Replace each category with the OneHot vectors.
Args:
data (pandas.Series or numpy.ndarray):
data (pandas.Series, numpy.ndarray, list or list of lists):
Data to transform.
Returns:
numpy.ndarray:
"""
data = self._prepare_data(data)
dummies = pd.get_dummies(data, dummy_na=self.dummy_na)
return dummies.reindex(columns=self.dummies, fill_value=0).values.astype(int)

Expand Down
42 changes: 42 additions & 0 deletions tests/transformers/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,48 @@ def test_reversible_mixed(self):

class TestOneHotEncodingTransformer:

def test__prepare_data_empty_lists(self):
# Setup
ohet = OneHotEncodingTransformer()
data = [[], [], []]

# Assert
with pytest.raises(ValueError):
ohet._prepare_data(data)

def test__prepare_data_nested_lists(self):
# Setup
ohet = OneHotEncodingTransformer()
data = [[[]]]

# Assert
with pytest.raises(ValueError):
ohet._prepare_data(data)

def test__prepare_data_list_of_lists(self):
# Setup
ohet = OneHotEncodingTransformer()

# Run
data = [['a'], ['b'], ['c']]
out = ohet._prepare_data(data)

# Assert
expected = np.array(['a', 'b', 'c'])
np.testing.assert_array_equal(out, expected)

def test__prepare_data_pandas_series(self):
# Setup
ohet = OneHotEncodingTransformer()

# Run
data = pd.Series(['a', 'b', 'c'])
out = ohet._prepare_data(data)

# Assert
expected = pd.Series(['a', 'b', 'c'])
np.testing.assert_array_equal(out, expected)

def test_fit_no_nans(self):
# Setup
ohet = OneHotEncodingTransformer()
Expand Down

0 comments on commit c2842b6

Please sign in to comment.