Skip to content

Commit

Permalink
Fix pd.merge to preserve ExtensionArrays dtypes (pandas-dev#20745)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored and jreback committed Apr 22, 2018
1 parent 4de2e9b commit 0ae7e90
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ def _get_dtype(arr_or_dtype):
return arr_or_dtype
elif isinstance(arr_or_dtype, type):
return np.dtype(arr_or_dtype)
elif isinstance(arr_or_dtype, CategoricalDtype):
elif isinstance(arr_or_dtype, ExtensionDtype):
return arr_or_dtype
elif isinstance(arr_or_dtype, DatetimeTZDtype):
return arr_or_dtype
Expand Down
12 changes: 9 additions & 3 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5541,8 +5541,14 @@ def concatenate_join_units(join_units, concat_axis, copy):
if len(to_concat) == 1:
# Only one block, nothing to concatenate.
concat_values = to_concat[0]
if copy and concat_values.base is not None:
concat_values = concat_values.copy()
if copy:
if isinstance(concat_values, np.ndarray):
# non-reindexed (=not yet copied) arrays are made into a view
# in JoinUnit.get_reindexed_values
if concat_values.base is not None:
concat_values = concat_values.copy()
else:
concat_values = concat_values.copy()
else:
concat_values = _concat._concat_compat(to_concat, axis=concat_axis)

Expand Down Expand Up @@ -5823,7 +5829,7 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
# External code requested filling/upcasting, bool values must
# be upcasted to object to avoid being upcasted to numeric.
values = self.block.astype(np.object_).values
elif self.block.is_categorical:
elif self.block.is_extension:
values = self.block.values
else:
# No dtype upcasting is done here, it will be performed during
Expand Down
21 changes: 21 additions & 0 deletions pandas/tests/extension/base/reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,24 @@ def test_set_frame_overwrite_object(self, data):
df = pd.DataFrame({"A": [1] * len(data)}, dtype=object)
df['A'] = data
assert df.dtypes['A'] == data.dtype

def test_merge(self, data, na_value):
# GH-20743
df1 = pd.DataFrame({'ext': data[:3], 'int1': [1, 2, 3],
'key': [0, 1, 2]})
df2 = pd.DataFrame({'int2': [1, 2, 3, 4], 'key': [0, 0, 1, 3]})

res = pd.merge(df1, df2)
exp = pd.DataFrame(
{'int1': [1, 1, 2], 'int2': [1, 2, 3], 'key': [0, 0, 1],
'ext': data._constructor_from_sequence(
[data[0], data[0], data[1]])})
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])

res = pd.merge(df1, df2, how='outer')
exp = pd.DataFrame(
{'int1': [1, 1, 2, 3, np.nan], 'int2': [1, 2, 3, np.nan, 4],
'key': [0, 0, 1, 2, 3],
'ext': data._constructor_from_sequence(
[data[0], data[0], data[1], data[2], na_value])})
self.assert_frame_equal(res, exp[['ext', 'int1', 'key', 'int2']])
4 changes: 4 additions & 0 deletions pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def test_align(self, data, na_value):
def test_align_frame(self, data, na_value):
pass

@pytest.mark.skip(reason="Unobserved categories preseved in concat.")
def test_merge(self, data, na_value):
pass


class TestGetitem(base.BaseGetitemTests):
@pytest.mark.skip(reason="Backwards compatibility")
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def assert_series_equal(self, left, right, *args, **kwargs):

def assert_frame_equal(self, left, right, *args, **kwargs):
# TODO(EA): select_dtypes
tm.assert_index_equal(
left.columns, right.columns,
exact=kwargs.get('check_column_type', 'equiv'),
check_names=kwargs.get('check_names', True),
check_exact=kwargs.get('check_exact', False),
check_categorical=kwargs.get('check_categorical', True),
obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame')))

decimals = (left.dtypes == 'decimal').index

for col in decimals:
Expand Down

0 comments on commit 0ae7e90

Please sign in to comment.