Skip to content

Commit

Permalink
Add object, str implementations for write_report()
Browse files Browse the repository at this point in the history
Expand tests.
  • Loading branch information
khaeru committed Nov 28, 2023
1 parent 642afe7 commit f28a095
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 7 deletions.
46 changes: 40 additions & 6 deletions genno/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,24 +1015,56 @@ def _format_header_comment(value: str) -> str:

@singledispatch
def write_report(
quantity: pd.DataFrame, path: Union[str, PathLike], kwargs: Optional[dict] = None
quantity: object, path: Union[str, PathLike], kwargs: Optional[dict] = None
) -> None:
"""Write a quantity to a file.
:py:`write_report()` is a :func:`~functools.singledispatch` function. This means
that user code can extend this operator to support different types for the
`quantity` argument:
.. code-block:: python
import genno.operator
@genno.operator.write_report.register
def my_writer(qty: MyClass, path, kwargs):
... # Code to write MyClass to file
Parameters
----------
quantity :
Object to be written. The base implementation supports :class:`.Quantity` and
:class:`pandas.DataFrame`.
path : str or pathlib.Path
Path to the file to be written.
kwargs :
Keyword arguments. For the default implementation, these are passed to
Keyword arguments. For the base implementation, these are passed to
:meth:~pandas.DataFrame.to_csv` or :meth:~pandas.DataFrame.to_excel` (according
to `path`), except for:
- "header_comment": valid only for `path` ending in :file:`.csv`. Multi-line
text that is prepended to the file, with comment characters ("# ") before
each line.
Raises
------
NotImplementedError
If `quantity` is of a type not supported by the base implementation or any
overloads.
"""
raise NotImplementedError(f"Write {type(quantity)} to file")


@write_report.register
def _(quantity: str, path: Union[str, PathLike], kwargs: Optional[dict] = None):
Path(path).write_text(quantity)


@write_report.register
def _(
quantity: pd.DataFrame, path: Union[str, PathLike], kwargs: Optional[dict] = None
) -> None:
path = Path(path)

if path.suffix == ".csv":
Expand All @@ -1044,15 +1076,17 @@ def write_report(
quantity.to_csv(f, **kwargs)
elif path.suffix == ".xlsx":
kwargs = kwargs or dict()
kwargs.setdefault("index", False)
kwargs.setdefault("merge_cells", False)
kwargs.setdefault("index", False)

quantity.to_excel(path, **kwargs)
else:
path.write_text(quantity) # type: ignore
raise NotImplementedError(f"Write pandas.DataFrame to {path.suffix!r}")


@write_report.register
def _(quantity: Quantity, path, kwargs=None) -> None:
def _(
quantity: Quantity, path: Union[str, PathLike], kwargs: Optional[dict] = None
) -> None:
# Convert the Quantity to a pandas.DataFrame, then write
write_report(quantity.to_dataframe(), path, kwargs)
write_report(quantity.to_dataframe().reset_index(), path, kwargs)
2 changes: 1 addition & 1 deletion genno/tests/core/test_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def test_file_formats(test_data_path, tmp_path):

# Write to CSV
p3 = tmp_path / "output.csv"
c.write(k, p3, index=True)
c.write(k, p3)

# Output is identical to input file, except for order
assert sorted(p1.read_text().split("\n")) == sorted(p3.read_text().split("\n"))
Expand Down
26 changes: 26 additions & 0 deletions genno/tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,29 @@ def test_sum(data, dimensions):
result = operator.sum(x, dimensions=dimensions)

assert result.name == x.name and result.units == x.units # Pass through


def test_write_report0(tmp_path, data) -> None:
p = tmp_path.joinpath("foo.txt")
*_, x = data

# Unsupported type
with pytest.raises(NotImplementedError, match="Write <class 'list'> to file"):
operator.write_report(list(), p)

# Unsupported path suffix
with pytest.raises(NotImplementedError, match="Write pandas.DataFrame to '.bar'"):
operator.write_report(x, tmp_path.joinpath("foo.bar"))

# Plain text
operator.write_report("Hello, world!", p)
assert "Hello, world!" == p.read_text()


def test_write_report1(tmp_path, data) -> None:
p = tmp_path.joinpath("foo.csv")
*_, x = data

# Header comment is written
operator.write_report(x, p, dict(header_comment="Hello, world!\n"))
assert p.read_text().startswith("# Hello, world!\n#")

0 comments on commit f28a095

Please sign in to comment.