Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session.virtualfile_to_dataset: Add 'header' parameter to parse column names from table header #3117

Merged
merged 9 commits into from
Apr 18, 2024
7 changes: 6 additions & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,7 @@ def virtualfile_to_dataset(
self,
vfname: str,
output_type: Literal["pandas", "numpy", "file", "strings"] = "pandas",
header: int | None = None,
column_names: list[str] | None = None,
dtype: type | dict[str, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -1831,6 +1832,10 @@ def virtualfile_to_dataset(
- ``"numpy"`` will return a :class:`numpy.ndarray` object.
- ``"file"`` means the result was saved to a file and will return ``None``.
- ``"strings"`` will return the trailing text only as an array of strings.
header
Row number containing column names for the :class:`pandas.DataFrame` output.
``header=None`` means not to parse the column names from table header.
Ignored if the row number is larger than the number of headers in the table.
column_names
The column names for the :class:`pandas.DataFrame` output.
dtype
Expand Down Expand Up @@ -1945,7 +1950,7 @@ def virtualfile_to_dataset(
return result.to_strings()

result = result.to_dataframe(
column_names=column_names, dtype=dtype, index_col=index_col
header=header, column_names=column_names, dtype=dtype, index_col=index_col
)
if output_type == "numpy": # numpy.ndarray output
return result.to_numpy()
Expand Down
24 changes: 19 additions & 5 deletions pygmt/datatypes/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
>>> with GMTTempFile(suffix=".txt") as tmpfile:
... # Prepare the sample data file
... with Path(tmpfile.name).open(mode="w") as fp:
... print("# x y z name", file=fp)
... print(">", file=fp)
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
Expand All @@ -43,7 +44,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
... print(ds.min[: ds.n_columns], ds.max[: ds.n_columns])
... # The table
... tbl = ds.table[0].contents
... print(tbl.n_columns, tbl.n_segments, tbl.n_records)
... print(tbl.n_columns, tbl.n_segments, tbl.n_records, tbl.n_headers)
... print(tbl.header[: tbl.n_headers])
... print(tbl.min[: tbl.n_columns], ds.max[: tbl.n_columns])
... for i in range(tbl.n_segments):
... seg = tbl.segment[i].contents
Expand All @@ -52,7 +54,8 @@ class _GMT_DATASET(ctp.Structure): # noqa: N801
... print(seg.text[: seg.n_rows])
1 3 2
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
3 2 4
3 2 4 1
[b'x y z name']
[1.0, 2.0, 3.0] [10.0, 11.0, 12.0]
[1.0, 4.0]
[2.0, 5.0]
Expand Down Expand Up @@ -169,6 +172,7 @@ def to_strings(self) -> np.ndarray[Any, np.dtype[np.str_]]:

def to_dataframe(
self,
header: int | None = None,
column_names: pd.Index | None = None,
dtype: type | Mapping[Any, type] | None = None,
index_col: str | int | None = None,
Expand All @@ -187,6 +191,10 @@ def to_dataframe(
----------
column_names
A list of column names.
header
Row number containing column names. ``header=None`` means not to parse the
column names from table header. Ignored if the row number is larger than the
number of headers in the table.
dtype
Data type. Can be a single type for all columns or a dictionary mapping
column names to types.
Expand All @@ -207,6 +215,7 @@ def to_dataframe(
>>> with GMTTempFile(suffix=".txt") as tmpfile:
... # prepare the sample data file
... with Path(tmpfile.name).open(mode="w") as fp:
... print("# col1 col2 col3 colstr", file=fp)
... print(">", file=fp)
... print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
... print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
Expand All @@ -218,12 +227,12 @@ def to_dataframe(
... lib.call_module("read", f"{tmpfile.name} {vouttbl} -Td")
... ds = lib.read_virtualfile(vouttbl, kind="dataset")
... text = ds.contents.to_strings()
... df = ds.contents.to_dataframe()
... df = ds.contents.to_dataframe(header=0)
>>> text
array(['TEXT1 TEXT23', 'TEXT4 TEXT567', 'TEXT8 TEXT90',
'TEXT123 TEXT456789'], dtype='<U18')
>>> df
0 1 2 3
col1 col2 col3 colstr
0 1.0 2.0 3.0 TEXT1 TEXT23
1 4.0 5.0 6.0 TEXT4 TEXT567
2 7.0 8.0 9.0 TEXT8 TEXT90
Expand All @@ -248,14 +257,19 @@ def to_dataframe(
if len(textvector) != 0:
vectors.append(pd.Series(data=textvector, dtype=pd.StringDtype()))

if header is not None:
tbl = self.table[0].contents # Use the first table!
if header < tbl.n_headers:
column_names = tbl.header[header].decode().split()

if len(vectors) == 0:
# Return an empty DataFrame if no columns are found.
df = pd.DataFrame(columns=column_names)
else:
# Create a DataFrame object by concatenating multiple columns
df = pd.concat(objs=vectors, axis="columns")
if column_names is not None: # Assign column names
df.columns = column_names
df.columns = column_names[: df.shape[1]]
if dtype is not None: # Set dtype for the whole dataset or individual columns
df = df.astype(dtype)
if index_col is not None: # Use a specific column as index
Expand Down
61 changes: 59 additions & 2 deletions pygmt/tests/test_datatypes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def dataframe_from_pandas(filepath_or_buffer, sep=r"\s+", comment="#", header=No
return df


def dataframe_from_gmt(fname):
def dataframe_from_gmt(fname, **kwargs):
"""
Read tabular data as pandas.DataFrame using GMT virtual file.
"""
with Session() as lib:
with lib.virtualfile_out(kind="dataset") as vouttbl:
lib.call_module("read", f"{fname} {vouttbl} -Td")
df = lib.virtualfile_to_dataset(vfname=vouttbl)
df = lib.virtualfile_to_dataset(vfname=vouttbl, **kwargs)
return df


Expand Down Expand Up @@ -84,6 +84,63 @@ def test_dataset_empty():
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header():
"""
Test parsing column names from dataset header.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

# Parse columne names from the first header line
df = dataframe_from_gmt(tmpfile.name, header=0)
assert df.columns.tolist() == ["lon", "lat", "z", "text"]
# pd.read_csv() can't parse the header line with a leading '#'.
# So, we need to skip the header line and manually set the column names.
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
expected_df.columns = df.columns.tolist()
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header_greater_than_nheaders():
"""
Test passing a header line number that is greater than the number of header lines.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

# Parse column names from the second header line.
df = dataframe_from_gmt(tmpfile.name, header=1)
# There is only one header line, so the column names should be default.
assert df.columns.tolist() == [0, 1, 2, 3]
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_header_too_many_names():
"""
Test passing a header line with more column names than the number of columns.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
with Path(tmpfile.name).open(mode="w") as fp:
print("# lon lat z text1 text2", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)

df = dataframe_from_gmt(tmpfile.name, header=0)
assert df.columns.tolist() == ["lon", "lat", "z", "text1"]
# pd.read_csv() can't parse the header line with a leading '#'.
# So, we need to skip the header line and manually set the column names.
expected_df = dataframe_from_pandas(tmpfile.name, header=None)
expected_df.columns = df.columns.tolist()
pd.testing.assert_frame_equal(df, expected_df)


def test_dataset_to_strings_with_none_values():
"""
Test that None values in the trailing text doesn't raise an exception.
Expand Down