From 39f5ae2e897e4e0c8ac92b34de3131b5c70feada Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Wed, 3 Mar 2021 16:33:10 +1300 Subject: [PATCH] Refactor grd modules to use virtualfile_from_data (#992) Remove duplicated if-elif-else code in grd2cpt, grdcontour, grdcut, grdfilter, grdimage, grdtrack, and grdview through the use of the virtualfile_from_data function. --- pygmt/clib/session.py | 2 +- pygmt/src/grd2cpt.py | 17 ++--------------- pygmt/src/grdcontour.py | 18 ++---------------- pygmt/src/grdcut.py | 13 +------------ pygmt/src/grdfilter.py | 13 +------------ pygmt/src/grdimage.py | 18 ++---------------- pygmt/src/grdtrack.py | 7 +------ pygmt/src/grdview.py | 9 +-------- 8 files changed, 11 insertions(+), 86 deletions(-) diff --git a/pygmt/clib/session.py b/pygmt/clib/session.py index 09919dc9394..64bcd55cf4e 100644 --- a/pygmt/clib/session.py +++ b/pygmt/clib/session.py @@ -1411,7 +1411,7 @@ def virtualfile_from_data(self, check_kind=None, data=None, x=None, y=None, z=No kind = data_kind(data, x, y, z) if check_kind == "raster" and kind not in ("file", "grid"): - raise GMTInvalidInput(f"Unrecognized data type: {type(data)}") + raise GMTInvalidInput(f"Unrecognized data type for grid: {type(data)}") if check_kind == "vector" and kind not in ("file", "matrix", "vectors"): raise GMTInvalidInput(f"Unrecognized data type: {type(data)}") diff --git a/pygmt/src/grd2cpt.py b/pygmt/src/grd2cpt.py index dc51d11bb44..5f5339436f7 100644 --- a/pygmt/src/grd2cpt.py +++ b/pygmt/src/grd2cpt.py @@ -4,14 +4,7 @@ from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import ( - build_arg_string, - data_kind, - dummy_context, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias @fmt_docstring @@ -169,14 +162,8 @@ def grd2cpt(grid, **kwargs): """ if "W" in kwargs and "Ww" in kwargs: raise GMTInvalidInput("Set only categorical or cyclic to True, not both.") - kind = data_kind(grid) with Session() as lib: - if kind == "file": - file_context = dummy_context(grid) - elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) - else: - raise GMTInvalidInput(f"Unrecognized data type: {type(grid)}") + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with file_context as infile: if "H" not in kwargs.keys(): # if no output is set arg_str = " ".join([infile, build_arg_string(kwargs)]) diff --git a/pygmt/src/grdcontour.py b/pygmt/src/grdcontour.py index bf700961fcb..15193cd0767 100644 --- a/pygmt/src/grdcontour.py +++ b/pygmt/src/grdcontour.py @@ -2,15 +2,7 @@ grdcontour - Plot a contour figure. """ from pygmt.clib import Session -from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import ( - build_arg_string, - data_kind, - dummy_context, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias @fmt_docstring @@ -103,14 +95,8 @@ def grdcontour(self, grid, **kwargs): {t} """ kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access - kind = data_kind(grid, None, None) with Session() as lib: - if kind == "file": - file_context = dummy_context(grid) - elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) - else: - raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid))) + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with file_context as fname: arg_str = " ".join([fname, build_arg_string(kwargs)]) lib.call_module("grdcontour", arg_str) diff --git a/pygmt/src/grdcut.py b/pygmt/src/grdcut.py index 5c33abbdb3b..0140898b90a 100644 --- a/pygmt/src/grdcut.py +++ b/pygmt/src/grdcut.py @@ -4,12 +4,9 @@ import xarray as xr from pygmt.clib import Session -from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( GMTTempFile, build_arg_string, - data_kind, - dummy_context, fmt_docstring, kwargs_to_strings, use_alias, @@ -89,17 +86,9 @@ def grdcut(grid, **kwargs): - None if ``outgrid`` is set (grid output will be stored in file set by ``outgrid``) """ - kind = data_kind(grid) - with GMTTempFile(suffix=".nc") as tmpfile: with Session() as lib: - if kind == "file": - file_context = dummy_context(grid) - elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) - else: - raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid))) - + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with file_context as infile: if "G" not in kwargs.keys(): # if outgrid is unset, output to tempfile kwargs.update({"G": tmpfile.name}) diff --git a/pygmt/src/grdfilter.py b/pygmt/src/grdfilter.py index 55de5cbb020..ce1e3b23d5e 100644 --- a/pygmt/src/grdfilter.py +++ b/pygmt/src/grdfilter.py @@ -4,12 +4,9 @@ import xarray as xr from pygmt.clib import Session -from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( GMTTempFile, build_arg_string, - data_kind, - dummy_context, fmt_docstring, kwargs_to_strings, use_alias, @@ -143,17 +140,9 @@ def grdfilter(grid, **kwargs): >>> grid = pygmt.datasets.load_earth_relief() >>> smooth_field = pygmt.grdfilter(grid=grid, filter="g600", distance="4") """ - kind = data_kind(grid) - with GMTTempFile(suffix=".nc") as tmpfile: with Session() as lib: - if kind == "file": - file_context = dummy_context(grid) - elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) - else: - raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid))) - + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with file_context as infile: if "G" not in kwargs.keys(): # if outgrid is unset, output to tempfile kwargs.update({"G": tmpfile.name}) diff --git a/pygmt/src/grdimage.py b/pygmt/src/grdimage.py index 8922fbaad3b..1fbe26160c6 100644 --- a/pygmt/src/grdimage.py +++ b/pygmt/src/grdimage.py @@ -2,15 +2,7 @@ grdimage - Plot grids or images. """ from pygmt.clib import Session -from pygmt.exceptions import GMTInvalidInput -from pygmt.helpers import ( - build_arg_string, - data_kind, - dummy_context, - fmt_docstring, - kwargs_to_strings, - use_alias, -) +from pygmt.helpers import build_arg_string, fmt_docstring, kwargs_to_strings, use_alias @fmt_docstring @@ -157,14 +149,8 @@ def grdimage(self, grid, **kwargs): {x} """ kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access - kind = data_kind(grid, None, None) with Session() as lib: - if kind == "file": - file_context = dummy_context(grid) - elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) - else: - raise GMTInvalidInput("Unrecognized data type: {}".format(type(grid))) + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with file_context as fname: arg_str = " ".join([fname, build_arg_string(kwargs)]) lib.call_module("grdimage", arg_str) diff --git a/pygmt/src/grdtrack.py b/pygmt/src/grdtrack.py index 5630795e25f..ea80d5c409d 100644 --- a/pygmt/src/grdtrack.py +++ b/pygmt/src/grdtrack.py @@ -84,12 +84,7 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs): raise GMTInvalidInput(f"Unrecognized data type {type(points)}") # Store the xarray.DataArray grid in virtualfile - if data_kind(grid) == "grid": - grid_context = lib.virtualfile_from_grid(grid) - elif data_kind(grid) == "file": - grid_context = dummy_context(grid) - else: - raise GMTInvalidInput(f"Unrecognized data type {type(grid)}") + grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid) # Run grdtrack on the temporary (csv) points table # and (netcdf) grid virtualfile diff --git a/pygmt/src/grdview.py b/pygmt/src/grdview.py index 15c9fc3bfc1..bf176738fe0 100644 --- a/pygmt/src/grdview.py +++ b/pygmt/src/grdview.py @@ -8,7 +8,6 @@ from pygmt.helpers import ( build_arg_string, data_kind, - dummy_context, fmt_docstring, kwargs_to_strings, use_alias, @@ -112,14 +111,8 @@ def grdview(self, grid, **kwargs): {t} """ kwargs = self._preprocess(**kwargs) # pylint: disable=protected-access - kind = data_kind(grid, None, None) with Session() as lib: - if kind == "file": - file_context = dummy_context(grid) - elif kind == "grid": - file_context = lib.virtualfile_from_grid(grid) - else: - raise GMTInvalidInput(f"Unrecognized data type for grid: {type(grid)}") + file_context = lib.virtualfile_from_data(check_kind="raster", data=grid) with contextlib.ExitStack() as stack: if "G" in kwargs: # deal with kwargs["G"] if drapegrid is xr.DataArray