Skip to content

Commit

Permalink
BUG (string dtype): fix handling of string dtype in interchange proto…
Browse files Browse the repository at this point in the history
…col (pandas-dev#60333)

Co-authored-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
WillAyd and jorisvandenbossche authored Nov 17, 2024
1 parent 34c080c commit 720a6e7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
12 changes: 8 additions & 4 deletions pandas/core/interchange/from_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

from pandas._config import using_string_dtype

from pandas.compat._optional import import_optional_dependency

import pandas as pd
Expand Down Expand Up @@ -147,8 +149,6 @@ def protocol_df_chunk_to_pandas(df: DataFrameXchg) -> pd.DataFrame:
-------
pd.DataFrame
"""
# We need a dict of columns here, with each column being a NumPy array (at
# least for now, deal with non-NumPy dtypes later).
columns: dict[str, Any] = {}
buffers = [] # hold on to buffers, keeps memory alive
for name in df.column_names():
Expand Down Expand Up @@ -347,8 +347,12 @@ def string_column_to_ndarray(col: Column) -> tuple[np.ndarray, Any]:
# Add to our list of strings
str_list[i] = string

# Convert the string list to a NumPy array
return np.asarray(str_list, dtype="object"), buffers
if using_string_dtype():
res = pd.Series(str_list, dtype="str")
else:
res = np.asarray(str_list, dtype="object") # type: ignore[assignment]

return res, buffers # type: ignore[return-value]


def parse_datetime_format_str(format_str, data) -> pd.Series | np.ndarray:
Expand Down
9 changes: 2 additions & 7 deletions pandas/tests/interchange/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas._libs.tslibs import iNaT
from pandas.compat import (
is_ci_environment,
Expand Down Expand Up @@ -401,7 +399,6 @@ def test_interchange_from_corrected_buffer_dtypes(monkeypatch) -> None:
pd.api.interchange.from_dataframe(df)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_empty_string_column():
# https://github.com/pandas-dev/pandas/issues/56703
df = pd.DataFrame({"a": []}, dtype=str)
Expand All @@ -410,13 +407,12 @@ def test_empty_string_column():
tm.assert_frame_equal(df, result)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_large_string():
# GH#56702
pytest.importorskip("pyarrow")
df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]")
result = pd.api.interchange.from_dataframe(df.__dataframe__())
expected = pd.DataFrame({"a": ["x"]}, dtype="object")
expected = pd.DataFrame({"a": ["x"]}, dtype="str")
tm.assert_frame_equal(result, expected)


Expand All @@ -427,7 +423,6 @@ def test_non_str_names():
assert names == ["0"]


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_non_str_names_w_duplicates():
# https://github.com/pandas-dev/pandas/issues/56701
df = pd.DataFrame({"0": [1, 2, 3], 0: [4, 5, 6]})
Expand All @@ -438,7 +433,7 @@ def test_non_str_names_w_duplicates():
"Expected a Series, got a DataFrame. This likely happened because you "
"called __dataframe__ on a DataFrame which, after converting column "
r"names to string, resulted in duplicated names: Index\(\['0', '0'\], "
r"dtype='object'\). Please rename these columns before using the "
r"dtype='(str|object)'\). Please rename these columns before using the "
"interchange protocol."
),
):
Expand Down

0 comments on commit 720a6e7

Please sign in to comment.