Skip to content

Commit

Permalink
Merge branch 'main' into pandas-2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman authored Nov 9, 2024
2 parents 6083058 + 189f376 commit e812329
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ jobs:
sphinx-copybutton
sphinx-design
sphinx-gallery
sphinx_rtd_theme
sphinx_rtd_theme<3.0
# Download cached remote files (artifacts) from GitHub
- name: Download remote data from GitHub
Expand Down
2 changes: 1 addition & 1 deletion ci/requirements/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ dependencies:
- sphinx-copybutton
- sphinx-design
- sphinx-gallery
- sphinx_rtd_theme
- sphinx_rtd_theme<3.0
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies:
- sphinx-copybutton
- sphinx-design
- sphinx-gallery>=0.17.0
- sphinx_rtd_theme
- sphinx_rtd_theme<3.0
# Dev dependencies (type hints)
- mypy
- pandas-stubs
4 changes: 0 additions & 4 deletions pygmt/tests/baseline/test_figure_shift_origin.png.dvc

This file was deleted.

5 changes: 5 additions & 0 deletions pygmt/tests/baseline/test_shift_origin.png.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
outs:
- md5: 39b241fdd879271cf1e8cf1f73454706
size: 9910
hash: md5
path: test_shift_origin.png
240 changes: 240 additions & 0 deletions pygmt/tests/test_clib_to_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""
Tests for the _to_numpy function in the clib.conversion module.
"""

import sys

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
from packaging.version import Version
from pygmt.clib.conversion import _to_numpy

try:
import pyarrow as pa

_HAS_PYARROW = True
except ImportError:
_HAS_PYARROW = False


def _check_result(result, expected_dtype):
"""
A helper function to check if the result of the _to_numpy function is a C-contiguous
NumPy array with the expected dtype.
"""
assert isinstance(result, np.ndarray)
assert result.flags.c_contiguous
assert result.dtype.type == expected_dtype


########################################################################################
# Test the _to_numpy function with Python built-in types.
########################################################################################
@pytest.mark.parametrize(
("data", "expected_dtype"),
[
pytest.param(
[1, 2, 3],
np.int32
if sys.platform == "win32" and Version(np.__version__) < Version("2.0")
else np.int64,
id="int",
),
pytest.param([1.0, 2.0, 3.0], np.float64, id="float"),
pytest.param(
[complex(+1), complex(-2j), complex("-Infinity+NaNj")],
np.complex128,
id="complex",
),
],
)
def test_to_numpy_python_types_numeric(data, expected_dtype):
"""
Test the _to_numpy function with Python built-in numeric types.
"""
result = _to_numpy(data)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, data)


########################################################################################
# Test the _to_numpy function with NumPy arrays.
#
# There are 24 fundamental dtypes in NumPy. Not all of them are supported by PyGMT.
#
# - Numeric dtypes:
# - int8, int16, int32, int64, longlong
# - uint8, uint16, uint32, uint64, ulonglong
# - float16, float32, float64, longdouble
# - complex64, complex128, clongdouble
# - bool
# - datetime64, timedelta64
# - str_
# - bytes_
# - object_
# - void
#
# Reference: https://numpy.org/doc/2.1/reference/arrays.scalars.html
########################################################################################
np_dtype_params = [
pytest.param(np.int8, np.int8, id="int8"),
pytest.param(np.int16, np.int16, id="int16"),
pytest.param(np.int32, np.int32, id="int32"),
pytest.param(np.int64, np.int64, id="int64"),
pytest.param(np.longlong, np.longlong, id="longlong"),
pytest.param(np.uint8, np.uint8, id="uint8"),
pytest.param(np.uint16, np.uint16, id="uint16"),
pytest.param(np.uint32, np.uint32, id="uint32"),
pytest.param(np.uint64, np.uint64, id="uint64"),
pytest.param(np.ulonglong, np.ulonglong, id="ulonglong"),
pytest.param(np.float16, np.float16, id="float16"),
pytest.param(np.float32, np.float32, id="float32"),
pytest.param(np.float64, np.float64, id="float64"),
pytest.param(np.longdouble, np.longdouble, id="longdouble"),
pytest.param(np.complex64, np.complex64, id="complex64"),
pytest.param(np.complex128, np.complex128, id="complex128"),
pytest.param(np.clongdouble, np.clongdouble, id="clongdouble"),
]


@pytest.mark.parametrize(("dtype", "expected_dtype"), np_dtype_params)
def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
"""
Test the _to_numpy function with NumPy arrays of NumPy numeric dtypes.
Test both 1-D and 2-D arrays which are not C-contiguous.
"""
# 1-D array that is not C-contiguous
array = np.array([1, 2, 3, 4, 5, 6], dtype=dtype)[::2]
assert array.flags.c_contiguous is False
result = _to_numpy(array)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, array, strict=True)

# 2-D array that is not C-contiguous
array = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype)[::2, ::2]
assert array.flags.c_contiguous is False
result = _to_numpy(array)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, array, strict=True)


########################################################################################
# Test the _to_numpy function with pandas.Series.
#
# In pandas, dtype can be specified by
#
# 1. NumPy dtypes (see above)
# 2. pandas dtypes
# 3. PyArrow types (see below)
#
# pandas provides following dtypes:
#
# - Numeric dtypes:
# - Int8, Int16, Int32, Int64
# - UInt8, UInt16, UInt32, UInt64
# - Float32, Float64
# - DatetimeTZDtype
# - PeriodDtype
# - IntervalDtype
# - StringDtype
# - CategoricalDtype
# - SparseDtype
# - BooleanDtype
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
#
# References:
# 1. https://pandas.pydata.org/docs/reference/arrays.html
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
# 3. https://pandas.pydata.org/docs/user_guide/pyarrow.html
########################################################################################
@pytest.mark.parametrize(("dtype", "expected_dtype"), np_dtype_params)
def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
"""
Test the _to_numpy function with pandas.Series of NumPy numeric dtypes.
"""
series = pd.Series([1, 2, 3, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous
result = _to_numpy(series)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, series)


########################################################################################
# Test the _to_numpy function with PyArrow arrays.
#
# PyArrow provides the following types:
#
# - Numeric types:
# - int8, int16, int32, int64
# - uint8, uint16, uint32, uint64
# - float16, float32, float64
#
# In PyArrow, array types can be specified in two ways:
#
# - Using string aliases (e.g., "int8")
# - Using pyarrow.DataType (e.g., ``pa.int8()``)
#
# Reference: https://arrow.apache.org/docs/python/api/datatypes.html
########################################################################################
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
pytest.param("int8", np.int8, id="int8"),
pytest.param("int16", np.int16, id="int16"),
pytest.param("int32", np.int32, id="int32"),
pytest.param("int64", np.int64, id="int64"),
pytest.param("uint8", np.uint8, id="uint8"),
pytest.param("uint16", np.uint16, id="uint16"),
pytest.param("uint32", np.uint32, id="uint32"),
pytest.param("uint64", np.uint64, id="uint64"),
pytest.param("float16", np.float16, id="float16"),
pytest.param("float32", np.float32, id="float32"),
pytest.param("float64", np.float64, id="float64"),
],
)
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype):
"""
Test the _to_numpy function with PyArrow arrays of PyArrow numeric types.
"""
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
if dtype == "float16": # float16 needs special handling
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
data = np.array(data, dtype=np.float16)
array = pa.array(data, type=dtype)[::2]
result = _to_numpy(array)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, array)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
pytest.param("int8", np.float64, id="int8"),
pytest.param("int16", np.float64, id="int16"),
pytest.param("int32", np.float64, id="int32"),
pytest.param("int64", np.float64, id="int64"),
pytest.param("uint8", np.float64, id="uint8"),
pytest.param("uint16", np.float64, id="uint16"),
pytest.param("uint32", np.float64, id="uint32"),
pytest.param("uint64", np.float64, id="uint64"),
pytest.param("float16", np.float16, id="float16"),
pytest.param("float32", np.float32, id="float32"),
pytest.param("float64", np.float64, id="float64"),
],
)
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
"""
Test the _to_numpy function with PyArrow arrays of PyArrow numeric types and NA.
"""
data = [1.0, 2.0, None, 4.0, 5.0, 6.0]
if dtype == "float16": # float16 needs special handling
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
data = np.array(data, dtype=np.float16)
array = pa.array(data, type=dtype)[::2]
result = _to_numpy(array)
_check_result(result, expected_dtype)
npt.assert_array_equal(result, array)
36 changes: 0 additions & 36 deletions pygmt/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,26 +292,6 @@ def test_figure_show():
fig.show()


@pytest.mark.mpl_image_compare
def test_figure_shift_origin():
"""
Test if fig.shift_origin works.
"""
kwargs = {"region": [0, 3, 0, 5], "projection": "X3c/5c", "frame": 0}
fig = Figure()
# First call shift_origin without projection and region.
# Test issue https://github.com/GenericMappingTools/pygmt/issues/514
fig.shift_origin(xshift="2c", yshift="3c")
fig.basemap(**kwargs)
fig.shift_origin(xshift="4c")
fig.basemap(**kwargs)
fig.shift_origin(yshift="6c")
fig.basemap(**kwargs)
fig.shift_origin(xshift="-4c", yshift="6c")
fig.basemap(**kwargs)
return fig


def test_figure_show_invalid_method():
"""
Test to check if an error is raised when an invalid method is passed to show.
Expand Down Expand Up @@ -407,22 +387,6 @@ def test_invalid_method(self):
set_display(method="invalid")


