From 32eee268624d6ff8384ac52a3722682d75d5c78f Mon Sep 17 00:00:00 2001 From: Dongdong Tian Date: Mon, 8 Jul 2024 21:23:52 +0800 Subject: [PATCH] Merge the WIP wrapper of GMT_IMAGE for further experiments --- pygmt/clib/session.py | 22 +++-- pygmt/datatypes/__init__.py | 1 + pygmt/datatypes/header.py | 3 +- pygmt/datatypes/image.py | 182 ++++++++++++++++++++++++++++++++++++ 4 files changed, 199 insertions(+), 9 deletions(-) create mode 100644 pygmt/datatypes/image.py diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 4c81d0b70d6..b1a4d5d8dc0 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -26,7 +26,7 @@ vectors_to_arrays, ) from pygmt.clib.loading import load_libgmt -from pygmt.datatypes import _GMT_DATASET, _GMT_GRID +from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE from pygmt.exceptions import ( GMTCLibError, GMTCLibNoSessionError, @@ -1769,7 +1769,9 @@ def virtualfile_from_data( @contextlib.contextmanager def virtualfile_out( - self, kind: Literal["dataset", "grid"] = "dataset", fname: str | None = None + self, + kind: Literal["dataset", "grid", "image"] = "dataset", + fname: str | None = None, ): r""" Create a virtual file or an actual file for storing output data. @@ -1782,8 +1784,8 @@ def virtualfile_out( Parameters ---------- kind - The data kind of the virtual file to create. Valid values are ``"dataset"`` - and ``"grid"``. Ignored if ``fname`` is specified. + The data kind of the virtual file to create. Valid values are ``"dataset"``, + ``"grid"``, and ``"image"``. Ignored if ``fname`` is specified. fname The name of the actual file to write the output data. No virtual file will be created. @@ -1826,8 +1828,11 @@ def virtualfile_out( family, geometry = { "dataset": ("GMT_IS_DATASET", "GMT_IS_PLP"), "grid": ("GMT_IS_GRID", "GMT_IS_SURFACE"), + "image": ("GMT_IS_IMAGE", "GMT_IS_SURFACE"), }[kind] - with self.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile: + with self.open_virtualfile( + family, geometry, "GMT_OUT|GMT_IS_REFERENCE", None + ) as vfile: yield vfile def inquire_virtualfile(self, vfname: str) -> int: @@ -1873,7 +1878,8 @@ def read_virtualfile( Name of the virtual file to read. kind Cast the data into a GMT data container. Valid values are ``"dataset"``, - ``"grid"`` and ``None``. If ``None``, will return a ctypes void pointer. + ``"grid"``, ``"image"`` and ``None``. If ``None``, will return a ctypes void + pointer. Examples -------- @@ -1921,9 +1927,9 @@ def read_virtualfile( # _GMT_DATASET). if kind is None: # Return the ctypes void pointer return pointer - if kind in {"image", "cube"}: + if kind == "cube": raise NotImplementedError(f"kind={kind} is not supported yet.") - dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID}[kind] + dtype = {"dataset": _GMT_DATASET, "grid": _GMT_GRID, "image": _GMT_IMAGE}[kind] return ctp.cast(pointer, ctp.POINTER(dtype)) def virtualfile_to_dataset( diff --git a/pygmt/datatypes/__init__.py b/pygmt/datatypes/__init__.py index 237a050a9f7..3489dd19d10 100644 --- a/pygmt/datatypes/__init__.py +++ b/pygmt/datatypes/__init__.py @@ -4,3 +4,4 @@ from pygmt.datatypes.dataset import _GMT_DATASET from pygmt.datatypes.grid import _GMT_GRID +from pygmt.datatypes.image import _GMT_IMAGE diff --git a/pygmt/datatypes/header.py b/pygmt/datatypes/header.py index 04e10ac0c72..ab109521131 100644 --- a/pygmt/datatypes/header.py +++ b/pygmt/datatypes/header.py @@ -203,7 +203,8 @@ def data_attrs(self) -> dict[str, Any]: Attributes for the data variable from the grid header. """ attrs: dict[str, Any] = {} - attrs["Conventions"] = "CF-1.7" + if self.type == 18: # Grid file format: ns = GMT netCDF format + attrs["Conventions"] = "CF-1.7" attrs["title"] = self.title.decode() attrs["history"] = self.command.decode() attrs["description"] = self.remark.decode() diff --git a/pygmt/datatypes/image.py b/pygmt/datatypes/image.py new file mode 100644 index 00000000000..ded7c818ae3 --- /dev/null +++ b/pygmt/datatypes/image.py @@ -0,0 +1,182 @@ +""" +Wrapper for the GMT_IMAGE data type. +""" + +import ctypes as ctp +from typing import ClassVar + +import numpy as np +import xarray as xr +from pygmt.datatypes.header import _GMT_GRID_HEADER + + +class _GMT_IMAGE(ctp.Structure): # noqa: N801 + """ + GMT image data structure. + + Examples + -------- + >>> from pygmt.clib import Session + >>> import numpy as np + >>> import xarray as xr + + >>> with Session() as lib: + ... with lib.virtualfile_out(kind="image") as voutimg: + ... lib.call_module("read", f"@earth_day_01d {voutimg} -Ti") + ... # Read the image from the virtual file + ... image = lib.read_virtualfile(vfname=voutimg, kind="image").contents + ... # The image header + ... header = image.header.contents + ... # Access the header properties + ... print(image.type, header.n_bands, header.n_rows, header.n_columns) + ... print(header.pad[:]) + ... # The x and y coordinates + ... x = image.x[: header.n_columns] + ... y = image.y[: header.n_rows] + ... # The data array (with paddings) + ... data = np.reshape( + ... image.data[: header.n_bands * header.mx * header.my], + ... (header.my, header.mx, header.n_bands), + ... ) + ... # The data array (without paddings) + ... pad = header.pad[:] + ... data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :] + ... print(data.shape) + 1 3 180 360 + [2, 2, 2, 2] + (180, 360, 3) + """ + + _fields_: ClassVar = [ + # Data type, e.g. GMT_FLOAT + ("type", ctp.c_int), + # Array with color lookup values + ("colormap", ctp.POINTER(ctp.c_int)), + # Number of colors in a paletted image + ("n_indexed_colors", ctp.c_int), + # Pointer to full GMT header for the image + ("header", ctp.POINTER(_GMT_GRID_HEADER)), + # Pointer to actual image + ("data", ctp.POINTER(ctp.c_ubyte)), + # Pointer to an optional transparency layer stored in a separate variable + ("alpha", ctp.POINTER(ctp.c_ubyte)), + # Color interpolation + ("color_interp", ctp.c_char_p), + # Pointer to the x-coordinate vector + ("x", ctp.POINTER(ctp.c_double)), + # Pointer to the y-coordinate vector + ("y", ctp.POINTER(ctp.c_double)), + # Book-keeping variables "hidden" from the API + ("hidden", ctp.c_void_p), + ] + + def to_dataarray(self) -> xr.DataArray: + """ + Convert a _GMT_IMAGE object to an :class:`xarray.DataArray` object. + + Returns + ------- + dataarray + A :class:`xarray.DataArray` object. + + Examples + -------- + >>> from pygmt.clib import Session + >>> with Session() as lib: + ... with lib.virtualfile_out(kind="image") as voutimg: + ... lib.call_module("read", ["@earth_day_01d", voutimg, "-Ti"]) + ... # Read the image from the virtual file + ... image = lib.read_virtualfile(voutimg, kind="image") + ... # Convert to xarray.DataArray and use it later + ... da = image.contents.to_dataarray() + >>> da # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Size: 2MB + array([[[ 10, 10, 10, ..., 10, 10, 10], + [ 10, 10, 10, ..., 10, 10, 10], + [ 10, 10, 10, ..., 10, 10, 10], + ..., + [192, 193, 193, ..., 193, 192, 191], + [204, 206, 206, ..., 205, 206, 204], + [208, 210, 210, ..., 210, 210, 208]], + + [[ 10, 10, 10, ..., 10, 10, 10], + [ 10, 10, 10, ..., 10, 10, 10], + [ 10, 10, 10, ..., 10, 10, 10], + ..., + [186, 187, 188, ..., 187, 186, 185], + [196, 198, 198, ..., 197, 197, 196], + [199, 201, 201, ..., 201, 202, 199]], + + [[ 51, 51, 51, ..., 51, 51, 51], + [ 51, 51, 51, ..., 51, 51, 51], + [ 51, 51, 51, ..., 51, 51, 51], + ..., + [177, 179, 179, ..., 178, 177, 177], + [185, 187, 187, ..., 187, 186, 185], + [189, 191, 191, ..., 191, 191, 189]]]) + Coordinates: + * x (x) float64 3kB -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5 + * y (y) float64 1kB 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5 + * band (band) uint8 3B 0 1 2 + Attributes: + title: + history: + description: + long_name: z + actual_range: [ 1.79769313e+308 -1.79769313e+308] + + >>> da.coords["x"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Size: 3kB + array([-179.5, -178.5, -177.5, ..., 177.5, 178.5, 179.5]) + Coordinates: + * x (x) float64 3kB -179.5 -178.5 -177.5 -176.5 ... 177.5 178.5 179.5 + + >>> da.coords["y"] # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Size: 1kB + array([ 89.5, 88.5, 87.5, 86.5, 85.5, 84.5, 83.5, 82.5, 81.5, 80.5, + 79.5, 78.5, 77.5, 76.5, 75.5, 74.5, 73.5, 72.5, 71.5, 70.5, + 69.5, 68.5, 67.5, 66.5, 65.5, 64.5, 63.5, 62.5, 61.5, 60.5, + ... + -0.5, -1.5, -2.5, -3.5, -4.5, -5.5, -6.5, -7.5, -8.5, -9.5, + ... + -60.5, -61.5, -62.5, -63.5, -64.5, -65.5, -66.5, -67.5, -68.5, -69.5, + -70.5, -71.5, -72.5, -73.5, -74.5, -75.5, -76.5, -77.5, -78.5, -79.5, + -80.5, -81.5, -82.5, -83.5, -84.5, -85.5, -86.5, -87.5, -88.5, -89.5]) + Coordinates: + * y (y) float64 1kB 89.5 88.5 87.5 86.5 ... -86.5 -87.5 -88.5 -89.5 + + >>> da.gmt.registration, da.gmt.gtype + (1, 0) + """ + + # Get image header + header: _GMT_GRID_HEADER = self.header.contents + + # Get DataArray without padding + pad = header.pad[:] + data: np.ndarray = np.reshape( + a=self.data[: header.n_bands * header.mx * header.my], + newshape=(header.my, header.mx, header.n_bands), + )[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :] + + # Get x and y coordinates + coords: dict[str, list | np.ndarray] = { + "x": self.x[: header.n_columns], + "y": self.y[: header.n_rows], + "band": np.array([0, 1, 2], dtype=np.uint8), + } + + # Create the xarray.DataArray object + image = xr.DataArray( + data=data, + coords=coords, + dims=("y", "x", "band"), + name=header.name, + attrs=header.data_attrs, + ).transpose("band", "y", "x") + + # Set GMT accessors. + # Must put at the end, otherwise info gets lost after certain image operations. + image.gmt.registration = header.registration + image.gmt.gtype = header.gtype + return image