diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index 177af12af..4d9c81a47 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -25,8 +25,7 @@ import pyarrow.parquet as pq from merlin.core.compat import HAS_GPU -from merlin.core.protocols import DataFrameLike, SeriesLike -from merlin.dag import DictArray +from merlin.core.protocols import DataFrameLike, DictLike, SeriesLike cp = None cudf = None @@ -351,8 +350,14 @@ def concat_columns(args: list): """Dispatch function to concatenate DataFrames with axis=1""" if len(args) == 1: return args[0] - elif isinstance(args[0], DictArray): - result = DictArray({}) + elif isinstance(args[0], DataFrameLike): + _lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd + return _lib.concat( + [a.reset_index(drop=True) for a in args], + axis=1, + ) + elif isinstance(args[0], DictLike): + result = type(args[0])() for arg in args: result.update(arg) return result diff --git a/merlin/dag/dictarray.py b/merlin/dag/dictarray.py index 8435069f0..7c595fd35 100644 --- a/merlin/dag/dictarray.py +++ b/merlin/dag/dictarray.py @@ -43,9 +43,11 @@ class DictArray(Transformable): A simple dataframe-like wrapper around a dictionary of values """ - def __init__(self, values: Dict, dtypes: Optional[Dict] = None): + def __init__(self, values: Optional[Dict] = None, dtypes: Optional[Dict] = None): super().__init__() + values = values or {} + array_values = {} for key, value in values.items(): array_values[key] = np.array(value) if isinstance(value, list) else value