Skip to content

Commit

Permalink
Refactor grd modules to use virtualfile_from_data (#992)
Browse files Browse the repository at this point in the history
Remove duplicated if-elif-else code in grd2cpt, grdcontour,
grdcut, grdfilter, grdimage, grdtrack, and grdview through
the use of the virtualfile_from_data function.
  • Loading branch information
weiji14 authored Mar 3, 2021
1 parent 6d57b64 commit 10482f3
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 86 deletions.
2 changes: 1 addition & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No
kind = data_kind(data, x, y, z)

if check_kind == "raster" and kind not in ("file", "grid"):
raise GMTInvalidInput(f"Unrecognized data type: {type(data)}")
raise GMTInvalidInput(f"Unrecognized data type for grid: {type(data)}")
if check_kind == "vector" and kind not in ("file", "matrix", "vectors"):
raise GMTInvalidInput(f"Unrecognized data type: {type(data)}")

Expand Down
17 changes: 2 additions & 15 deletions pygmt/src/grd2cpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias


@fmt_docstring
Expand Down Expand Up @@ -169,14 +162,8 @@ def grd2cpt(grid, **kwargs):
"""
if "W" in kwargs and "Ww" in kwargs:
raise GMTInvalidInput("Set only categorical or cyclic to True, not both.")
kind = data_kind(grid)
with Session() as lib:
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
else:
raise GMTInvalidInput(f"Unrecognized data type: {type(grid)}")
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
if "H" not in kwargs.keys(): # if no output is set
arg_str = " ".join([infile, build_arg_string(kwargs)])
Expand Down
18 changes: 2 additions & 16 deletions pygmt/src/grdcontour.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
grdcontour - Plot a contour figure.
"""
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias


@fmt_docstring
Expand Down Expand Up @@ -103,14 +95,8 @@ def grdcontour(self, grid, **kwargs):
{t}
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
kind = data_kind(grid, None, None)
with Session() as lib:
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
else:
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
lib.call_module("grdcontour", arg_str)
13 changes: 1 addition & 12 deletions pygmt/src/grdcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
Expand Down Expand Up @@ -89,17 +86,9 @@ def grdcut(grid, **kwargs):
- None if ``outgrid`` is set (grid output will be stored in file set by
``outgrid``)
"""
kind = data_kind(grid)

with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
else:
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))

file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
if "G" not in kwargs.keys(): # if outgrid is unset, output to tempfile
kwargs.update({"G": tmpfile.name})
Expand Down
13 changes: 1 addition & 12 deletions pygmt/src/grdfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
Expand Down Expand Up @@ -143,17 +140,9 @@ def grdfilter(grid, **kwargs):
>>> grid = pygmt.datasets.load_earth_relief()
>>> smooth_field = pygmt.grdfilter(grid=grid, filter="g600", distance="4")
"""
kind = data_kind(grid)

with GMTTempFile(suffix=".nc") as tmpfile:
with Session() as lib:
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
else:
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))

file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as infile:
if "G" not in kwargs.keys(): # if outgrid is unset, output to tempfile
kwargs.update({"G": tmpfile.name})
Expand Down
18 changes: 2 additions & 16 deletions pygmt/src/grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,7 @@
grdimage - Plot grids or images.
"""
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
)
from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias


@fmt_docstring
Expand Down Expand Up @@ -157,14 +149,8 @@ def grdimage(self, grid, **kwargs):
{x}
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
kind = data_kind(grid, None, None)
with Session() as lib:
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
else:
raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid)))
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)
with file_context as fname:
arg_str = " ".join([fname, build_arg_string(kwargs)])
lib.call_module("grdimage", arg_str)
7 changes: 1 addition & 6 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
raise GMTInvalidInput(f"Unrecognized data type {type(points)}")

# Store the xarray.DataArray grid in virtualfile
if data_kind(grid) == "grid":
grid_context = lib.virtualfile_from_grid(grid)
elif data_kind(grid) == "file":
grid_context = dummy_context(grid)
else:
raise GMTInvalidInput(f"Unrecognized data type {type(grid)}")
grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

# Run grdtrack on the temporary (csv) points table
# and (netcdf) grid virtualfile
Expand Down
9 changes: 1 addition & 8 deletions pygmt/src/grdview.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pygmt.helpers import (
build_arg_string,
data_kind,
dummy_context,
fmt_docstring,
kwargs_to_strings,
use_alias,
Expand Down Expand Up @@ -112,14 +111,8 @@ def grdview(self, grid, **kwargs):
{t}
"""
kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access
kind = data_kind(grid, None, None)
with Session() as lib:
if kind == "file":
file_context = dummy_context(grid)
elif kind == "grid":
file_context = lib.virtualfile_from_grid(grid)
else:
raise GMTInvalidInput(f"Unrecognized data type for grid: {type(grid)}")
file_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

with contextlib.ExitStack() as stack:
if "G" in kwargs: # deal with kwargs["G"] if drapegrid is xr.DataArray
Expand Down

0 comments on commit 10482f3

Please sign in to comment.