def test_figure_unsupported_xshift_yshift():
"""
Raise an exception if X/Y/xshift/yshift is used.
"""
fig = Figure()
fig.basemap(region=[0, 1, 0, 1], projection="X1c/1c", frame=True)
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", xshift="3c")
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", X="3c")
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", yshift="3c")
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", Y="3c")


class TestGetDefaultDisplayMethod:
"""
Test the _get_default_display_method function.
Expand Down
43 changes: 43 additions & 0 deletions pygmt/tests/test_shift_origin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Test Figure.shift_origin.
"""

import pytest
from pygmt.exceptions import GMTInvalidInput
from pygmt.figure import Figure


@pytest.mark.mpl_image_compare
def test_shift_origin():
"""
Test if fig.shift_origin works.
"""
kwargs = {"region": [0, 3, 0, 5], "projection": "X3c/5c", "frame": 0}
fig = Figure()
# First call shift_origin without projection and region.
# Test issue https://github.com/GenericMappingTools/pygmt/issues/514
fig.shift_origin(xshift="2c", yshift="3c")
fig.basemap(**kwargs)
fig.shift_origin(xshift="4c")
fig.basemap(**kwargs)
fig.shift_origin(yshift="6c")
fig.basemap(**kwargs)
fig.shift_origin(xshift="-4c", yshift="6c")
fig.basemap(**kwargs)
return fig


def test_shift_origin_unsupported_xshift_yshift():
"""
Raise an exception if X/Y/xshift/yshift is used.
"""
fig = Figure()
fig.basemap(region=[0, 1, 0, 1], projection="X1c/1c", frame=True)
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", xshift="3c")
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", X="3c")
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", yshift="3c")
with pytest.raises(GMTInvalidInput):
fig.plot(x=1, y=1, style="c3c", Y="3c")

0 comments on commit e812329

Please sign in to comment.