Skip to content

Commit

Permalink
Fix DataFrame constructor to broadcast scalar inputs properly (#12997)
Browse files Browse the repository at this point in the history
Fixes: #12646

This PR fixes an issue with `DataFrame` where broadcasting scalar inputs was order dependent.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #12997
  • Loading branch information
galipremsagar authored Mar 23, 2023
1 parent dd5252b commit 33e2387
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
24 changes: 17 additions & 7 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,14 +907,24 @@ def _init_from_dict_like(
if index is None:
num_rows = 0
if data:
col_name = next(iter(data))
if is_scalar(data[col_name]):
num_rows = num_rows or 1
else:
data[col_name] = column.as_column(
data[col_name], nan_as_null=nan_as_null
keys, values, lengths = zip(
*(
(k, v, 1)
if is_scalar(v)
else (
k,
vc := as_column(v, nan_as_null=nan_as_null),
len(vc),
)
for k, v in data.items()
)
num_rows = len(data[col_name])
)
data = dict(zip(keys, values))
try:
(num_rows,) = (set(lengths) - {1}) or {1}
except ValueError:
raise ValueError("All arrays must be the same length")

self._index = RangeIndex(0, num_rows)
else:
self._index = as_index(index)
Expand Down
30 changes: 30 additions & 0 deletions python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10057,3 +10057,33 @@ def test_dataframe_from_arrow_slice():
actual = cudf.DataFrame.from_arrow(table_slice)

assert_eq(expected, actual)


@pytest.mark.parametrize(
"data",
[
{"a": [1, 2, 3], "b": ["x", "y", "z"], "c": 4},
{"c": 4, "a": [1, 2, 3], "b": ["x", "y", "z"]},
{"a": [1, 2, 3], "c": 4},
],
)
def test_dataframe_init_from_scalar_and_lists(data):
actual = cudf.DataFrame(data)
expected = pd.DataFrame(data)

assert_eq(expected, actual)


def test_dataframe_init_length_error():
assert_exceptions_equal(
lfunc=pd.DataFrame,
rfunc=cudf.DataFrame,
lfunc_args_and_kwargs=(
[],
{"data": {"a": [1, 2, 3], "b": ["x", "y", "z", "z"], "c": 4}},
),
rfunc_args_and_kwargs=(
[],
{"data": {"a": [1, 2, 3], "b": ["x", "y", "z", "z"], "c": 4}},
),
)

0 comments on commit 33e2387

Please sign in to comment.