Skip to content

Commit

Permalink
Fix concat_columns for DataFrames with list features (#165)
Browse files Browse the repository at this point in the history
* Update instance checks in `concat_columns` to use cudf/pd classes.

Adds test for concat_columns with list columns

* Support CPU-only environment by checking cudf and pandas separately

* Update import formatting
  • Loading branch information
oliverholworthy authored Nov 10, 2022
1 parent c1ddc19 commit 930ea89
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
16 changes: 7 additions & 9 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,13 @@ def concat_columns(args: list):
"""Dispatch function to concatenate DataFrames with axis=1"""
if len(args) == 1:
return args[0]
elif isinstance(args[0], DataFrameLike):
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd
return _lib.concat(
elif cudf is not None and isinstance(args[0], cudf.DataFrame):
return cudf.concat(
[a.reset_index(drop=True) for a in args],
axis=1,
)
elif isinstance(args[0], pd.DataFrame):
return pd.concat(
[a.reset_index(drop=True) for a in args],
axis=1,
)
Expand All @@ -361,12 +365,6 @@ def concat_columns(args: list):
for arg in args:
result.update(arg)
return result
else:
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd
return _lib.concat(
[a.reset_index(drop=True) for a in args],
axis=1,
)
return None


Expand Down
21 changes: 15 additions & 6 deletions tests/unit/core/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
import numpy as np
import pytest

from merlin.core.dispatch import HAS_GPU, is_list_dtype, list_val_dtype, make_df
from merlin.core.dispatch import HAS_GPU, concat_columns, is_list_dtype, list_val_dtype, make_df

if HAS_GPU:
_CPU = [True, False]
_DEVICES = ["cpu", "gpu"]
else:
_CPU = [True]
_DEVICES = ["cpu"]


@pytest.mark.parametrize("cpu", _CPU)
def test_list_dtypes(tmpdir, cpu):
df = make_df(device="cpu" if cpu else "gpu")
@pytest.mark.parametrize("device", _DEVICES)
def test_list_dtypes(tmpdir, device):
df = make_df(device=device)
df["vals"] = [
[[0, 1, 2], [3, 4], [5]],
]
Expand All @@ -35,3 +35,12 @@ def test_list_dtypes(tmpdir, cpu):

assert is_list_dtype(df["vals"])
assert list_val_dtype(df["vals"]) == np.dtype(np.int64)


@pytest.mark.parametrize("device", _DEVICES)
def test_concat_columns(device):
df1 = make_df({"a": [1, 2], "b": [[3], [4, 5]]}, device=device)
df2 = make_df({"c": [3, 4, 5]}, device=device)
data_frames = [df1, df2]
res = concat_columns(data_frames)
assert res.columns.to_list() == ["a", "b", "c"]

0 comments on commit 930ea89

Please sign in to comment.