Skip to content

Commit

Permalink
Correctly reserve the grid data dtype by converting ctypes array to n…
Browse files Browse the repository at this point in the history
…umpy array with np.ctypeslib.as_array (#3446)
  • Loading branch information
seisman authored Sep 24, 2024
1 parent f7110e2 commit 7544245
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 41 deletions.
62 changes: 32 additions & 30 deletions pygmt/datatypes/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,15 @@ class _GMT_GRID(ctp.Structure): # noqa: N801
... print(header.pad[:])
... print(header.mem_layout, header.nan_value, header.xy_off)
... # The x and y coordinates
... print(grid.x[: header.n_columns])
... print(grid.y[: header.n_rows])
... x = np.ctypeslib.as_array(grid.x, shape=(header.n_columns,)).copy()
... y = np.ctypeslib.as_array(grid.y, shape=(header.n_rows,)).copy()
... # The data array (with paddings)
... data = np.reshape(
... grid.data[: header.mx * header.my], (header.my, header.mx)
... )
... data = np.ctypeslib.as_array(
... grid.data, shape=(header.my, header.mx)
... ).copy()
... # The data array (without paddings)
... pad = header.pad[:]
... data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1]]
... print(data)
14 8 1
[-55.0, -47.0, -24.0, -10.0] 190.0 981.0 [1.0, 1.0]
1.0 0.0
Expand All @@ -61,22 +60,26 @@ class _GMT_GRID(ctp.Structure): # noqa: N801
18 1 12 18
[2, 2, 2, 2]
b'' nan 0.5
[-54.5, -53.5, -52.5, -51.5, -50.5, -49.5, -48.5, -47.5]
[-10.5, -11.5, -12.5, -13.5, -14.5, -15.5, ..., -22.5, -23.5]
[[347.5 331.5 309. 282. 190. 208. 299.5 348. ]
[349. 313. 325.5 247. 191. 225. 260. 452.5]
[345.5 320. 335. 292. 207.5 247. 325. 346.5]
[450.5 395.5 366. 248. 250. 354.5 550. 797.5]
[494.5 488.5 357. 254.5 286. 484.5 653.5 930. ]
[601. 526.5 535. 299. 398.5 645. 797.5 964. ]
[308. 595.5 555.5 556. 580. 770. 927. 920. ]
[521.5 682.5 796. 886. 571.5 638.5 739.5 881.5]
[310. 521.5 757. 570.5 538.5 524. 686.5 794. ]
[561.5 539. 446.5 481.5 439.5 553. 726.5 981. ]
[557. 435. 385.5 345.5 413.5 496. 519.5 833.5]
[373. 367.5 349. 352.5 419.5 428. 570. 667.5]
[383. 284.5 344.5 394. 491. 556.5 578.5 618.5]
[347.5 344.5 386. 640.5 617. 579. 646.5 671. ]]
>>> x
array([-54.5, -53.5, -52.5, -51.5, -50.5, -49.5, -48.5, -47.5])
>>> y
array([-10.5, -11.5, -12.5, -13.5, ..., -20.5, -21.5, -22.5, -23.5])
>>> data
array([[347.5, 331.5, 309. , 282. , 190. , 208. , 299.5, 348. ],
[349. , 313. , 325.5, 247. , 191. , 225. , 260. , 452.5],
[345.5, 320. , 335. , 292. , 207.5, 247. , 325. , 346.5],
[450.5, 395.5, 366. , 248. , 250. , 354.5, 550. , 797.5],
[494.5, 488.5, 357. , 254.5, 286. , 484.5, 653.5, 930. ],
[601. , 526.5, 535. , 299. , 398.5, 645. , 797.5, 964. ],
[308. , 595.5, 555.5, 556. , 580. , 770. , 927. , 920. ],
[521.5, 682.5, 796. , 886. , 571.5, 638.5, 739.5, 881.5],
[310. , 521.5, 757. , 570.5, 538.5, 524. , 686.5, 794. ],
[561.5, 539. , 446.5, 481.5, 439.5, 553. , 726.5, 981. ],
[557. , 435. , 385.5, 345.5, 413.5, 496. , 519.5, 833.5],
[373. , 367.5, 349. , 352.5, 419.5, 428. , 570. , 667.5],
[383. , 284.5, 344.5, 394. , 491. , 556.5, 578.5, 618.5],
[347.5, 344.5, 386. , 640.5, 617. , 579. , 646.5, 671. ]],
dtype=float32)
"""

_fields_: ClassVar = [
Expand Down Expand Up @@ -126,7 +129,8 @@ def to_dataarray(self) -> xr.DataArray:
[450.5, 395.5, 366. , 248. , 250. , 354.5, 550. , 797.5],
[345.5, 320. , 335. , 292. , 207.5, 247. , 325. , 346.5],
[349. , 313. , 325.5, 247. , 191. , 225. , 260. , 452.5],
[347.5, 331.5, 309. , 282. , 190. , 208. , 299.5, 348. ]])
[347.5, 331.5, 309. , 282. , 190. , 208. , 299.5, 348. ]],
dtype=float32)
Coordinates:
* lat (lat) float64... -23.5 -22.5 -21.5 -20.5 ... -12.5 -11.5 -10.5
* lon (lon) float64... -54.5 -53.5 -52.5 -51.5 -50.5 -49.5 -48.5 -47.5
Expand Down Expand Up @@ -169,16 +173,14 @@ def to_dataarray(self) -> xr.DataArray:
# Get dimensions and their attributes from the header.
dims, dim_attrs = header.dims, header.dim_attrs
# The coordinates, given as a tuple of the form (dims, data, attrs)
coords = [
(dims[0], self.y[: header.n_rows], dim_attrs[0]),
(dims[1], self.x[: header.n_columns], dim_attrs[1]),
]
x = np.ctypeslib.as_array(self.x, shape=(header.n_columns,)).copy()
y = np.ctypeslib.as_array(self.y, shape=(header.n_rows,)).copy()
coords = [(dims[0], y, dim_attrs[0]), (dims[1], x, dim_attrs[1])]

# The data array without paddings
data = np.ctypeslib.as_array(self.data, shape=(header.my, header.mx)).copy()
pad = header.pad[:]
data = np.reshape(self.data[: header.mx * header.my], (header.my, header.mx))[
pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1]
]
data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1]]

# Create the xarray.DataArray object
grid = xr.DataArray(
Expand Down
19 changes: 9 additions & 10 deletions pygmt/datatypes/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ class _GMT_IMAGE(ctp.Structure): # noqa: N801
... # Image-specific attributes.
... print(image.type, image.n_indexed_colors)
... # The x and y coordinates
... x = image.x[: header.n_columns]
... y = image.y[: header.n_rows]
... x = np.ctypeslib.as_array(image.x, shape=(header.n_columns,)).copy()
... y = np.ctypeslib.as_array(image.y, shape=(header.n_rows,)).copy()
... # The data array (with paddings)
... data = np.reshape(
... image.data[: header.n_bands * header.mx * header.my],
... (header.my, header.mx, header.n_bands),
... )
... data = np.ctypeslib.as_array(
... image.data, shape=(header.my, header.mx, header.n_bands)
... ).copy()
... # The data array (without paddings)
... pad = header.pad[:]
... data = data[pad[2] : header.my - pad[3], pad[0] : header.mx - pad[1], :]
Expand All @@ -60,10 +59,10 @@ class _GMT_IMAGE(ctp.Structure): # noqa: N801
[2, 2, 2, 2]
b'BRPa' 0.5
1 0
>>> x
[-179.5, -178.5, ..., 178.5, 179.5]
>>> y
[89.5, 88.5, ..., -88.5, -89.5]
>>> x # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
array([-179.5, -178.5, ..., 178.5, 179.5])
>>> y # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS
array([ 89.5, 88.5, ..., -88.5, -89.5])
>>> data.shape
(180, 360, 3)
>>> data.min(), data.max()
Expand Down
2 changes: 1 addition & 1 deletion pygmt/tests/test_sphinterpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def test_sphinterpolate_no_outgrid(mars):
npt.assert_allclose(temp_grid.max(), 14628.144)
npt.assert_allclose(temp_grid.min(), -6908.1987)
npt.assert_allclose(temp_grid.median(), 118.96849)
npt.assert_allclose(temp_grid.mean(), 272.60578)
npt.assert_allclose(temp_grid.mean(), 272.60593)

0 comments on commit 7544245

Please sign in to comment.