Skip to content

Commit

Permalink
Wrap the GMT API function GMT_Read_Data to read data into GMT data co…
Browse files Browse the repository at this point in the history
…ntainers (#3324)

Co-authored-by: Wei Ji <[email protected]>
  • Loading branch information
seisman and weiji14 authored Jul 19, 2024
1 parent 6c436a3 commit 917b3aa
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ Low level access (these are mostly used by the :mod:`pygmt.clib` package):
clib.Session.put_matrix
clib.Session.put_strings
clib.Session.put_vector
clib.Session.read_data
clib.Session.write_data
clib.Session.open_virtualfile
clib.Session.read_virtualfile
Expand Down
4 changes: 3 additions & 1 deletion pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def as_c_contiguous(array):
return array


def sequence_to_ctypes_array(sequence: Sequence, ctype, size: int) -> ctp.Array | None:
def sequence_to_ctypes_array(
sequence: Sequence | None, ctype, size: int
) -> ctp.Array | None:
"""
Convert a sequence of numbers into a ctypes array variable.
Expand Down
95 changes: 94 additions & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pathlib
import sys
import warnings
from collections.abc import Generator
from collections.abc import Generator, Sequence
from typing import Literal

import numpy as np
Expand Down Expand Up @@ -1067,6 +1067,99 @@ def put_matrix(self, dataset, matrix, pad=0):
if status != 0:
raise GMTCLibError(f"Failed to put matrix of type {matrix.dtype}.")

def read_data(
self,
infile: str,
kind: Literal["dataset", "grid"],
family: str | None = None,
geometry: str | None = None,
mode: str = "GMT_READ_NORMAL",
region: Sequence[float] | None = None,
data=None,
):
"""
Read a data file into a GMT data container.
Wraps ``GMT_Read_Data`` but only allows reading from a file. The function
definition is different from the original C API function.
Parameters
----------
infile
The input file name.
kind
The data kind of the input file. Valid values are ``"dataset"`` and
``"grid"``.
family
A valid GMT data family name (e.g., ``"GMT_IS_DATASET"``). See the
``FAMILIES`` attribute for valid names. If ``None``, will determine the data
family from the ``kind`` parameter.
geometry
A valid GMT data geometry name (e.g., ``"GMT_IS_POINT"``). See the
``GEOMETRIES`` attribute for valid names. If ``None``, will determine the
data geometry from the ``kind`` parameter.
mode
How the data is to be read from the file. This option varies depending on
the given family. See the
:gmt-docs:`GMT API documentation <devdocs/api.html#import-from-a-file-stream-or-handle>`
for details. Default is ``GMT_READ_NORMAL`` which corresponds to the default
read mode value of 0 in the ``GMT_enum_read`` enum.
region
Subregion of the data, in the form of [xmin, xmax, ymin, ymax, zmin, zmax].
If ``None``, the whole data is read.
data
``None`` or the pointer returned by this function after a first call. It's
useful when reading grids/images/cubes in two steps (get a grid/image/cube
structure with a header, then read the data).
Returns
-------
Pointer to the data container, or ``None`` if there were errors.
Raises
------
GMTCLibError
If the GMT API function fails to read the data.
""" # noqa: W505
c_read_data = self.get_libgmt_func(
"GMT_Read_Data",
argtypes=[
ctp.c_void_p, # V_API
ctp.c_uint, # family
ctp.c_uint, # method
ctp.c_uint, # geometry
ctp.c_uint, # mode
ctp.POINTER(ctp.c_double), # wesn
ctp.c_char_p, # infile
ctp.c_void_p, # data
],
restype=ctp.c_void_p, # data_ptr
)

# Determine the family, geometry and data container from kind
_family, _geometry, dtype = {
"dataset": ("GMT_IS_DATASET", "GMT_IS_PLP", _GMT_DATASET),
"grid": ("GMT_IS_GRID", "GMT_IS_SURFACE", _GMT_GRID),
}[kind]
if family is None:
family = _family
if geometry is None:
geometry = _geometry

data_ptr = c_read_data(
self.session_pointer,
self[family],
self["GMT_IS_FILE"], # Reading from a file
self[geometry],
self[mode],
sequence_to_ctypes_array(region, ctp.c_double, 6),
infile.encode(),
data,
)
if data_ptr is None:
raise GMTCLibError(f"Failed to read dataset from '{infile}'.")
return ctp.cast(data_ptr, ctp.POINTER(dtype))

def write_data(self, family, geometry, mode, wesn, output, data):
"""
Write a GMT data container to a file.
Expand Down
141 changes: 141 additions & 0 deletions pygmt/tests/test_clib_read_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Test the Session.read_data method.
"""

from pathlib import Path

import pandas as pd
import pytest
import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTCLibError
from pygmt.helpers import GMTTempFile
from pygmt.io import load_dataarray
from pygmt.src import which

try:
import rioxarray # noqa: F401

_HAS_RIOXARRAY = True
except ImportError:
_HAS_RIOXARRAY = False


@pytest.fixture(scope="module", name="expected_xrgrid")
def fixture_expected_xrgrid():
"""
The expected xr.DataArray object for the static_earth_relief.nc file.
"""
return load_dataarray(which("@static_earth_relief.nc"))


def test_clib_read_data_dataset():
"""
Test the Session.read_data method for datasets.
"""
with GMTTempFile(suffix=".txt") as tmpfile:
# Prepare the sample data file
with Path(tmpfile.name).open(mode="w", encoding="utf-8") as fp:
print("# x y z name", file=fp)
print(">", file=fp)
print("1.0 2.0 3.0 TEXT1 TEXT23", file=fp)
print("4.0 5.0 6.0 TEXT4 TEXT567", file=fp)
print(">", file=fp)
print("7.0 8.0 9.0 TEXT8 TEXT90", file=fp)
print("10.0 11.0 12.0 TEXT123 TEXT456789", file=fp)

with Session() as lib:
ds = lib.read_data(tmpfile.name, kind="dataset").contents
df = ds.to_dataframe(header=0)
expected_df = pd.DataFrame(
data={
"x": [1.0, 4.0, 7.0, 10.0],
"y": [2.0, 5.0, 8.0, 11.0],
"z": [3.0, 6.0, 9.0, 12.0],
"name": pd.Series(
[
"TEXT1 TEXT23",
"TEXT4 TEXT567",
"TEXT8 TEXT90",
"TEXT123 TEXT456789",
],
dtype=pd.StringDtype(),
),
}
)
pd.testing.assert_frame_equal(df, expected_df)


def test_clib_read_data_grid(expected_xrgrid):
"""
Test the Session.read_data method for grids.
"""
with Session() as lib:
grid = lib.read_data("@static_earth_relief.nc", kind="grid").contents
xrgrid = grid.to_dataarray()
xr.testing.assert_equal(xrgrid, expected_xrgrid)
assert grid.header.contents.n_bands == 1 # Explicitly check n_bands


def test_clib_read_data_grid_two_steps(expected_xrgrid):
"""
Test the Session.read_data method for grids in two steps, first reading the header
and then the data.
"""
infile = "@static_earth_relief.nc"
with Session() as lib:
# Read the header first
data_ptr = lib.read_data(infile, kind="grid", mode="GMT_CONTAINER_ONLY")
grid = data_ptr.contents
header = grid.header.contents
assert header.n_rows == 14
assert header.n_columns == 8
assert header.wesn[:] == [-55.0, -47.0, -24.0, -10.0]
assert header.z_min == 190.0
assert header.z_max == 981.0
assert header.n_bands == 1 # Explicitly check n_bands
assert not grid.data # The data is not read yet

# Read the data
lib.read_data(infile, kind="grid", mode="GMT_DATA_ONLY", data=data_ptr)
xrgrid = data_ptr.contents.to_dataarray()
xr.testing.assert_equal(xrgrid, expected_xrgrid)


def test_clib_read_data_grid_actual_image():
"""
Test the Session.read_data method for grid, but actually the file is an image.
"""
with Session() as lib:
data_ptr = lib.read_data(
"@earth_day_01d_p", kind="grid", mode="GMT_CONTAINER_AND_DATA"
)
image = data_ptr.contents
header = image.header.contents
assert header.n_rows == 180
assert header.n_columns == 360
assert header.wesn[:] == [-180.0, 180.0, -90.0, 90.0]
# Explicitly check n_bands. Only one band is read for 3-band images.
assert header.n_bands == 1

if _HAS_RIOXARRAY: # Full check if rioxarray is installed.
xrimage = image.to_dataarray()
expected_xrimage = xr.open_dataarray(
which("@earth_day_01d_p"), engine="rasterio"
)
assert expected_xrimage.band.size == 3 # 3-band image.
xr.testing.assert_equal(
xrimage,
expected_xrimage.isel(band=0)
.drop_vars(["band", "spatial_ref"])
.sortby("y"),
)


def test_clib_read_data_fails():
"""
Test that the Session.read_data method raises an exception if there are errors.
"""
with Session() as lib:
with pytest.raises(GMTCLibError):
lib.read_data("not-exsits.txt", kind="dataset")

0 comments on commit 917b3aa

Please sign in to comment.