Skip to content

Commit

Permalink
Enable passing pyarrow.StringArray to clib.Session.put_strings
Browse files Browse the repository at this point in the history
Convert a pyarrow.StringArray via a Python list to a ctypes array in the strings_to_ctypes_array function. Updated docstrings and type hints in `clib.Session.put_strings` method and `clib.conversion.strings_to_ctypes_array` function. Added two parametrized unit tests to ensure that pyarrow.StringArray can be passed into the clib methods.
  • Loading branch information
weiji14 committed Oct 11, 2024
1 parent 07fbca6 commit d379e46
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 25 deletions.
18 changes: 14 additions & 4 deletions pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import numpy as np
from pygmt.exceptions import GMTInvalidInput

try:
import pyarrow as pa
except ImportError:
pa = None

Check warning on line 15 in pygmt/clib/conversion.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/conversion.py#L14-L15

Added lines #L14 - L15 were not covered by tests

def dataarray_to_matrix(grid):
"""
Expand Down Expand Up @@ -263,14 +267,15 @@ def sequence_to_ctypes_array(
return (ctype * size)(*sequence)


def strings_to_ctypes_array(strings: Sequence[str]) -> ctp.Array:
def strings_to_ctypes_array(strings: Sequence[str] | pa.StringArray) -> ctp.Array:
"""
Convert a sequence (e.g., a list) of strings into a ctypes array.
Convert a sequence (e.g., a list) of strings or a pyarrow.StringArray into a ctypes
array.
Parameters
----------
strings
A sequence of strings.
A sequence of strings or a pyarrow.StringArray.
Returns
-------
Expand All @@ -286,7 +291,12 @@ def strings_to_ctypes_array(strings: Sequence[str]) -> ctp.Array:
>>> [s.decode() for s in ctypes_array]
['first', 'second', 'third']
"""
return (ctp.c_char_p * len(strings))(*[s.encode() for s in strings])
try:
bytes_string_list = [s.encode() for s in strings]
except AttributeError: # 'pyarrow.StringScalar' object has no attribute 'encode'
# Convert pyarrow.StringArray to Python list first
bytes_string_list = [s.encode() for s in strings.to_pylist()]
return (ctp.c_char_p * len(strings))(*bytes_string_list)


def array_to_datetime(array):
Expand Down
43 changes: 26 additions & 17 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
tempfile_from_image,
)

try:
import pyarrow as pa
except ImportError:
pa = None

Check warning on line 40 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L39-L40

Added lines #L39 - L40 were not covered by tests

FAMILIES = [
"GMT_IS_DATASET", # Entity is a data table
"GMT_IS_GRID", # Entity is a grid
Expand Down Expand Up @@ -936,39 +941,43 @@ def put_vector(self, dataset, column, vector):
f"in column {column} of dataset."
)

def put_strings(self, dataset, family, strings):
def put_strings(
self,
dataset: ctp.c_void_p,
family: Literal["GMT_IS_VECTOR", "GMT_IS_MATRIX"],
strings: Sequence[str] | pa.StringArray,
):
"""
Attach a numpy 1-D array of dtype str as a column on a GMT dataset.
Attach a 1-D numpy array of dtype str or pyarrow.StringArray as a column on a
GMT dataset.
Use this function to attach string type numpy array data to a GMT
dataset and pass it to GMT modules. Wraps ``GMT_Put_Strings``.
Use this function to attach string type array data to a GMT dataset and pass it
to GMT modules. Wraps ``GMT_Put_Strings``.
The dataset must be created by :meth:`pygmt.clib.Session.create_data`
first.
The dataset must be created by :meth:`pygmt.clib.Session.create_data` first.
.. warning::
The numpy array must be C contiguous in memory. If it comes from a
column slice of a 2-D array, for example, you will have to make a
copy. Use :func:`numpy.ascontiguousarray` to make sure your vector
is contiguous (it won't copy if it already is).
The array must be C contiguous in memory. If it comes from a column slice of
a 2-D array, for example, you will have to make a copy. Use
:func:`numpy.ascontiguousarray` to make sure your vector is contiguous (it
won't copy if it already is).
Parameters
----------
dataset : :class:`ctypes.c_void_p`
dataset
The ctypes void pointer to a ``GMT_Dataset``. Create it with
:meth:`pygmt.clib.Session.create_data`.
family : str
family
The family type of the dataset. Can be either ``GMT_IS_VECTOR`` or
``GMT_IS_MATRIX``.
strings : numpy 1-D array
The array that will be attached to the dataset. Must be a 1-D C
contiguous array.
strings
The array that will be attached to the dataset. Must be a 1-D C contiguous
array.
Raises
------
GMTCLibError
If given invalid input or ``GMT_Put_Strings`` exits with
status != 0.
If given invalid input or ``GMT_Put_Strings`` exits with status != 0.
"""
c_put_strings = self.get_libgmt_func(
"GMT_Put_Strings",
Expand Down
25 changes: 22 additions & 3 deletions pygmt/tests/test_clib_put_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,31 @@
from pygmt import clib
from pygmt.exceptions import GMTCLibError
from pygmt.helpers import GMTTempFile
from pygmt.helpers.testing import skip_if_no

try:
import pyarrow as pa
except ImportError:
pa = None


@pytest.mark.benchmark
def test_put_strings():
@pytest.mark.parametrize(
("array_func", "dtype"),
[
pytest.param(np.array, {"dtype": str}, id="str"),
pytest.param(
getattr(pa, "array", None),
{"type": pa.string()},
marks=skip_if_no(package="pyarrow"),
id="pyarrow",
),
],
)
def test_put_strings(array_func, dtype):
"""
Check that assigning a numpy array of dtype str to a dataset works.
Check that assigning a numpy array of dtype str, or a pyarrow.StringArray to a
dataset works.
"""
with clib.Session() as lib:
dataset = lib.create_data(
Expand All @@ -24,7 +43,7 @@ def test_put_strings():
)
x = np.array([1, 2, 3, 4, 5], dtype=np.int32)
y = np.array([6, 7, 8, 9, 10], dtype=np.int32)
strings = np.array(["a", "bc", "defg", "hijklmn", "opqrst"], dtype=str)
strings = array_func(["a", "bc", "defg", "hijklmn", "opqrst"], **dtype)
lib.put_vector(dataset, column=lib["GMT_X"], vector=x)
lib.put_vector(dataset, column=lib["GMT_Y"], vector=y)
lib.put_strings(
Expand Down
2 changes: 1 addition & 1 deletion pygmt/tests/test_clib_virtualfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_virtualfile_from_vectors(dtypes):
pytest.param(np.array, {"dtype": object}, id="object"),
pytest.param(
getattr(pa, "array", None),
{}, # pa.string()
{"type": pa.string()},
marks=skip_if_no(package="pyarrow"),
id="pyarrow",
),
Expand Down

0 comments on commit d379e46

Please sign in to comment.