diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 95f868e89ac39..2e75bb3b8c326 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -105,6 +105,12 @@ def take(self, indexer, allow_fill=True, fill_value=None): def copy(self, deep=False): return type(self)(self.data[:]) + def astype(self, dtype, copy=True): + # NumPy has issues when all the dicts are the same length. + # np.array([UserDict(...), UserDict(...)]) fails, + # but np.array([{...}, {...}]) works, so cast. + return np.array([dict(x) for x in self], dtype=dtype, copy=copy) + def unique(self): # Parent method doesn't work since np.array will try to infer # a 2-dim object. diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index dcf08440738e7..0ef34c3b0f679 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -1,8 +1,10 @@ import operator +import collections import pytest - +import pandas as pd +import pandas.util.testing as tm from pandas.compat import PY2, PY36 from pandas.tests.extension import base @@ -59,27 +61,76 @@ def data_for_grouping(): ]) -class TestDtype(base.BaseDtypeTests): +class BaseJSON(object): + # NumPy doesn't handle an array of equal-length UserDicts. + # The default assert_series_equal eventually does a + # Series.values, which raises. We work around it by + # converting the UserDicts to dicts. + def assert_series_equal(self, left, right, **kwargs): + if left.dtype.name == 'json': + assert left.dtype == right.dtype + left = pd.Series(JSONArray(left.values.astype(object)), + index=left.index, name=left.name) + right = pd.Series(JSONArray(right.values.astype(object)), + index=right.index, name=right.name) + tm.assert_series_equal(left, right, **kwargs) + + def assert_frame_equal(self, left, right, *args, **kwargs): + tm.assert_index_equal( + left.columns, right.columns, + exact=kwargs.get('check_column_type', 'equiv'), + check_names=kwargs.get('check_names', True), + check_exact=kwargs.get('check_exact', False), + check_categorical=kwargs.get('check_categorical', True), + obj='{obj}.columns'.format(obj=kwargs.get('obj', 'DataFrame'))) + + jsons = (left.dtypes == 'json').index + + for col in jsons: + self.assert_series_equal(left[col], right[col], + *args, **kwargs) + + left = left.drop(columns=jsons) + right = right.drop(columns=jsons) + tm.assert_frame_equal(left, right, *args, **kwargs) + + +class TestDtype(BaseJSON, base.BaseDtypeTests): pass -class TestInterface(base.BaseInterfaceTests): - pass +class TestInterface(BaseJSON, base.BaseInterfaceTests): + def test_custom_asserts(self): + # This would always trigger the KeyError from trying to put + # an array of equal-length UserDicts inside an ndarray. + data = JSONArray([collections.UserDict({'a': 1}), + collections.UserDict({'b': 2}), + collections.UserDict({'c': 3})]) + a = pd.Series(data) + self.assert_series_equal(a, a) + self.assert_frame_equal(a.to_frame(), a.to_frame()) + + b = pd.Series(data.take([0, 0, 1])) + with pytest.raises(AssertionError): + self.assert_series_equal(a, b) + + with pytest.raises(AssertionError): + self.assert_frame_equal(a.to_frame(), b.to_frame()) -class TestConstructors(base.BaseConstructorsTests): +class TestConstructors(BaseJSON, base.BaseConstructorsTests): pass -class TestReshaping(base.BaseReshapingTests): +class TestReshaping(BaseJSON, base.BaseReshapingTests): pass -class TestGetitem(base.BaseGetitemTests): +class TestGetitem(BaseJSON, base.BaseGetitemTests): pass -class TestMissing(base.BaseMissingTests): +class TestMissing(BaseJSON, base.BaseMissingTests): @pytest.mark.xfail(reason="Setting a dict as a scalar") def test_fillna_series(self): """We treat dictionaries as a mapping in fillna, not a scalar.""" @@ -94,7 +145,7 @@ def test_fillna_frame(self): reason="Dictionary order unstable") -class TestMethods(base.BaseMethodsTests): +class TestMethods(BaseJSON, base.BaseMethodsTests): @unhashable def test_value_counts(self, all_data, dropna): pass @@ -126,7 +177,7 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending): data_missing_for_sorting, ascending) -class TestCasting(base.BaseCastingTests): +class TestCasting(BaseJSON, base.BaseCastingTests): @pytest.mark.xfail def test_astype_str(self): """This currently fails in NumPy on np.array(self, dtype=str) with @@ -139,7 +190,7 @@ def test_astype_str(self): # internals has trouble setting sequences of values into scalar positions. -class TestGroupby(base.BaseGroupbyTests): +class TestGroupby(BaseJSON, base.BaseGroupbyTests): @unhashable def test_groupby_extension_transform(self):