diff --git a/pygmt/base_plotting.py b/pygmt/base_plotting.py index 811516c1335..b310ea60ef1 100644 --- a/pygmt/base_plotting.py +++ b/pygmt/base_plotting.py @@ -4,19 +4,21 @@ """ import contextlib import csv + import numpy as np import pandas as pd from .clib import Session from .exceptions import GMTError, GMTInvalidInput from .helpers import ( + GMTTempFile, build_arg_string, - dummy_context, data_kind, + dummy_context, fmt_docstring, - GMTTempFile, - use_alias, kwargs_to_strings, + tempfile_from_buffer, + use_alias, ) @@ -801,10 +803,11 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg Parameters ---------- - spec : None or str - Either None (default) for using the automatically generated legend - specification file, or a filename pointing to the legend - specification file. + spec : None or str or io.StringIO + Set to None (default) for using the automatically generated legend + specification file. Alternatively, pass in a filename or an + io.StringIO in-memory stream buffer pointing to the legend + specification text. {J} {R} position : str @@ -829,13 +832,17 @@ def legend(self, spec=None, position="JTR+jTR+o0.2c", box="+gwhite+p1p", **kwarg with Session() as lib: if spec is None: - specfile = "" + file_context = dummy_context("") elif data_kind(spec) == "file": - specfile = spec + file_context = dummy_context(spec) + elif data_kind(spec) == "buffer": + file_context = tempfile_from_buffer(spec) else: - raise GMTInvalidInput("Unrecognized data type: {}".format(type(spec))) - arg_str = " ".join([specfile, build_arg_string(kwargs)]) - lib.call_module("legend", arg_str) + raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}") + + with file_context as fname: + arg_str = " ".join([fname, build_arg_string(kwargs)]) + lib.call_module("legend", arg_str) @fmt_docstring @use_alias( diff --git a/pygmt/helpers/utils.py b/pygmt/helpers/utils.py index 60d33f957e1..7e24097871e 100644 --- a/pygmt/helpers/utils.py +++ b/pygmt/helpers/utils.py @@ -1,9 +1,10 @@ """ Utilities and common tasks for wrapping the GMT modules. """ -import sys +import io import shutil import subprocess +import sys import webbrowser from collections.abc import Iterable from contextlib import contextmanager @@ -62,6 +63,8 @@ def data_kind(data, x=None, y=None, z=None): if isinstance(data, str): kind = "file" + elif isinstance(data, io.StringIO): + kind = "buffer" elif isinstance(data, xr.DataArray): kind = "grid" elif data is not None: diff --git a/pygmt/tests/test_legend.py b/pygmt/tests/test_legend.py index 1fa98d6733a..f7008a592b8 100644 --- a/pygmt/tests/test_legend.py +++ b/pygmt/tests/test_legend.py @@ -1,6 +1,8 @@ """ Tests for legend """ +import io + import pytest from .. import Figure @@ -44,9 +46,6 @@ def test_legend_default_position(): return fig -@pytest.mark.xfail( - reason="Baseline image not updated to use earth relief grid in GMT 6.1.0", -) @pytest.mark.mpl_image_compare def test_legend_entries(): """ @@ -73,7 +72,8 @@ def test_legend_entries(): @pytest.mark.mpl_image_compare -def test_legend_specfile(): +@pytest.mark.parametrize("usebuffer", [True, False]) +def test_legend_specfile(usebuffer): """ Test specfile functionality. """ @@ -113,7 +113,10 @@ def test_legend_specfile(): fig = Figure() fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True) - fig.legend(specfile.name, position="JTM+jCM+w5i") + + spec = io.StringIO(specfile_contents) if usebuffer else specfile.name + + fig.legend(spec=spec, position="JTM+jCM+w5i") return fig