Skip to content

Commit

Permalink
fix(python): Ensure write_excel recognises the Array dtype and writ…
Browse files Browse the repository at this point in the history
…es it out as a string (#20994)
  • Loading branch information
alexander-beedie authored Feb 1, 2025
1 parent 084ddde commit 4c22f1e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 30 deletions.
10 changes: 4 additions & 6 deletions py-polars/polars/io/spreadsheet/_write_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
Datetime,
Float64,
Int64,
List,
Object,
Struct,
Time,
)
from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES
Expand Down Expand Up @@ -190,13 +188,13 @@ def _xl_column_range(
include_header: bool,
as_range: bool = True,
) -> tuple[int, int, int, int] | str:
"""Return the excel sheet range of a named column, accounting for all offsets."""
"""Return the Excel sheet range of a named column, accounting for all offsets."""
col_start = (
table_start[0] + int(include_header),
table_start[1] + (df.get_column_index(col) if isinstance(col, str) else col[0]),
)
col_finish = (
col_start[0] + len(df) - 1,
col_start[0] + df.height - 1,
col_start[1] + (0 if isinstance(col, str) else (col[1] - col[0])),
)
if as_range:
Expand Down Expand Up @@ -358,7 +356,7 @@ def _map_str(s: Series) -> Series:
cast_cols = [
F.col(col).map_batches(_map_str).alias(col)
for col, tp in df.schema.items()
if tp in (List, Struct, Object)
if (tp.is_nested() or tp == Object)
]
if cast_cols:
df = df.with_columns(cast_cols)
Expand Down Expand Up @@ -569,7 +567,7 @@ def _xl_setup_workbook(
workbook: Workbook | BytesIO | Path | str | None,
worksheet: str | Worksheet | None = None,
) -> tuple[Workbook, Worksheet, bool]:
"""Establish the target excel workbook and worksheet."""
"""Establish the target Excel workbook and worksheet."""
from xlsxwriter import Workbook
from xlsxwriter.worksheet import Worksheet

Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,17 +948,17 @@ def _drop_null_data(
col = df[col_name]
if (
col.dtype == Null
or col.null_count() == len(df)
or col.null_count() == df.height
or (
col.dtype in NUMERIC_DTYPES
and col.replace(0, None).null_count() == len(df)
and col.replace(0, None).null_count() == df.height
)
):
null_cols.append(col_name)
if null_cols:
df = df.drop(*null_cols)

if len(df) == 0 and len(df.columns) == 0:
if df.height == df.width == 0:
return _empty_frame(raise_if_empty)
if drop_empty_rows:
return df.filter(~F.all_horizontal(F.all().is_null()))
Expand Down
58 changes: 37 additions & 21 deletions py-polars/tests/unit/io/test_spreadsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from polars._typing import ExcelSpreadsheetEngine, SchemaDict, SelectorType
from polars._typing import (
ExcelSpreadsheetEngine,
PolarsDataType,
SchemaDict,
SelectorType,
)


# pytestmark = pytest.mark.slow()
Expand Down Expand Up @@ -543,22 +548,6 @@ def test_read_mixed_dtype_columns(
)


@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"])
def test_write_excel_bytes(engine: ExcelSpreadsheetEngine) -> None:
df = pl.DataFrame({"colx": [1.5, -2, 0], "coly": ["a", None, "c"]})

excel_bytes = BytesIO()
df.write_excel(excel_bytes)

df_read = pl.read_excel(excel_bytes, engine=engine)
assert_frame_equal(df, df_read)

# also confirm consistent behaviour when 'infer_schema_length=0'
df_read = pl.read_excel(excel_bytes, engine=engine, infer_schema_length=0)
expected = pl.DataFrame({"colx": ["1.5", "-2", "0"], "coly": ["a", None, "c"]})
assert_frame_equal(expected, df_read)


def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> None:
df1 = pl.read_excel(
path_xlsx,
Expand Down Expand Up @@ -900,10 +889,21 @@ def test_excel_write_column_and_row_totals(engine: ExcelSpreadsheetEngine) -> No
assert xldf.row(-1) == (None, 0.0, 0.0, 0, 0, None, 0.0, 0)


@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"])
def test_excel_write_compound_types(engine: ExcelSpreadsheetEngine) -> None:
@pytest.mark.parametrize(
("engine", "list_dtype"),
[
("calamine", pl.List(pl.Int8)),
("openpyxl", pl.List(pl.UInt16)),
("xlsx2csv", pl.Array(pl.Int32, 2)),
],
)
def test_excel_write_compound_types(
engine: ExcelSpreadsheetEngine,
list_dtype: PolarsDataType,
) -> None:
df = pl.DataFrame(
{"x": [[1, 2], [3, 4], [5, 6]], "y": ["a", "b", "c"], "z": [9, 8, 7]}
data={"x": [[1, 2], [3, 4], [5, 6]], "y": ["a", "b", "c"], "z": [9, 8, 7]},
schema_overrides={"x": pl.Array(pl.Int32, 2)},
).select("x", pl.struct(["y", "z"]))

xls = BytesIO()
Expand Down Expand Up @@ -966,6 +966,22 @@ def test_excel_read_named_table_with_total_row(tmp_path: Path) -> None:
assert xldf.row(3) == (None, 0, 0)


@pytest.mark.parametrize("engine", ["calamine", "openpyxl", "xlsx2csv"])
def test_excel_write_to_bytesio(engine: ExcelSpreadsheetEngine) -> None:
df = pl.DataFrame({"colx": [1.5, -2, 0], "coly": ["a", None, "c"]})

excel_bytes = BytesIO()
df.write_excel(excel_bytes)

df_read = pl.read_excel(excel_bytes, engine=engine)
assert_frame_equal(df, df_read)

# also confirm consistent behaviour when 'infer_schema_length=0'
df_read = pl.read_excel(excel_bytes, engine=engine, infer_schema_length=0)
expected = pl.DataFrame({"colx": ["1.5", "-2", "0"], "coly": ["a", None, "c"]})
assert_frame_equal(expected, df_read)


@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl", "calamine"])
def test_excel_write_to_file_object(
engine: ExcelSpreadsheetEngine, tmp_path: Path
Expand Down Expand Up @@ -1352,7 +1368,7 @@ def test_drop_empty_rows(
assert df3.shape == (10, 4)


def test_write_excel_select_col_dtype() -> None:
def test_excel_write_select_col_dtype() -> None:
from openpyxl import load_workbook
from xlsxwriter import Workbook

Expand Down

0 comments on commit 4c22f1e

Please sign in to comment.