Skip to content

Commit

Permalink
DataFrame with namedtuples uses ._field as column names (#13824)
Browse files Browse the repository at this point in the history
Allow namedtuple's `_field` attribute to be mapped to DataFrame column labels like pandas

closes #13823

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #13824
  • Loading branch information
mroeschke authored Aug 9, 2023
1 parent edb25a8 commit da6ac73
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
9 changes: 9 additions & 0 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,15 @@ def _init_from_list_like(self, data, index=None, columns=None):
for col in data
):
raise TypeError("Inputs should be an iterable or sequence.")
if (
len(data) > 0
and columns is None
and isinstance(data[0], tuple)
and hasattr(data[0], "_fields")
):
# pandas behavior is to use the fields from the first
# namedtuple as the column names
columns = data[0]._fields

data = list(itertools.zip_longest(*data))

Expand Down
20 changes: 19 additions & 1 deletion python/cudf/cudf/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import string
import textwrap
import warnings
from collections import OrderedDict, defaultdict
from collections import OrderedDict, defaultdict, namedtuple
from copy import copy

import cupy
Expand Down Expand Up @@ -10261,3 +10261,21 @@ def __getitem__(self, key):

with pytest.raises(TypeError):
cudf.DataFrame({"a": A()})


def test_dataframe_constructor_from_namedtuple():
Point1 = namedtuple("Point1", ["a", "b", "c"])
Point2 = namedtuple("Point1", ["x", "y"])

data = [Point1(1, 2, 3), Point2(4, 5)]
idx = ["a", "b"]
gdf = cudf.DataFrame(data, index=idx)
pdf = pd.DataFrame(data, index=idx)

assert_eq(gdf, pdf)

data = [Point2(4, 5), Point1(1, 2, 3)]
with pytest.raises(ValueError):
cudf.DataFrame(data, index=idx)
with pytest.raises(ValueError):
pd.DataFrame(data, index=idx)

0 comments on commit da6ac73

Please sign in to comment.