Skip to content

Commit

Permalink
Update downcast_nullable_types for series consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
eccabay committed Sep 16, 2022
1 parent 8b7cc18 commit 56eb074
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
32 changes: 25 additions & 7 deletions evalml/tests/utils_tests/test_woodwork_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,19 +297,37 @@ def test_schema_is_equal_fraud(fraud_100):
assert _schema_is_equal(X.ww.schema, X2.ww.schema)


def test_downcast_nullable_types_series():
@pytest.mark.parametrize("ignore_null_cols", [True, False])
def test_downcast_nullable_types_series(ignore_null_cols):
X = pd.DataFrame(
{"int": [1, 0, 1, 1, 0.0], "bool": [True, False, True, True, False]},
{
"int": [1, 0, 1, 1, 0.0],
"bool": [True, False, True, True, False],
"int with nan": [1, 0, 1, 1, pd.NA],
},
)
X.ww.init(
logical_types={
"int": IntegerNullable,
"bool": BooleanNullable,
"int with nan": IntegerNullable,
},
)
X.ww.init(logical_types={"int": "IntegerNullable", "bool": "BooleanNullable"})

y_int = X.ww["int"]
y_int_t = downcast_nullable_types(y_int)
assert y_int_t.ww.logical_type.type_string == "double"
y_int_t = downcast_nullable_types(y_int, ignore_null_cols)
assert y_int_t.ww.logical_type.type_string == Double.type_string

y_bool = X.ww["bool"]
y_bool_t = downcast_nullable_types(y_bool)
assert y_bool_t.ww.logical_type.type_string == "boolean"
y_bool_t = downcast_nullable_types(y_bool, ignore_null_cols)
assert y_bool_t.ww.logical_type.type_string == Boolean.type_string

y_int_nan = X.ww["int with nan"]
y_int_nan_t = downcast_nullable_types(y_int_nan, ignore_null_cols)
if ignore_null_cols:
assert y_int_nan_t.ww.logical_type.type_string == IntegerNullable.type_string
else:
assert y_int_t.ww.logical_type.type_string == Double.type_string


def test_downcast_nullable_types_can_handle_no_schema():
Expand Down
2 changes: 2 additions & 0 deletions evalml/utils/woodwork_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ def downcast_nullable_types(data, ignore_null_cols=True):
data.ww.init()

if isinstance(data, pd.Series):
if ignore_null_cols and data.isna().any():
return data
if isinstance(data.ww.logical_type, ww.logical_types.BooleanNullable):
data = data.ww.set_logical_type("Boolean")
if isinstance(
Expand Down

0 comments on commit 56eb074

Please sign in to comment.