diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index f59954aaf08..cda4e8cbd4c 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -32,6 +32,7 @@ is_numerical_dtype, is_scalar, min_scalar_type, + find_common_type, ) T = TypeVar("T", bound="Frame") @@ -4029,8 +4030,11 @@ def _find_common_dtypes_and_categories(non_null_columns, dtypes): # default to the first non-null dtype dtypes[idx] = cols[0].dtype # If all the non-null dtypes are int/float, find a common dtype - if all(is_numerical_dtype(col.dtype) for col in cols): - dtypes[idx] = np.find_common_type([col.dtype for col in cols], []) + if all( + is_numerical_dtype(col.dtype) or is_decimal_dtype(col.dtype) + for col in cols + ): + dtypes[idx] = find_common_type([col.dtype for col in cols]) # If all categorical dtypes, combine the categories elif all( isinstance(col, cudf.core.column.CategoricalColumn) for col in cols @@ -4045,17 +4049,6 @@ def _find_common_dtypes_and_categories(non_null_columns, dtypes): # Set the column dtype to the codes' dtype. The categories # will be re-assigned at the end dtypes[idx] = min_scalar_type(len(categories[idx])) - elif all( - isinstance(col, cudf.core.column.DecimalColumn) for col in cols - ): - # Find the largest scale and the largest difference between - # precision and scale of the columns to be concatenated - s = max([col.dtype.scale for col in cols]) - lhs = max([col.dtype.precision - col.dtype.scale for col in cols]) - # Combine to get the necessary precision and clip at the maximum - # precision - p = min(cudf.Decimal64Dtype.MAX_PRECISION, s + lhs) - dtypes[idx] = cudf.Decimal64Dtype(p, s) # Otherwise raise an error if columns have different dtypes elif not all(is_dtype_equal(c.dtype, dtypes[idx]) for c in cols): raise ValueError("All columns must be the same type") diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index d812214caf8..a894baf8235 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -45,7 +45,6 @@ from cudf.utils import cudautils, docutils, ioutils from cudf.utils.docutils import copy_docstring from cudf.utils.dtypes import ( - _decimal_normalize_types, can_convert_to_column, is_decimal_dtype, is_list_dtype, @@ -53,7 +52,7 @@ is_mixed_with_object_dtype, is_scalar, min_scalar_type, - numeric_normalize_types, + find_common_type, ) from cudf.utils.utils import ( get_appropriate_dispatched_func, @@ -2402,10 +2401,8 @@ def _concat(cls, objs, axis=0, index=True): ) if dtype_mismatch: - if isinstance(objs[0]._column, cudf.core.column.DecimalColumn): - objs = _decimal_normalize_types(*objs) - else: - objs = numeric_normalize_types(*objs) + common_dtype = find_common_type([obj.dtype for obj in objs]) + objs = [obj.astype(common_dtype) for obj in objs] col = _concat_columns([o._column for o in objs]) diff --git a/python/cudf/cudf/tests/test_concat.py b/python/cudf/cudf/tests/test_concat.py index 31dc6012905..5c4c121db4d 100644 --- a/python/cudf/cudf/tests/test_concat.py +++ b/python/cudf/cudf/tests/test_concat.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +from decimal import Decimal import cudf as gd from cudf.tests.utils import assert_eq, assert_exceptions_equal @@ -1262,3 +1263,267 @@ def test_concat_decimal_series(ltype, rtype): expected = pd.concat([ps1, ps2]) assert_eq(expected, got) + + +@pytest.mark.parametrize( + "df1, df2, df3, expected", + [ + ( + gd.DataFrame( + {"val": [Decimal("42.5"), Decimal("8.7")]}, + dtype=Decimal64Dtype(5, 2), + ), + gd.DataFrame( + {"val": [Decimal("9.23"), Decimal("-67.49")]}, + dtype=Decimal64Dtype(6, 4), + ), + gd.DataFrame({"val": [8, -5]}, dtype="int32"), + gd.DataFrame( + { + "val": [ + Decimal("42.5"), + Decimal("8.7"), + Decimal("9.23"), + Decimal("-67.49"), + Decimal("8"), + Decimal("-5"), + ] + }, + dtype=Decimal64Dtype(7, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.DataFrame( + {"val": [Decimal("95.2"), Decimal("23.4")]}, + dtype=Decimal64Dtype(5, 2), + ), + gd.DataFrame({"val": [54, 509]}, dtype="uint16"), + gd.DataFrame({"val": [24, -48]}, dtype="int32"), + gd.DataFrame( + { + "val": [ + Decimal("95.2"), + Decimal("23.4"), + Decimal("54"), + Decimal("509"), + Decimal("24"), + Decimal("-48"), + ] + }, + dtype=Decimal64Dtype(5, 2), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.DataFrame( + {"val": [Decimal("36.56"), Decimal("-59.24")]}, + dtype=Decimal64Dtype(9, 4), + ), + gd.DataFrame({"val": [403.21, 45.13]}, dtype="float32"), + gd.DataFrame({"val": [52.262, -49.25]}, dtype="float64"), + gd.DataFrame( + { + "val": [ + Decimal("36.56"), + Decimal("-59.24"), + Decimal("403.21"), + Decimal("45.13"), + Decimal("52.262"), + Decimal("-49.25"), + ] + }, + dtype=Decimal64Dtype(9, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.DataFrame( + {"val": [Decimal("9563.24"), Decimal("236.633")]}, + dtype=Decimal64Dtype(9, 4), + ), + gd.DataFrame({"val": [5393, -95832]}, dtype="int64"), + gd.DataFrame({"val": [-29.234, -31.945]}, dtype="float64"), + gd.DataFrame( + { + "val": [ + Decimal("9563.24"), + Decimal("236.633"), + Decimal("5393"), + Decimal("-95832"), + Decimal("-29.234"), + Decimal("-31.945"), + ] + }, + dtype=Decimal64Dtype(9, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ], +) +def test_concat_decimal_numeric_dataframe(df1, df2, df3, expected): + df = gd.concat([df1, df2, df3]) + assert_eq(df, expected) + assert_eq(df.val.dtype, expected.val.dtype) + + +@pytest.mark.parametrize( + "s1, s2, s3, expected", + [ + ( + gd.Series( + [Decimal("32.8"), Decimal("-87.7")], dtype=Decimal64Dtype(6, 2) + ), + gd.Series( + [Decimal("101.243"), Decimal("-92.449")], + dtype=Decimal64Dtype(9, 6), + ), + gd.Series([94, -22], dtype="int32"), + gd.Series( + [ + Decimal("32.8"), + Decimal("-87.7"), + Decimal("101.243"), + Decimal("-92.449"), + Decimal("94"), + Decimal("-22"), + ], + dtype=Decimal64Dtype(10, 6), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.Series( + [Decimal("7.2"), Decimal("122.1")], dtype=Decimal64Dtype(5, 2) + ), + gd.Series([33, 984], dtype="uint32"), + gd.Series([593, -702], dtype="int32"), + gd.Series( + [ + Decimal("7.2"), + Decimal("122.1"), + Decimal("33"), + Decimal("984"), + Decimal("593"), + Decimal("-702"), + ], + dtype=Decimal64Dtype(5, 2), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.Series( + [Decimal("982.94"), Decimal("-493.626")], + dtype=Decimal64Dtype(9, 4), + ), + gd.Series([847.98, 254.442], dtype="float32"), + gd.Series([5299.262, -2049.25], dtype="float64"), + gd.Series( + [ + Decimal("982.94"), + Decimal("-493.626"), + Decimal("847.98"), + Decimal("254.442"), + Decimal("5299.262"), + Decimal("-2049.25"), + ], + dtype=Decimal64Dtype(9, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ( + gd.Series( + [Decimal("492.204"), Decimal("-72824.455")], + dtype=Decimal64Dtype(9, 4), + ), + gd.Series([8438, -27462], dtype="int64"), + gd.Series([-40.292, 49202.953], dtype="float64"), + gd.Series( + [ + Decimal("492.204"), + Decimal("-72824.455"), + Decimal("8438"), + Decimal("-27462"), + Decimal("-40.292"), + Decimal("49202.953"), + ], + dtype=Decimal64Dtype(9, 4), + index=[0, 1, 0, 1, 0, 1], + ), + ), + ], +) +def test_concat_decimal_numeric_series(s1, s2, s3, expected): + s = gd.concat([s1, s2, s3]) + assert_eq(s, expected) + + +@pytest.mark.parametrize( + "s1, s2, expected", + [ + ( + gd.Series( + [Decimal("955.22"), Decimal("8.2")], dtype=Decimal64Dtype(5, 2) + ), + gd.Series(["2007-06-12", "2006-03-14"], dtype="datetime64"), + gd.Series( + [ + "955.22", + "8.20", + "2007-06-12 00:00:00", + "2006-03-14 00:00:00", + ], + index=[0, 1, 0, 1], + ), + ), + ( + gd.Series( + [Decimal("-52.44"), Decimal("365.22")], + dtype=Decimal64Dtype(5, 2), + ), + gd.Series( + np.arange( + "2005-02-01T12", "2005-02-01T15", dtype="datetime64[h]" + ), + dtype="datetime64[s]", + ), + gd.Series( + [ + "-52.44", + "365.22", + "2005-02-01 12:00:00", + "2005-02-01 13:00:00", + "2005-02-01 14:00:00", + ], + index=[0, 1, 0, 1, 2], + ), + ), + ( + gd.Series( + [Decimal("753.0"), Decimal("94.22")], + dtype=Decimal64Dtype(5, 2), + ), + gd.Series([np.timedelta64(111, "s"), np.timedelta64(509, "s")]), + gd.Series( + ["753.00", "94.22", "0 days 00:01:51", "0 days 00:08:29"], + index=[0, 1, 0, 1], + ), + ), + ( + gd.Series( + [Decimal("753.0"), Decimal("94.22")], + dtype=Decimal64Dtype(5, 2), + ), + gd.Series( + [np.timedelta64(940252, "s"), np.timedelta64(758385, "s")] + ), + gd.Series( + ["753.00", "94.22", "10 days 21:10:52", "8 days 18:39:45"], + index=[0, 1, 0, 1], + ), + ), + ], +) +def test_concat_decimal_non_numeric(s1, s2, expected): + s = gd.concat([s1, s2]) + assert_eq(s, expected) diff --git a/python/cudf/cudf/utils/dtypes.py b/python/cudf/cudf/utils/dtypes.py index 16c35bab4b1..0b59116f8e6 100644 --- a/python/cudf/cudf/utils/dtypes.py +++ b/python/cudf/cudf/utils/dtypes.py @@ -290,13 +290,15 @@ def is_decimal_dtype(obj): ) -def _decimal_normalize_types(*args): - s = max([a.dtype.scale for a in args]) - lhs = max([a.dtype.precision - a.dtype.scale for a in args]) +def _find_common_type_decimal(dtypes): + # Find the largest scale and the largest difference between + # precision and scale of the columns to be concatenated + s = max([dtype.scale for dtype in dtypes]) + lhs = max([dtype.precision - dtype.scale for dtype in dtypes]) + # Combine to get the necessary precision and clip at the maximum + # precision p = min(cudf.Decimal64Dtype.MAX_PRECISION, s + lhs) - dtype = cudf.Decimal64Dtype(p, s) - - return [a.astype(dtype) for a in args] + return cudf.Decimal64Dtype(p, s) def cudf_dtype_from_pydata_dtype(dtype): @@ -690,9 +692,15 @@ def find_common_type(dtypes): dtypes = set(dtypes) if any(is_decimal_dtype(dtype) for dtype in dtypes): - raise NotImplementedError( - "DecimalDtype is not yet supported in find_common_type" - ) + if all( + is_decimal_dtype(dtype) or is_numerical_dtype(dtype) + for dtype in dtypes + ): + return _find_common_type_decimal( + [dtype for dtype in dtypes if is_decimal_dtype(dtype)] + ) + else: + return np.dtype("O") # Corner case 1: # Resort to np.result_type to handle "M" and "m" types separately