Skip to content

Commit

Permalink
Be more strict about the definitions of matrix/vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Jul 23, 2024
1 parent 9932f81 commit f160dd6
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 46 deletions.
23 changes: 12 additions & 11 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1701,7 +1701,7 @@ def virtualfile_in( # noqa: PLR0912
if check_kind == "raster":
valid_kinds += ("grid", "image")
elif check_kind == "vector":
valid_kinds += ("matrix", "vectors", "geojson")
valid_kinds += ("none", "matrix", "vectors", "geojson")
if kind not in valid_kinds:
raise GMTInvalidInput(
f"Unrecognized data type for {check_kind}: {type(data)}"
Expand All @@ -1714,11 +1714,9 @@ def virtualfile_in( # noqa: PLR0912
"geojson": tempfile_from_geojson,
"grid": self.virtualfile_from_grid,
"image": tempfile_from_image,
# Note: virtualfile_from_matrix is not used because a matrix can be
# converted to vectors instead, and using vectors allows for better
# handling of string type inputs (e.g. for datetime data types)
"matrix": self.virtualfile_from_vectors,
"matrix": self.virtualfile_from_matrix,
"vectors": self.virtualfile_from_vectors,
"none": self.virtualfile_from_vectors,
}[kind]

# Ensure the data is an iterable (Python list or tuple)
Expand All @@ -1733,25 +1731,28 @@ def virtualfile_in( # noqa: PLR0912
)
warnings.warn(message=msg, category=RuntimeWarning, stacklevel=2)
_data = (data,) if not isinstance(data, pathlib.PurePath) else (str(data),)
elif kind == "vectors":
elif kind == "none":
_data = [np.atleast_1d(x), np.atleast_1d(y)]
if z is not None:
_data.append(np.atleast_1d(z))
if extra_arrays:
_data.extend(extra_arrays)
elif kind == "matrix": # turn 2-D arrays into list of vectors
elif kind == "vectors":
if hasattr(data, "items") and not hasattr(data, "to_frame"):
# pandas.DataFrame or xarray.Dataset types.
# pandas.Series will be handled below like a 1-D numpy.ndarray.
_data = [array for _, array in data.items()]
elif hasattr(data, "ndim") and data.ndim == 2 and data.dtype.kind in "iuf":
else:
# Python list, tuple, numpy.ndarray, and pandas.Series types
_data = np.atleast_2d(np.asanyarray(data).T)
elif kind == "matrix":
if data.dtype.kind in "iuf":
# Just use virtualfile_from_matrix for 2-D numpy.ndarray
# which are signed integer (i), unsigned integer (u) or
# floating point (f) types
_virtualfile_from = self.virtualfile_from_matrix
_data = (data,)
else:
# Python list, tuple, numpy.ndarray, and pandas.Series types
else: # turn 2-D arrays into list of vectors
_virtualfile_from = self.virtualfile_from_vectors
_data = np.atleast_2d(np.asanyarray(data).T)

# Finally create the virtualfile from the data, to be passed into GMT
Expand Down
45 changes: 24 additions & 21 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,31 +187,30 @@ def _check_encoding(
return "ISOLatin1+"


def data_kind(
def data_kind( # noqa: PLR0911
data: Any, required: bool = True
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
) -> Literal["none", "arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
"""
Check the kind of data that is provided to a module.
Recognized data kinds are:
- ``"arg"``: bool, int or float, representing an optional argument, mainly used for
dealing with optional virtual files
- ``"none"``: None and data is required. In this case, the data is usually given via
a series of vectors (e.g., x/y/z)
- ``"arg"``: bool, int, float, or None (only when ``required`` is False),
representing an optional argument, mainly used for dealing with optional virtual
files
- ``"file"``: a string or a :class:`pathlib.PurePath` object or a sequence of them,
representing a file name or a list of file names
- ``"geojson"``: a geo-like Python object that implements ``__geo_interface__``
(e.g., geopandas.GeoDataFrame or shapely.geometry)
- ``"grid"``: a :class:`xarray.DataArray` object with dimensions not equal to 3
- ``"image"``: a :class:`xarray.DataArray` object with 3 dimensions
- ``"matrix"``: a :class:`pandas.DataFrame` object, a 2-D :class:`numpy.ndarray`,
a dictionary with array-like values, or a sequence of sequences
- ``"matrix"``: a 2-D :class:`numpy.ndarray` object
- ``"vectors"``: a :class:`pandas.DataFrame` object, a dictionary with array-like
values, or a sequence of sequences
In addition, the data can be given via a series of vectors (e.g., x/y/z). In this
case, the ``data`` argument is ``None`` and the data kind is determined by the
``required`` argument. The data kind is ``"vectors"`` if ``required`` is ``True``,
otherwise the data kind is ``"arg"``.
The function will fallback to ``"matrix"`` for any unrecognized data.
The function will fallback to ``"vectors"`` for any unrecognized data.
Parameters
----------
Expand All @@ -232,12 +231,12 @@ def data_kind(
>>> import xarray as xr
>>> import pandas as pd
>>> import pathlib
>>> [data_kind(data=data) for data in (2, 2.0, True, False)]
['arg', 'arg', 'arg', 'arg']
>>> data_kind(data=None)
'vectors'
'none'
>>> data_kind(data=None, required=False)
'arg'
>>> [data_kind(data=data) for data in (2, 2.0, True, False)]
['arg', 'arg', 'arg', 'arg']
>>> data_kind(data="my-data-file.txt")
'file'
>>> data_kind(data=pathlib.Path("my-data-file.txt"))
Expand All @@ -251,16 +250,16 @@ def data_kind(
>>> data_kind(data=np.arange(10).reshape((5, 2)))
'matrix'
>>> data_kind(data=pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}))
'matrix'
'vectors'
>>> data_kind(data={"x": [1, 2], "y": [3, 4]})
'matrix'
'vectors'
>>> data_kind(data=[1, 2, 3])
'matrix'
'vectors'
"""
# data is None, so data must be given via a series of vectors (i.e., x/y/z).
# The only exception is when dealing with optional virtual files.
if data is None:
return "vectors" if required else "arg"
return "none" if required else "arg"

# A file or a list of files
if isinstance(data, str | pathlib.PurePath) or (
Expand All @@ -282,8 +281,12 @@ def data_kind(
if hasattr(data, "__geo_interface__"):
return "geojson"

# Fallback to "matrix" for anything else
return "matrix"
# A 2-D numpy.ndarray
if hasattr(data, "__array_interface__") and data.ndim == 2:
return "matrix"

# Fallback to "vectors" for anything else
return "vectors"


def non_ascii_to_octal(
Expand Down
11 changes: 6 additions & 5 deletions pygmt/src/legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg
if kwargs.get("F") is None:
kwargs["F"] = box

with Session() as lib:
if spec is None:
match data_kind(spec):
case "none":
specfile = ""
elif data_kind(spec) == "file" and not is_nonstr_iter(spec):
# Is a file but not a list of files
case kind if kind == "file" and not is_nonstr_iter(spec):
specfile = spec
else:
case _:
raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}")

with Session() as lib:
lib.call_module(module="legend", args=build_arg_list(kwargs, infile=specfile))
2 changes: 1 addition & 1 deletion pygmt/src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def plot( # noqa: PLR0912

kind = data_kind(data)
extra_arrays = []
if kind == "vectors": # Add more columns for vectors input
if kind == "none": # Add more columns for vectors input
# Parameters for vector styles
if (
kwargs.get("S") is not None
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/plot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def plot3d( # noqa: PLR0912
kind = data_kind(data)
extra_arrays = []

if kind == "vectors": # Add more columns for vectors input
if kind == "none": # Add more columns for vectors input
# Parameters for vector styles
if (
kwargs.get("S") is not None
Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def text_( # noqa: PLR0912
"Provide either position only, or x/y pairs, or textfiles."
)
kind = data_kind(textfiles)
if kind == "vectors" and text is None:
if kind == "none" and text is None:
raise GMTInvalidInput("Must provide text with x/y pairs")
else:
if any(v is not None for v in (x, y, textfiles)):
Expand Down Expand Up @@ -227,7 +227,7 @@ def text_( # noqa: PLR0912

# Append text at last column. Text must be passed in as str type.
confdict = {}
if kind == "vectors":
if kind == "none":
text = np.atleast_1d(text).astype(str)
encoding = _check_encoding("".join(text))
if encoding != "ascii":
Expand Down
2 changes: 1 addition & 1 deletion pygmt/src/x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def x2sys_cross(
match data_kind(track):
case "file":
file_contexts.append(contextlib.nullcontext(track))
case "matrix":
case "vectors":
# find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from
# $X2SYS_HOME/TAGNAME/TAGNAME.tag file
tagfile = Path(
Expand Down
4 changes: 2 additions & 2 deletions pygmt/tests/test_grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_grdtrack_profile(dataarray):

def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
"""
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or file.
Run grdtrack using points input that is not a pandas.DataFrame or file.
"""
invalid_points = dataframe.longitude.to_xarray()

Expand All @@ -141,7 +141,7 @@ def test_grdtrack_wrong_kind_of_grid_input(dataarray, dataframe):
"""
invalid_grid = dataarray.to_dataset()

assert data_kind(invalid_grid) == "matrix"
assert data_kind(invalid_grid) == "vectors"
with pytest.raises(GMTInvalidInput):
grdtrack(points=dataframe, grid=invalid_grid, newcolname="bathymetry")

Expand Down
4 changes: 2 additions & 2 deletions pygmt/tests/test_grdview.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_grdview_wrong_kind_of_grid(xrgrid):
Run grdview using grid input that is not an xarray.DataArray or file.
"""
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
assert data_kind(dataset) == "matrix"
assert data_kind(dataset) == "vectors"

fig = Figure()
with pytest.raises(GMTInvalidInput):
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_grdview_wrong_kind_of_drapegrid(xrgrid):
Run grdview using drapegrid input that is not an xarray.DataArray or file.
"""
dataset = xrgrid.to_dataset() # convert xarray.DataArray to xarray.Dataset
assert data_kind(dataset) == "matrix"
assert data_kind(dataset) == "vectors"

fig = Figure()
with pytest.raises(GMTInvalidInput):
Expand Down

0 comments on commit f160dd6

Please sign in to comment.