diff --git a/.coveragerc b/.coveragerc index 905f7faf66..610e4237ff 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,4 +5,9 @@ omit = seaborn/colors/* seaborn/cm.py seaborn/conftest.py - seaborn/tests/* + +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: + raise NotImplementedError diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ea6f006695..8b644c1345 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,7 @@ name: CI on: push: - branches: master + branches: [master, nextgen/**] pull_request: branches: master workflow_dispatch: @@ -78,7 +78,7 @@ jobs: - name: Install seaborn run: | - pip install --upgrade pip + pip install --upgrade pip wheel if [[ ${{matrix.install}} == 'all' ]]; then EXTRAS='[all]'; fi if [[ ${{matrix.deps }} == 'pinned' ]]; then DEPS='-r ci/deps_pinned.txt'; fi pip install .$EXTRAS $DEPS -r ci/utils.txt @@ -96,3 +96,24 @@ jobs: - name: Upload coverage uses: codecov/codecov-action@v2 if: ${{ success() }} + + lint: + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + + - name: Install tools + run: pip install mypy flake8 + + - name: Flake8 + run: make lint + + - name: Type checking + run: make typecheck diff --git a/.gitignore b/.gitignore index 65b6ed00fc..c9e7058fe9 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,6 @@ htmlcov/ .idea/ .vscode/ .pytest_cache/ -notes/ .DS_Store +notes/ +notebooks/ diff --git a/Makefile b/Makefile index 1e15125f36..af98516b20 100644 --- a/Makefile +++ b/Makefile @@ -8,3 +8,6 @@ unittests: lint: flake8 seaborn + +typecheck: + mypy --follow-imports=skip seaborn/_core seaborn/_marks seaborn/_stats diff --git a/ci/deps_pinned.txt b/ci/deps_pinned.txt index 9949d00c47..0aaa1a3503 100644 --- a/ci/deps_pinned.txt +++ b/ci/deps_pinned.txt @@ -1,5 +1,8 @@ -numpy~=1.16.0 -pandas~=0.24.0 -matplotlib~=3.0.0 -scipy~=1.2.0 -statsmodels~=0.9.0 +numpy~=1.17.0 +pandas~=0.25.0 +matplotlib~=3.1.0 +scipy~=1.3.0 +statsmodels~=0.10.0 +# Pillow added in install_requires for later matplotlibs +pillow>=6.2.0 +typing_extensions \ No newline at end of file diff --git a/ci/utils.txt b/ci/utils.txt index 99f8cc215f..98821bcf71 100644 --- a/ci/utils.txt +++ b/ci/utils.txt @@ -2,3 +2,4 @@ pytest!=5.3.4 pytest-cov pytest-xdist flake8 +mypy diff --git a/doc/nextgen/.gitignore b/doc/nextgen/.gitignore new file mode 100644 index 0000000000..7cc96f6b9f --- /dev/null +++ b/doc/nextgen/.gitignore @@ -0,0 +1,4 @@ +_static/ +api/ +demo.rst +index.rst diff --git a/doc/nextgen/Makefile b/doc/nextgen/Makefile new file mode 100644 index 0000000000..4f25b0e3f7 --- /dev/null +++ b/doc/nextgen/Makefile @@ -0,0 +1,25 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +notebooks: + ./nb_to_doc.py ./index.ipynb + ./nb_to_doc.py ./demo.ipynb + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst new file mode 100644 index 0000000000..d58d2d77c4 --- /dev/null +++ b/doc/nextgen/api.rst @@ -0,0 +1,77 @@ +.. _nextgen_api: + +.. currentmodule:: seaborn.objects + +API Reference +============= + +.. warning:: + + This is a provisional API that is under active development, incomplete, and subject to change before release. + +Plot interface +-------------- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Plot + Plot.add + Plot.scale + Plot.facet + Plot.pair + Plot.configure + Plot.on + Plot.plot + Plot.save + Plot.show + +Marks +----- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Area + Bar + Dot + Line + Ribbon + Scatter + +Stats +----- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Agg + Hist + PolyFit + +Moves +----- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Dodge + Jitter + Shift + Stack + + +Scales +------ + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Nominal + Continuous + Temporal diff --git a/doc/nextgen/conf.py b/doc/nextgen/conf.py new file mode 100644 index 0000000000..25d40fb64e --- /dev/null +++ b/doc/nextgen/conf.py @@ -0,0 +1,88 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'seaborn' +copyright = '2022, Michael Waskom' +author = 'Michael Waskom' + +# The full version, including alpha/beta/rc tags +release = 'nextgen' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "numpydoc", + "IPython.sphinxext.ipython_console_highlighting", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '.ipynb_checkpoints'] + +# The reST default role (used for this markup: `text`) to use for all documents. +default_role = 'literal' + +autosummary_generate = True +numpydoc_show_class_members = False +autodoc_typehints = "none" + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "pydata_sphinx_theme" + +html_theme_options = { + "show_prev_next": False, + "page_sidebar_items": [], +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +html_logo = "_static/logo.svg" + +html_sidebars = { + # "**": [], + "demo": ["page-toc"] +} + + +# -- Intersphinx ------------------------------------------------ + +intersphinx_mapping = { + 'numpy': ('https://numpy.org/doc/stable/', None), + 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), + 'matplotlib': ('https://matplotlib.org/stable', None), + 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), + 'statsmodels': ('https://www.statsmodels.org/stable/', None) +} diff --git a/doc/nextgen/demo.ipynb b/doc/nextgen/demo.ipynb new file mode 100644 index 0000000000..ee519d29ba --- /dev/null +++ b/doc/nextgen/demo.ipynb @@ -0,0 +1,1064 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "662eff49-63cf-42b5-8a48-ac4145c2e3cc", + "metadata": {}, + "source": [ + "# Demonstration of next-generation seaborn interface" + ] + }, + { + "cell_type": "raw", + "id": "e7636dfe-2eff-4dc7-8f4f-325768c28cb4", + "metadata": {}, + "source": [ + ".. warning::\n", + "\n", + " This API is **experimental** and **unstable**. Please try it out and provide feedback, but expect it to change without warning prior to an official release." + ] + }, + { + "cell_type": "markdown", + "id": "fab541af", + "metadata": {}, + "source": [ + "## The basic interface\n", + "\n", + "The new interface exists as a set of classes that can be acessed through a single namespace import:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7cc1337", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn.objects as so" + ] + }, + { + "cell_type": "markdown", + "id": "7fd68dad", + "metadata": {}, + "source": [ + "This is a clean namespace, and I'm leaning towards recommending `from seaborn.objects import *` for interactive usecases. But let's not go so far just yet.\n", + "\n", + "Let's also import the main namespace so we can load our trusty example datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de5478fd", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn\n", + "seaborn.set_theme()" + ] + }, + { + "cell_type": "markdown", + "id": "cb0b155c-6a89-4f4d-826b-bf23e513cdad", + "metadata": {}, + "source": [ + "The main object is `seaborn.objects.Plot`. You instantiate it by passing data and some assignments from columns in the data to roles in the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2c13f9c-15b1-48ce-999e-b59f9a76ae52", + "metadata": {}, + "outputs": [], + "source": [ + "tips = seaborn.load_dataset(\"tips\")\n", + "so.Plot(tips, x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "id": "90050ae8-98ef-43b5-a079-523f97a01877", + "metadata": {}, + "source": [ + "But instantiating the `Plot` object doesn't actually plot anything. For that you need to add some layers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b1a4bec-aeac-4758-af07-dfc8f4adbf9e", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "7d9e32f9-ac92-4ef9-8f6a-777ef004424f", + "metadata": {}, + "source": [ + "Variables can be defined globally, or for a specific layer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b78774e1-b98f-4335-897f-6d9b2c404cfa", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips).add(so.Scatter(), x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "id": "29b96416-6bc4-480b-bc91-86a466b705c3", + "metadata": {}, + "source": [ + "Each layer can also have its own data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef21550d-a404-4b73-925b-3b9c8d00ec92", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .add(so.Scatter(color=\".6\"), data=tips.query(\"size != 2\"))\n", + " .add(so.Scatter(), data=tips.query(\"size == 2\"))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cfa61787-b6c9-4aef-8a39-533fd566fc74", + "metadata": {}, + "source": [ + "As in the existing interface, variables can be keys to the `data` object or vectors of various kinds:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "707e70c2-9751-4579-b9e9-a74d8d5ba8ad", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips.to_dict(), x=\"total_bill\")\n", + " .add(so.Scatter(), y=tips[\"tip\"].to_numpy())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2875d1e2-f06a-4166-8fdc-57c71dc0e56a", + "metadata": {}, + "source": [ + "The interface also supports semantic mappings between data and plot variables. But the specification of those mappings uses more explicit parameter names:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f78ad77-7708-4010-b2ae-3d7430d37e96", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"time\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "90911104-ec12-4cf1-bcdb-3991ca55f600", + "metadata": {}, + "source": [ + "It also offers a wider range of mappable features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e56e910c-e4f6-4e13-8913-c01c97a0c296", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\", fill=\"time\")\n", + " .add(so.Scatter(fillalpha=.8))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a84fb373", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Core components\n", + "\n", + "### Visual representation: the Mark" + ] + }, + { + "cell_type": "markdown", + "id": "a224ebd6-720b-4645-909e-58a2a0d787d3", + "metadata": {}, + "source": [ + "Each layer needs a `Mark` object, which defines how to draw the plot. There will be marks corresponding to existing seaborn functions and ones offering new functionality. But not many have been implemented yet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c31d7411-2a87-4e7a-baaf-5d3ef8cc5b91", + "metadata": {}, + "outputs": [], + "source": [ + "fmri = seaborn.load_dataset(\"fmri\").query(\"region == 'parietal'\")\n", + "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line())" + ] + }, + { + "cell_type": "markdown", + "id": "c973ed95-924e-47e0-960b-22fbffabae35", + "metadata": {}, + "source": [ + "`Mark` objects will expose an API to set features directly, rather than mapping them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df5244c8-60f2-4218-adaf-2036a9e72bc1", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, y=\"day\", x=\"total_bill\").add(so.Dot(color=\"#698\", alpha=.5))" + ] + }, + { + "cell_type": "markdown", + "id": "ae0e288e-74cf-461c-8e68-786e364032a1", + "metadata": {}, + "source": [ + "### Data transformation: the Stat\n", + "\n", + "\n", + "Built-in statistical transformations are one of seaborn's key features. But currently, they are tied up with the different visual representations. E.g., you can aggregate data in `lineplot`, but not in `scatterplot`.\n", + "\n", + "In the new interface, these concerns are separated. Each layer can accept a `Stat` object that applies a data transformation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9edb53ec-7146-43c6-870a-eff46ea282ba", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "markdown", + "id": "1788d935-5ad5-4262-993f-8d48c66631b9", + "metadata": {}, + "source": [ + "The `Stat` is computed on subsets of data defined by the semantic mappings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08fe699f-c6ce-4508-9746-efe1504e67b3", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "markdown", + "id": "08e0155f-e290-4378-9f2c-f818993cd8e2", + "metadata": {}, + "source": [ + "Each mark also accepts a `group` mapping that creates subsets without altering visual properties:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c6c94d2-81c5-42d7-9a53-885547a92bae", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", + " .add(so.Line(), so.Agg(), group=\"subject\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "aa9409ac-8200-4a4d-8f60-8bee612cd6c0", + "metadata": {}, + "source": [ + "The `Mark` and `Stat` objects allow for more compositionality and customization. There will be guidelines for how to define your own objects to plug into the broader system:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7edd619c-baf4-4acc-99f1-ebe5a9475555", + "metadata": {}, + "outputs": [], + "source": [ + "class PeakAnnotation(so.Mark):\n", + " def plot(self, split_generator, scales, orient):\n", + " for keys, data, ax in split_generator():\n", + " ix = data[\"y\"].idxmax()\n", + " ax.annotate(\n", + " \"The peak\", data.loc[ix, [\"x\", \"y\"]],\n", + " xytext=(10, -100), textcoords=\"offset points\",\n", + " va=\"top\", ha=\"center\",\n", + " arrowprops=dict(arrowstyle=\"->\", color=\".2\"),\n", + "\n", + " )\n", + "\n", + "(\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\")\n", + " .add(so.Line(), so.Agg())\n", + " .add(PeakAnnotation(), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "28ac1b3b-c83b-4e06-8ea5-7ba73b6f2498", + "metadata": {}, + "source": [ + "The new interface understands not just `x` and `y`, but also range specifiers; some `Stat` objects will output ranges, and some `Mark` objects will accept them. (This means that it will finally be possible to pass pre-defined error-bars into seaborn):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb9d0026-01a8-4ac7-a9fb-178144f063d2", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " fmri\n", + " .groupby(\"timepoint\")\n", + " .signal\n", + " .describe()\n", + " .pipe(so.Plot, x=\"timepoint\")\n", + " .add(so.Line(), y=\"mean\")\n", + " .add(so.Ribbon(alpha=.2), ymin=\"min\", ymax=\"max\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6c2dbb64-9569-4e93-9968-532d9d5cbaf1", + "metadata": {}, + "source": [ + "-----\n", + "\n", + "### Overplotting resolution: the Move\n", + "\n", + "Existing seaborn functions have parameters that allow adjustments for overplotting, such as `dodge=` in several categorical functions, `jitter=` in several functions based on scatterplots, and the `multiple=` paramter in distribution functions. In the new interface, those adjustments are abstracted away from the particular visual representation into the concept of a `Move`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cbd874f-cd3d-4cc2-b029-dddf40dc3965", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a0524b93-56d8-4695-b3c3-164989c3bf51", + "metadata": {}, + "source": [ + "Separating out the positional adjustment makes it possible to add additional flexibility without overwhelming the signature of a single function. For example, there will be more options for handling missing levels when dodging and for fine-tuning the adjustment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40916811-440a-49f9-8ae5-601472652a96", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge(empty=\"fill\", gap=.1))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d3fc22b3-01b0-427f-8ffe-8065daf757c9", + "metadata": {}, + "source": [ + "By default, the `move` will resolve all overlapping semantic mappings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e73fb57-450a-4c1d-8e3c-642dd0f032a3", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"sex\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0815cf5f-cc23-4104-b50e-589d6d675c51", + "metadata": {}, + "source": [ + "But you can specify a subset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68ec1247-4218-41e0-a5bb-2f76bc778ae0", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", + " .add(so.Dot(), move=so.Dodge(by=[\"color\"]))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c001004a-6771-46eb-b231-6accf88fe330", + "metadata": {}, + "source": [ + "It's also possible to stack multiple moves or kinds of moves by passing a list:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82421309-65f4-44cf-b0dd-5fcde629d784", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", + " .add(\n", + " so.Dot(),\n", + " move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "988f245a", + "metadata": {}, + "source": [ + "Separating the `Stat` and `Move` from the visual representation affords more flexibility, greatly expanding the space of graphics that can be created." + ] + }, + { + "cell_type": "markdown", + "id": "937d0e51-95b3-4997-8ca3-a63a09894a6b", + "metadata": { + "tags": [] + }, + "source": [ + "-----\n", + "\n", + "### Semantic mapping: the Scale\n", + "\n", + "The declarative interface allows users to represent dataset variables with visual properites such as position, color or size. A complete plot can be made without doing anything more defining the mappings: users need not be concerned with converting their data into units that matplotlib understands. But what if one wants to alter the mapping that seaborn chooses? This is accomplished through the concept of a `Scale`.\n", + "\n", + "The notion of scaling will probably not be unfamiliar; as in matplotlib, seaborn allows one to apply a mathematical transformation, such as `log`, to the coordinate variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "129d44e9-69b5-44e8-9b86-65074455913c", + "metadata": {}, + "outputs": [], + "source": [ + "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec1cbc42-5bdd-4287-8167-41f847e988c3", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\")\n", + " .scale(x=\"log\", y=\"log\")\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a43e28d7-99e1-4e17-aa20-d4f3bb8bc86e", + "metadata": {}, + "source": [ + "But the `Scale` concept is much more general in seaborn: a scale can be provided for any mappable property. For example, it is how you specify the palette used for color variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4dbdd051-df47-4508-a67b-29517c7c0831", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(x=\"log\", y=\"log\", color=\"rocket\")\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bbb34aca-47df-4029-8a83-994a46d04c65", + "metadata": {}, + "source": [ + "While there are a number of short-hand \"magic\" arguments you can provide for each scale, it is also possible to be more explicit by passing a `Scale` object. There are several distinct `Scale` classes, corresponding to the fundamental scale types (nominal, ordinal, continuous, etc.). Each class exposes a number of relevant parameters that control the details of the mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec8c0c03-1757-48de-9a71-bef16488296a", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(\n", + " x=\"log\",\n", + " y=so.Continuous(transform=\"log\").tick(at=[3, 10, 30, 100, 300]),\n", + " color=so.Continuous(\"rocket\", transform=\"log\"),\n", + " )\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "81565db5-8791-4f6c-bc49-59673081686c", + "metadata": {}, + "source": [ + "There are several different kinds of scales, including scales appropriate for categorical data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77b9ca9a-f2f7-48c3-913e-72a70ad1d21e", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"year\", y=\"distance\", color=\"method\")\n", + " .scale(\n", + " y=\"log\",\n", + " color=so.Nominal([\"b\", \"g\"], order=[\"Radial Velocity\", \"Transit\"])\n", + " )\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e7c9211-70fe-4f63-9951-7b9af68627a1", + "metadata": {}, + "source": [ + "It's also possible to disable scaling for a variable so that the literal values in the dataset are passed directly through to matplotlib:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc009a51-a725-4bdd-85c9-7b97bc86d96b", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"distance\", y=\"orbital_period\", pointsize=\"mass\")\n", + " .scale(x=\"log\", y=\"log\", pointsize=None)\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ca5430c5-8690-490a-80fb-698f264a7b6a", + "metadata": {}, + "source": [ + "Scaling interacts with the `Stat` and `Move` transformations. When an axis has a nonlinear scale, any statistical transformations or adjustments take place in the appropriate space:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e657b9f8-0dab-48e8-b074-995097f0e41c", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(planets, x=\"distance\").add(so.Bar(), so.Hist()).scale(x=\"log\")" + ] + }, + { + "cell_type": "markdown", + "id": "64de6841-07e1-4fa5-9b88-6a8984db59a0", + "metadata": {}, + "source": [ + "This is also true of the `Move` transformations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7ab3109-db3c-4bb6-aa3b-629a8c054ba5", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(\n", + " planets, x=\"distance\",\n", + " color=(planets[\"number\"] > 1).rename(\"multiple\")\n", + " )\n", + " .add(so.Bar(), so.Hist(), so.Dodge())\n", + " .scale(x=\"log\", color=so.Nominal())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5041491d-b47f-4fb3-af93-7c9490d6b901", + "metadata": {}, + "source": [ + "----\n", + "\n", + "## Defining subplot structure" + ] + }, + { + "cell_type": "markdown", + "id": "92c1a0fd-873f-476b-9e88-d6a2c4f49807", + "metadata": {}, + "source": [ + "Seaborn's faceting functionality (drawing subsets of the data on distinct subplots) is built into the `Plot` object and works interchangably with any `Mark`/`Stat`/`Move`/`Scale` spec:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cfc9ea6-b5d2-4fc3-9a59-62a09668944a", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .facet(\"time\", order=[\"Dinner\", \"Lunch\"])\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fc429604-d719-44b0-b504-edeaca481583", + "metadata": {}, + "source": [ + "Unlike the existing `FacetGrid` it is simple to *not* facet a layer, so that a plot is simply replicated across each column (or row):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "101e7d02-17b1-44b4-9f0c-6d7c4e194f76", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .facet(col=\"day\")\n", + " .add(so.Scatter(color=\".75\"), col=None)\n", + " .add(so.Scatter(), color=\"day\")\n", + " .configure(figsize=(7, 3))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "befb9400-f252-49fd-aee6-00a1b371c645", + "metadata": {}, + "source": [ + "The `Plot` object *also* subsumes the `PairGrid` functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06a63c71-3043-49b8-81c6-a8d7c8025015", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, y=\"day\")\n", + " .pair(x=[\"total_bill\", \"tip\"])\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0f2f885-2e87-41a7-bf21-877c05306067", + "metadata": {}, + "source": [ + "Pairing and faceting can be combined in the same plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0108128-635e-4f92-8621-65627b95b6ea", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"day\")\n", + " .facet(\"sex\")\n", + " .pair(y=[\"total_bill\", \"tip\"])\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0933fcf-8f11-470c-b5c1-c3c2a1a1c2a1", + "metadata": {}, + "source": [ + "Or the `Plot.pair` functionality can be used to define unique pairings between variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c2d4955-0f85-4318-8cac-7d8d33678bda", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips)\n", + " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cross=False)\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "be694009-ec20-4cdc-8be0-0b2e5a6839a1", + "metadata": {}, + "source": [ + "It's additionally possible to \"pair\" with a single variable, for univariate plots like histograms.\n", + "\n", + "Both faceted and paired plots with subplots along a single dimension can be \"wrapped\", and this works both columwise and rowwise:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c25cfa26-5c90-4699-8deb-9aa6ff41eae6", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips)\n", + " .pair(x=tips.columns, wrap=3)\n", + " .configure(sharey=False)\n", + " .add(so.Bar(), so.Hist())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "862d7901", + "metadata": {}, + "source": [ + "Importantly, there's no distinction between \"axes-level\" and \"figure-level\" here. Any kind of plot can be faceted or paired by adding a method call to the `Plot` definition, without changing anything else about how you are creating the figure." + ] + }, + { + "cell_type": "markdown", + "id": "d1eff6ab-84dd-4b32-9923-3d29fb43a209", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Iterating and displaying" + ] + }, + { + "cell_type": "markdown", + "id": "354b2395-4cad-40c0-a558-60368d5b435f", + "metadata": {}, + "source": [ + "It is possible (and in fact the deafult behavior) to be completely pyplot-free, and all the drawing is done by directly hooking into Jupyter's rich display system. Unlike in normal usage of the inline backend, writing code in a cell to define a plot is indendent from showing it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3171891-5e1e-4146-a940-f4327f40be3a", + "metadata": {}, + "outputs": [], + "source": [ + "p = so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd9fad6-0d9a-4cc8-9523-587270a71dc0", + "metadata": {}, + "outputs": [], + "source": [ + "p" + ] + }, + { + "cell_type": "markdown", + "id": "d7157904-0fcc-4eb8-8a7a-27df91cec68b", + "metadata": {}, + "source": [ + "By default, the methods on `Plot` do *not* mutate the object they are called on. This means that you can define a common base specification and then iterate on different versions of it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf8e1469-2dae-470f-8599-fe5d45b2b038", + "metadata": {}, + "outputs": [], + "source": [ + "p = (\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", + " .scale(color=\"crest\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b343b0e0-698a-4453-a3b8-b780f54724c8", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae17bce2-be77-44de-ada8-f546f786407d", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line(), group=\"subject\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e89ef5-3cd3-4ec0-af83-1e69c087bbfb", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "166d34d4-2b10-4aae-963d-9ba58f80f79d", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", + " .add(so.Line(linewidth=3), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9228ee06-2a6c-41cb-95cf-7bb217a421e0", + "metadata": {}, + "source": [ + "It's also possible to hook into the `pyplot` system by calling `Plot.show`. (As you might in a terminal interface, or to use a GUI). Notice how this looks lower-res: that's because `Plot` is generating \"high-DPI\" figures internally!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8055ab9-22c6-40cd-98e6-926a100cd173", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", + " .add(so.Line(linewidth=3), so.Agg())\n", + " .show()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "278e7ad4-a8e6-4cb7-ac61-9f2530ade898", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Matplotlib integration\n", + "\n", + "It's always been a design aim in seaborn to allow complicated seaborn plots to coexist within the context of a larger matplotlib figure. This is acheived within the \"axes-level\" functions, which accept an `ax=` parameter. The `Plot` object *will* provide a similar functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0701b67e-f037-4cfd-b3f6-304dfb47a13c", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "_, ax = mpl.figure.Figure(constrained_layout=True).subplots(1, 2)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .on(ax)\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "432144e8-e490-4213-8cc4-afdeeb467daa", + "metadata": {}, + "source": [ + "But a limitation has been that the \"figure-level\" functions, which can produce multiple subplots, cannot be directed towards an existing figure. That is no longer the case; `Plot.on()` also accepts a `Figure` (created either with or without `pyplot`) object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7c8c01e-db55-47ef-82f2-a69124bb4a94", + "metadata": {}, + "outputs": [], + "source": [ + "f = mpl.figure.Figure(constrained_layout=True)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .on(f)\n", + " .add(so.Scatter())\n", + " .facet(\"time\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b5b621be-f8c5-4515-81dd-6c7bd0e956ad", + "metadata": {}, + "source": [ + "Providing an existing figure is perhaps only marginally useful. While it will ease the integration of seaborn with GUI frameworks, seaborn is still using up the whole figure canvas. But with the introduction of the `SubFigure` concept in matplotlib 3.4, it becomes possible to place a small-multiples plot *within* a larger set of subplots:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "192e6587-642d-45da-85bd-ac220ffd66e9", + "metadata": {}, + "outputs": [], + "source": [ + "f = mpl.figure.Figure(constrained_layout=True, figsize=(8, 4))\n", + "sf1, sf2 = f.subfigures(1, 2)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", + " .add(so.Scatter())\n", + " .on(sf1)\n", + " .plot()\n", + ")\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", + " .facet(\"day\", wrap=2)\n", + " .add(so.Scatter())\n", + " .on(sf2)\n", + " .plot()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "baff5db0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py39-latest", + "language": "python", + "name": "seaborn-py39-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb new file mode 100644 index 0000000000..3e7cdbadcb --- /dev/null +++ b/doc/nextgen/index.ipynb @@ -0,0 +1,112 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b3b7451c-9938-4cc2-a6ee-7298548d3bfa", + "metadata": {}, + "source": [ + "# Next-generation seaborn interface\n", + "\n", + "Over the past year, I have been developing an entirely new interface for making plots with seaborn. The new interface is designed to be declarative, compositional and extensible. If successful, it will greatly expand the space of plots that can be created with seaborn while making the experience of using it simpler and more delightful.\n", + "\n", + "To make that concrete, here is a [hello world example](http://seaborn.pydata.org/introduction.html#our-first-seaborn-plot) with the new interface:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03997ae0-313d-46d8-9a7a-9b3e13f405fd", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "sns.set_theme()\n", + "tips = sns.load_dataset(\"tips\")\n", + "\n", + "import seaborn.objects as so\n", + "(\n", + " so.Plot(\n", + " tips, \"total_bill\", \"tip\",\n", + " color=\"smoker\", marker=\"smoker\", pointsize=\"size\",\n", + " )\n", + " .facet(\"time\")\n", + " .add(so.Scatter())\n", + " .configure(figsize=(7, 4))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c76dbb00-20ee-4508-bca3-76a4763e5640", + "metadata": {}, + "source": [ + "## Testing the alpha release\n", + "\n", + "If you're interested, please install the alpha and kick the tires. It is very far from complete, so expect some rough edges and instability! But feedback will be very helpful in pushing this towards a more stable broad release:\n", + "\n", + " pip install https://github.com/mwaskom/seaborn/archive/refs/tags/v0.12.0a0.tar.gz\n", + "\n", + "The documentation is still a work in progress, but there's a reasonably thorough demo of the main parts, and some basic API documentation for the existing classes." + ] + }, + { + "cell_type": "raw", + "id": "dee35714-b4c9-474d-96a9-a7c1e9312f23", + "metadata": {}, + "source": [ + ".. toctree::\n", + " :maxdepth: 1\n", + "\n", + " demo\n", + " api" + ] + }, + { + "cell_type": "markdown", + "id": "ebb5eb5b-515a-4374-996e-70cb72e883d3", + "metadata": {}, + "source": [ + "## Background and goals\n", + "\n", + "This work grew out of long-running efforts to refactor the seaborn internals so that its functions could rely on common code-paths. At a certain point, I realized that I was developing an API that might also be interesting for external users.\n", + "\n", + "Of course, \"write a new interface\" quickly turned into \"rethink every aspect of the library.\" The current interface has some [pain points](https://michaelwaskom.medium.com/three-common-seaborn-difficulties-10fdd0cc2a8b) that arise from early constraints and path dependence. By starting fresh, these can be avoided.\n", + "\n", + "Originally, seaborn existed as a toolbox of domain-specific statistical graphics to be used alongside matplotlib. As the library grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn's functions. Currently, this necessitates direct use of matplotlib. I've always thought that, if you're comfortable with both libraries, this setup offers a powerful blend of convenience and flexibility. But it can be hard to know which library will let you accomplish some specific task.\n", + "\n", + "So the new interface is designed to provide a more comprehensive experience, such that all of the steps involved in the creation of a reasonably-customized plot can be accomplished in the same way. And the compositional nature of the objects provides much more flexibility than currently exists in seaborn with a similar level of abstraction: this lets you focus on *what* you want to show rather than *how* to show it.\n", + "\n", + "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot (along with vega-lite, d3, and other great libraries), I've also made plenty of choices differently, for better or for worse." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cdc2435-9ef5-4b89-b85c-ad4f0c55050a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py39-latest", + "language": "python", + "name": "seaborn-py39-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst new file mode 100644 index 0000000000..a494884a31 --- /dev/null +++ b/doc/nextgen/index.rst @@ -0,0 +1,104 @@ +Next-generation seaborn interface +================================= + +Over the past year, I have been developing an entirely new interface for +making plots with seaborn. The new interface is designed to be +declarative, compositional and extensible. If successful, it will +greatly expand the space of plots that can be created with seaborn while +making the experience of using it simpler and more delightful. + +To make that concrete, here is a `hello world +example `__ +with the new interface: + +.. code:: ipython3 + + import seaborn as sns + sns.set_theme() + tips = sns.load_dataset("tips") + + import seaborn.objects as so + ( + so.Plot( + tips, "total_bill", "tip", + color="smoker", marker="smoker", pointsize="size", + ) + .facet("time") + .add(so.Scatter()) + .configure(figsize=(7, 4)) + ) + + + + +.. image:: index_files/index_1_0.png + :width: 632.8249999999999px + :height: 313.22499999999997px + + + +Testing the alpha release +------------------------- + +If you’re interested, please install the alpha and kick the tires. It is +very far from complete, so expect some rough edges and instability! But +feedback will be very helpful in pushing this towards a more stable +broad release: + +:: + + pip install https://github.com/mwaskom/seaborn/archive/refs/tags/v0.12.0a0.tar.gz + +The documentation is still a work in progress, but there’s a reasonably +thorough demo of the main parts, and some basic API documentation for +the existing classes. + +.. toctree:: + :maxdepth: 1 + + demo + api + +Background and goals +-------------------- + +This work grew out of long-running efforts to refactor the seaborn +internals so that its functions could rely on common code-paths. At a +certain point, I realized that I was developing an API that might also +be interesting for external users. + +Of course, “write a new interface” quickly turned into “rethink every +aspect of the library.” The current interface has some `pain +points `__ +that arise from early constraints and path dependence. By starting +fresh, these can be avoided. + +Originally, seaborn existed as a toolbox of domain-specific statistical +graphics to be used alongside matplotlib. As the library grew, it became +more common to reach for — or even learn — seaborn first. But one +inevitably desires some customization that is not offered within the +(already much-too-long) list of parameters in seaborn’s functions. +Currently, this necessitates direct use of matplotlib. I’ve always +thought that, if you’re comfortable with both libraries, this setup +offers a powerful blend of convenience and flexibility. But it can be +hard to know which library will let you accomplish some specific task. + +So the new interface is designed to provide a more comprehensive +experience, such that all of the steps involved in the creation of a +reasonably-customized plot can be accomplished in the same way. And the +compositional nature of the objects provides much more flexibility than +currently exists in seaborn with a similar level of abstraction: this +lets you focus on *what* you want to show rather than *how* to show it. + +One will note that the result looks a bit (a lot?) like ggplot. That’s +not unintentional, but the goal is also *not* to “port ggplot2 to +Python”. (If that’s what you’re looking for, check out the very nice +`plotnine `__ package). +There is an immense amount of wisdom in the grammar of graphics and in +its particular implementation as ggplot2. But, as languages, R and +Python are just too different for idioms from one to feel natural when +translated literally into the other. So while I have taken much +inspiration from ggplot (along with vega-lite, d3, and other great +libraries), I’ve also made plenty of choices differently, for better or +for worse. + diff --git a/doc/nextgen/nb_to_doc.py b/doc/nextgen/nb_to_doc.py new file mode 100755 index 0000000000..ddb7ca6b89 --- /dev/null +++ b/doc/nextgen/nb_to_doc.py @@ -0,0 +1,178 @@ +#! /usr/bin/env python +"""Execute a .ipynb file, write out a processed .rst and clean .ipynb. + +Some functions in this script were copied from the nbstripout tool: + +Copyright (c) 2015 Min RK, Florian Rathgeber, Michael McNeil Forbes +2019 Casper da Costa-Luis + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +import os +import sys +import nbformat +from nbconvert import RSTExporter +from nbconvert.preprocessors import ( + ExecutePreprocessor, + TagRemovePreprocessor, + ExtractOutputPreprocessor +) +from traitlets.config import Config + + +class MetadataError(Exception): + pass + + +def pop_recursive(d, key, default=None): + """dict.pop(key) where `key` is a `.`-delimited list of nested keys. + >>> d = {'a': {'b': 1, 'c': 2}} + >>> pop_recursive(d, 'a.c') + 2 + >>> d + {'a': {'b': 1}} + """ + nested = key.split('.') + current = d + for k in nested[:-1]: + if hasattr(current, 'get'): + current = current.get(k, {}) + else: + return default + if not hasattr(current, 'pop'): + return default + return current.pop(nested[-1], default) + + +def strip_output(nb): + """ + Strip the outputs, execution count/prompt number and miscellaneous + metadata from a notebook object, unless specified to keep either the + outputs or counts. + """ + keys = {'metadata': [], 'cell': {'metadata': ["execution"]}} + + nb.metadata.pop('signature', None) + nb.metadata.pop('widgets', None) + + for field in keys['metadata']: + pop_recursive(nb.metadata, field) + + for cell in nb.cells: + + # Remove the outputs, unless directed otherwise + if 'outputs' in cell: + + cell['outputs'] = [] + + # Remove the prompt_number/execution_count, unless directed otherwise + if 'prompt_number' in cell: + cell['prompt_number'] = None + if 'execution_count' in cell: + cell['execution_count'] = None + + # Always remove this metadata + for output_style in ['collapsed', 'scrolled']: + if output_style in cell.metadata: + cell.metadata[output_style] = False + if 'metadata' in cell: + for field in ['collapsed', 'scrolled', 'ExecuteTime']: + cell.metadata.pop(field, None) + for (extra, fields) in keys['cell'].items(): + if extra in cell: + for field in fields: + pop_recursive(getattr(cell, extra), field) + return nb + + +if __name__ == "__main__": + + # Get the desired ipynb file path and parse into components + _, fpath = sys.argv + basedir, fname = os.path.split(fpath) + fstem = fname[:-6] + + # Read the notebook + print(f"Executing {fpath} ...", end=" ", flush=True) + with open(fpath) as f: + nb = nbformat.read(f, as_version=4) + + # Run the notebook + kernel = os.environ.get("NB_KERNEL", None) + if kernel is None: + kernel = nb["metadata"]["kernelspec"]["name"] + ep = ExecutePreprocessor( + timeout=600, + kernel_name=kernel, + ) + ep.preprocess(nb, {"metadata": {"path": basedir}}) + + # Remove plain text execution result outputs + for cell in nb.get("cells", {}): + if "show-output" in cell["metadata"].get("tags", []): + continue + fields = cell.get("outputs", []) + for field in fields: + if field["output_type"] == "execute_result": + data_keys = field["data"].keys() + for key in list(data_keys): + if key == "text/plain": + field["data"].pop(key) + if not field["data"]: + fields.remove(field) + + # Convert to .rst formats + exp = RSTExporter() + + c = Config() + c.TagRemovePreprocessor.remove_cell_tags = {"hide"} + c.TagRemovePreprocessor.remove_input_tags = {"hide-input"} + c.TagRemovePreprocessor.remove_all_outputs_tags = {"hide-output"} + c.ExtractOutputPreprocessor.output_filename_template = \ + f"{fstem}_files/{fstem}_" + "{cell_index}_{index}{extension}" + + exp.register_preprocessor(TagRemovePreprocessor(config=c), True) + exp.register_preprocessor(ExtractOutputPreprocessor(config=c), True) + + body, resources = exp.from_notebook_node(nb) + + # Clean the output on the notebook and save a .ipynb back to disk + print(f"Writing clean {fpath} ... ", end=" ", flush=True) + nb = strip_output(nb) + with open(fpath, "wt") as f: + nbformat.write(nb, f) + + # Write the .rst file + rst_path = os.path.join(basedir, f"{fstem}.rst") + print(f"Writing {rst_path}") + with open(rst_path, "w") as f: + f.write(body) + + # Write the individual image outputs + imdir = os.path.join(basedir, f"{fstem}_files") + if not os.path.exists(imdir): + os.mkdir(imdir) + + for imname, imdata in resources["outputs"].items(): + if imname.startswith(fstem): + impath = os.path.join(basedir, f"{imname}") + with open(impath, "wb") as f: + f.write(imdata) diff --git a/doc/requirements.txt b/doc/requirements.txt index 5ac137016a..6ddd964920 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,5 +1,6 @@ docutils<0.18 # https://sourceforge.net/p/docutils/bugs/431/ sphinx==3.3.1 +jinja2<3.1 # Needed for compat with pinned sphinx sphinx_bootstrap_theme==0.8.0 jinja2<3.1 # Needed for compat with pinned sphinx numpydoc diff --git a/seaborn/_compat.py b/seaborn/_compat.py new file mode 100644 index 0000000000..44f409b231 --- /dev/null +++ b/seaborn/_compat.py @@ -0,0 +1,125 @@ +import numpy as np +import matplotlib as mpl +from seaborn.external.version import Version + + +def MarkerStyle(marker=None, fillstyle=None): + """ + Allow MarkerStyle to accept a MarkerStyle object as parameter. + + Supports matplotlib < 3.3.0 + https://github.com/matplotlib/matplotlib/pull/16692 + + """ + if isinstance(marker, mpl.markers.MarkerStyle): + if fillstyle is None: + return marker + else: + marker = marker.get_marker() + return mpl.markers.MarkerStyle(marker, fillstyle) + + +def norm_from_scale(scale, norm): + """Produce a Normalize object given a Scale and min/max domain limits.""" + # This is an internal maplotlib function that simplifies things to access + # It is likely to become part of the matplotlib API at some point: + # https://github.com/matplotlib/matplotlib/issues/20329 + if isinstance(norm, mpl.colors.Normalize): + return norm + + if scale is None: + return None + + if norm is None: + vmin = vmax = None + else: + vmin, vmax = norm # TODO more helpful error if this fails? + + class ScaledNorm(mpl.colors.Normalize): + + def __call__(self, value, clip=None): + # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py + # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE + value, is_scalar = self.process_value(value) + self.autoscale_None(value) + if self.vmin > self.vmax: + raise ValueError("vmin must be less or equal to vmax") + if self.vmin == self.vmax: + return np.full_like(value, 0) + if clip is None: + clip = self.clip + if clip: + value = np.clip(value, self.vmin, self.vmax) + # ***** Seaborn changes start **** + t_value = self.transform(value).reshape(np.shape(value)) + t_vmin, t_vmax = self.transform([self.vmin, self.vmax]) + # ***** Seaborn changes end ***** + if not np.isfinite([t_vmin, t_vmax]).all(): + raise ValueError("Invalid vmin or vmax") + t_value -= t_vmin + t_value /= (t_vmax - t_vmin) + t_value = np.ma.masked_invalid(t_value, copy=False) + return t_value[0] if is_scalar else t_value + + new_norm = ScaledNorm(vmin, vmax) + new_norm.transform = scale.get_transform().transform + + return new_norm + + +def scale_factory(scale, axis, **kwargs): + """ + Backwards compatability for creation of independent scales. + + Matplotlib scales require an Axis object for instantiation on < 3.4. + But the axis is not used, aside from extraction of the axis_name in LogScale. + + """ + modify_transform = False + if Version(mpl.__version__) < Version("3.4"): + if axis[0] in "xy": + modify_transform = True + axis = axis[0] + base = kwargs.pop("base", None) + if base is not None: + kwargs[f"base{axis}"] = base + nonpos = kwargs.pop("nonpositive", None) + if nonpos is not None: + kwargs[f"nonpos{axis}"] = nonpos + + if isinstance(scale, str): + class Axis: + axis_name = axis + axis = Axis() + + scale = mpl.scale.scale_factory(scale, axis, **kwargs) + + if modify_transform: + transform = scale.get_transform() + transform.base = kwargs.get("base", 10) + if kwargs.get("nonpositive") == "mask": + # Setting a private attribute, but we only get here + # on an old matplotlib, so this won't break going forwards + transform._clip = False + + return scale + + +def set_scale_obj(ax, axis, scale): + """Handle backwards compatability with setting matplotlib scale.""" + if Version(mpl.__version__) < Version("3.4"): + # The ability to pass a BaseScale instance to Axes.set_{}scale was added + # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089 + # Workaround: use the scale name, which is restrictive only if the user + # wants to define a custom scale; they'll need to update the registry too. + if scale.name is None: + # Hack to support our custom Formatter-less CatScale + return + method = getattr(ax, f"set_{axis}scale") + kws = {} + if scale.name == "function": + trans = scale.get_transform() + kws["functions"] = (trans._forward, trans._inverse) + method(scale.name, **kws) + else: + ax.set(**{f"{axis}scale": scale}) diff --git a/seaborn/_core/__init__.py b/seaborn/_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py new file mode 100644 index 0000000000..9de8be5fb8 --- /dev/null +++ b/seaborn/_core/data.py @@ -0,0 +1,262 @@ +""" +Components for parsing variable assignments and internally representing plot data. +""" +from __future__ import annotations + +from collections import abc +import pandas as pd + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from pandas import DataFrame + from seaborn._core.typing import DataSource, VariableSpec + + +# TODO Repetition in the docstrings should be reduced with interpolation tools + +class PlotData: + """ + Data table with plot variable schema and mapping to original names. + + Contains logic for parsing variable specification arguments and updating + the table with layer-specific data and/or mappings. + + Parameters + ---------- + data + Input data where variable names map to vector values. + variables + Keys are names of plot variables (x, y, ...) each value is one of: + + - name of a column (or index level, or dictionary entry) in `data` + - vector in any format that can construct a :class:`pandas.DataFrame` + + Attributes + ---------- + frame + Data table with column names having defined plot variables. + names + Dictionary mapping plot variable names to names in source data structure(s). + ids + Dictionary mapping plot variable names to unique data source identifiers. + + """ + frame: DataFrame + frames: dict[tuple, DataFrame] + names: dict[str, str | None] + ids: dict[str, str | int] + source_data: DataSource + source_vars: dict[str, VariableSpec] + + def __init__( + self, + data: DataSource, + variables: dict[str, VariableSpec], + ): + + frame, names, ids = self._assign_variables(data, variables) + + self.frame = frame + self.names = names + self.ids = ids + + self.frames = {} # TODO this is a hack, remove + + self.source_data = data + self.source_vars = variables + + def __contains__(self, key: str) -> bool: + """Boolean check on whether a variable is defined in this dataset.""" + if self.frame is None: + return any(key in df for df in self.frames.values()) + return key in self.frame + + def join( + self, + data: DataSource, + variables: dict[str, VariableSpec] | None, + ) -> PlotData: + """Add, replace, or drop variables and return as a new dataset.""" + # Inherit the original source of the upsteam data by default + if data is None: + data = self.source_data + + # TODO allow `data` to be a function (that is called on the source data?) + + if not variables: + variables = self.source_vars + + # Passing var=None implies that we do not want that variable in this layer + disinherit = [k for k, v in variables.items() if v is None] + + # Create a new dataset with just the info passed here + new = PlotData(data, variables) + + # -- Update the inherited DataSource with this new information + + drop_cols = [k for k in self.frame if k in new.frame or k in disinherit] + parts = [self.frame.drop(columns=drop_cols), new.frame] + + # Because we are combining distinct columns, this is perhaps more + # naturally thought of as a "merge"/"join". But using concat because + # some simple testing suggests that it is marginally faster. + frame = pd.concat(parts, axis=1, sort=False, copy=False) + + names = {k: v for k, v in self.names.items() if k not in disinherit} + names.update(new.names) + + ids = {k: v for k, v in self.ids.items() if k not in disinherit} + ids.update(new.ids) + + new.frame = frame + new.names = names + new.ids = ids + + # Multiple chained operations should always inherit from the original object + new.source_data = self.source_data + new.source_vars = self.source_vars + + return new + + def _assign_variables( + self, + data: DataSource, + variables: dict[str, VariableSpec], + ) -> tuple[DataFrame, dict[str, str | None], dict[str, str | int]]: + """ + Assign values for plot variables given long-form data and/or vector inputs. + + Parameters + ---------- + data + Input data where variable names map to vector values. + variables + Keys are names of plot variables (x, y, ...) each value is one of: + + - name of a column (or index level, or dictionary entry) in `data` + - vector in any format that can construct a :class:`pandas.DataFrame` + + Returns + ------- + frame + Table mapping seaborn variables (x, y, color, ...) to data vectors. + names + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + ids + Like the `names` dict, but `None` values are replaced by the `id()` + of the data object that defined the variable. + + Raises + ------ + ValueError + When variables are strings that don't appear in `data`, or when they are + non-indexed vector datatypes that have a different length from `data`. + + """ + source_data: dict | DataFrame + frame: DataFrame + names: dict[str, str | None] + ids: dict[str, str | int] + + plot_data = {} + names = {} + ids = {} + + given_data = data is not None + if given_data: + source_data = data + else: + # Data is optional; all variables can be defined as vectors + # But simplify downstream code by always having a usable source data object + source_data = {} + + # TODO Generally interested in accepting a generic DataFrame interface + # Track https://data-apis.org/ for development + + # Variables can also be extracted from the index of a DataFrame + if isinstance(source_data, pd.DataFrame): + index = source_data.index.to_frame().to_dict("series") + else: + index = {} + + for key, val in variables.items(): + + # Simply ignore variables with no specification + if val is None: + continue + + # Try to treat the argument as a key for the data collection. + # But be flexible about what can be used as a key. + # Usually it will be a string, but allow other hashables when + # taking from the main data object. Allow only strings to reference + # fields in the index, because otherwise there is too much ambiguity. + + # TODO this will be rendered unnecessary by the following pandas fix: + # https://github.com/pandas-dev/pandas/pull/41283 + try: + hash(val) + val_is_hashable = True + except TypeError: + val_is_hashable = False + + val_as_data_key = ( + # See https://github.com/pandas-dev/pandas/pull/41283 + # (isinstance(val, abc.Hashable) and val in source_data) + (val_is_hashable and val in source_data) + or (isinstance(val, str) and val in index) + ) + + if val_as_data_key: + + if val in source_data: + plot_data[key] = source_data[val] + elif val in index: + plot_data[key] = index[val] + names[key] = ids[key] = str(val) + + elif isinstance(val, str): + + # This looks like a column name but, lookup failed. + + err = f"Could not interpret value `{val}` for `{key}`. " + if not given_data: + err += "Value is a string, but `data` was not passed." + else: + err += "An entry with this name does not appear in `data`." + raise ValueError(err) + + else: + + # Otherwise, assume the value somehow represents data + + # Ignore empty data structures + if isinstance(val, abc.Sized) and len(val) == 0: + continue + + # If vector has no index, it must match length of data table + if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): + if isinstance(val, abc.Sized) and len(data) != len(val): + val_cls = val.__class__.__name__ + err = ( + f"Length of {val_cls} vectors must match length of `data`" + f" when both are used, but `data` has length {len(data)}" + f" and the vector passed to `{key}` has length {len(val)}." + ) + raise ValueError(err) + + plot_data[key] = val + + # Try to infer the original name using pandas-like metadata + if hasattr(val, "name"): + names[key] = ids[key] = str(val.name) # type: ignore # mypy/1424 + else: + names[key] = None + ids[key] = id(val) + + # Construct a tidy plot DataFrame. This will convert a number of + # types automatically, aligning on index in case of pandas objects + # TODO Note: this fails when variable specs *only* have scalars! + frame = pd.DataFrame(plot_data) + + return frame, names, ids diff --git a/seaborn/_core/groupby.py b/seaborn/_core/groupby.py new file mode 100644 index 0000000000..3809a530f5 --- /dev/null +++ b/seaborn/_core/groupby.py @@ -0,0 +1,124 @@ +"""Simplified split-apply-combine paradigm on dataframes for internal use.""" +from __future__ import annotations + +import pandas as pd + +from seaborn._core.rules import categorical_order + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Callable + from pandas import DataFrame, MultiIndex, Index + + +class GroupBy: + """ + Interface for Pandas GroupBy operations allowing specified group order. + + Writing our own class to do this has a few advantages: + - It constrains the interface between Plot and Stat/Move objects + - It allows control over the row order of the GroupBy result, which is + important when using in the context of some Move operations (dodge, stack, ...) + - It simplifies some complexities regarding the return type and Index contents + one encounters with Pandas, especially for DataFrame -> DataFrame applies + - It increases future flexibility regarding alternate DataFrame libraries + + """ + def __init__(self, order: list[str] | dict[str, list | None]): + """ + Initialize the GroupBy from grouping variables and optional level orders. + + Parameters + ---------- + order + List of variable names or dict mapping names to desired level orders. + Level order values can be None to use default ordering rules. The + variables can include names that are not expected to appear in the + data; these will be dropped before the groups are defined. + + """ + if not order: + raise ValueError("GroupBy requires at least one grouping variable") + + if isinstance(order, list): + order = {k: None for k in order} + self.order = order + + def _get_groups(self, data: DataFrame) -> MultiIndex: + """Return index with Cartesian product of ordered grouping variable levels.""" + levels = {} + for var, order in self.order.items(): + if var in data: + if order is None: + order = categorical_order(data[var]) + levels[var] = order + + grouper: str | list[str] + groups: Index | MultiIndex | None + if not levels: + grouper = [] + groups = None + elif len(levels) > 1: + grouper = list(levels) + groups = pd.MultiIndex.from_product(levels.values(), names=grouper) + else: + grouper, = list(levels) + groups = pd.Index(levels[grouper], name=grouper) + return grouper, groups + + def _reorder_columns(self, res, data): + """Reorder result columns to match original order with new columns appended.""" + cols = [c for c in data if c in res] + cols += [c for c in res if c not in data] + return res.reindex(columns=pd.Index(cols)) + + def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame: + """ + Reduce each group to a single row in the output. + + The output will have a row for each unique combination of the grouping + variable levels with null values for the aggregated variable(s) where + those combinations do not appear in the dataset. + + """ + grouper, groups = self._get_groups(data) + + if not grouper: + # We will need to see whether there are valid usecases that end up here + raise ValueError("No grouping variables are present in dataframe") + + res = ( + data + .groupby(grouper, sort=False, observed=True) + .agg(*args, **kwargs) + .reindex(groups) + .reset_index() + .pipe(self._reorder_columns, data) + ) + + return res + + def apply( + self, data: DataFrame, func: Callable[..., DataFrame], + *args, **kwargs, + ) -> DataFrame: + """Apply a DataFrame -> DataFrame mapping to each group.""" + grouper, groups = self._get_groups(data) + + if not grouper: + return self._reorder_columns(func(data, *args, **kwargs), data) + + parts = {} + for key, part_df in data.groupby(grouper, sort=False): + parts[key] = func(part_df, *args, **kwargs) + stack = [] + for key in groups: + if key in parts: + if isinstance(grouper, list): + group_ids = dict(zip(grouper, key)) + else: + group_ids = {grouper: key} + stack.append(parts[key].assign(**group_ids)) + + res = pd.concat(stack, ignore_index=True) + return self._reorder_columns(res, data) diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py new file mode 100644 index 0000000000..b5b9593de9 --- /dev/null +++ b/seaborn/_core/moves.py @@ -0,0 +1,160 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy as np + +from seaborn._core.groupby import GroupBy + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional + from pandas import DataFrame + + +@dataclass +class Move: + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + raise NotImplementedError + + +@dataclass +class Jitter(Move): + """ + Random displacement of marks along either or both axes to reduce overplotting. + """ + width: float = 0 + x: float = 0 + y: float = 0 + + seed: Optional[int] = None + + # TODO what is the best way to have a reasonable default? + # The problem is that "reasonable" seems dependent on the mark + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + + # TODO is it a problem that GroupBy is not used for anything here? + # Should we type it as optional? + + data = data.copy() + + rng = np.random.default_rng(self.seed) + + def jitter(data, col, scale): + noise = rng.uniform(-.5, +.5, len(data)) + offsets = noise * scale + return data[col] + offsets + + if self.width: + data[orient] = jitter(data, orient, self.width * data["width"]) + if self.x: + data["x"] = jitter(data, "x", self.x) + if self.y: + data["y"] = jitter(data, "y", self.y) + + return data + + +@dataclass +class Dodge(Move): + """ + Displacement and narrowing of overlapping marks along orientation axis. + """ + empty: str = "keep" # keep, drop, fill + gap: float = 0 + + # TODO accept just a str here? + by: Optional[list[str]] = None + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + + grouping_vars = [v for v in groupby.order if v in data] + groups = groupby.agg(data, {"width": "max"}) + if self.empty == "fill": + groups = groups.dropna() + + def groupby_pos(s): + grouper = [groups[v] for v in [orient, "col", "row"] if v in data] + return s.groupby(grouper, sort=False, observed=True) + + def scale_widths(w): + # TODO what value to fill missing widths??? Hard problem... + # TODO short circuit this if outer widths has no variance? + empty = 0 if self.empty == "fill" else w.mean() + filled = w.fillna(empty) + scale = filled.max() + norm = filled.sum() + if self.empty == "keep": + w = filled + return w / norm * scale + + def widths_to_offsets(w): + return w.shift(1).fillna(0).cumsum() + (w - w.sum()) / 2 + + new_widths = groupby_pos(groups["width"]).transform(scale_widths) + offsets = groupby_pos(new_widths).transform(widths_to_offsets) + + if self.gap: + new_widths *= 1 - self.gap + + groups["_dodged"] = groups[orient] + offsets + groups["width"] = new_widths + + out = ( + data + .drop("width", axis=1) + .merge(groups, on=grouping_vars, how="left") + .drop(orient, axis=1) + .rename(columns={"_dodged": orient}) + ) + + return out + + +@dataclass +class Stack(Move): + """ + Displacement of overlapping bar or area marks along the value axis. + """ + # TODO center? (or should this be a different move?) + + def _stack(self, df, orient): + + # TODO should stack do something with ymin/ymax style marks? + # Should there be an upstream conversion to baseline/height parameterization? + + if df["baseline"].nunique() > 1: + err = "Stack move cannot be used when baselines are already heterogeneous" + raise RuntimeError(err) + + other = {"x": "y", "y": "x"}[orient] + stacked_lengths = (df[other] - df["baseline"]).dropna().cumsum() + offsets = stacked_lengths.shift(1).fillna(0) + + df[other] = stacked_lengths + df["baseline"] = df["baseline"] + offsets + + return df + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + + # TODO where to ensure that other semantic variables are sorted properly? + groupers = ["col", "row", orient] + return GroupBy(groupers).apply(data, self._stack, orient) + + +@dataclass +class Shift(Move): + """ + Displacement of all marks with the same magnitude / direction. + """ + x: float = 0 + y: float = 0 + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + + data = data.copy(deep=False) + data["x"] = data["x"] + self.x + data["y"] = data["y"] + self.y + return data diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py new file mode 100644 index 0000000000..2abd64e79e --- /dev/null +++ b/seaborn/_core/plot.py @@ -0,0 +1,1390 @@ +"""The classes for specifying and compiling a declarative visualization.""" +from __future__ import annotations + +import io +import os +import re +import sys +import inspect +import itertools +import textwrap +from collections import abc +from collections.abc import Callable, Generator, Hashable +from typing import Any + +import pandas as pd +from pandas import DataFrame, Series, Index +import matplotlib as mpl +from matplotlib.axes import Axes +from matplotlib.artist import Artist +from matplotlib.figure import Figure + +from seaborn._marks.base import Mark +from seaborn._stats.base import Stat +from seaborn._core.data import PlotData +from seaborn._core.moves import Move +from seaborn._core.scales import ScaleSpec, Scale +from seaborn._core.subplots import Subplots +from seaborn._core.groupby import GroupBy +from seaborn._core.properties import PROPERTIES, Property, Coordinate +from seaborn._core.typing import DataSource, VariableSpec, OrderSpec +from seaborn._core.rules import categorical_order +from seaborn._compat import set_scale_obj +from seaborn.external.version import Version + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from matplotlib.figure import SubFigure + + +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +# ---- Definitions for internal specs --------------------------------- # + + +class Layer(TypedDict, total=False): + + mark: Mark # TODO allow list? + stat: Stat | None # TODO allow list? + move: Move | list[Move] | None + data: PlotData + source: DataSource + vars: dict[str, VariableSpec] + orient: str + + +class FacetSpec(TypedDict, total=False): + + variables: dict[str, VariableSpec] + structure: dict[str, list[str]] + wrap: int | None + + +class PairSpec(TypedDict, total=False): + + variables: dict[str, VariableSpec] + structure: dict[str, list[str]] + cross: bool + wrap: int | None + + +# ---- The main interface for declarative plotting -------------------- # + + +def build_plot_signature(cls): + """ + Decorator function for giving Plot a useful signature. + + Currently this mostly saves us some duplicated typing, but we would + like eventually to have a way of registering new semantic properties, + at which point dynamic signature generation would become more important. + + """ + sig = inspect.signature(cls) + params = [ + inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL), + inspect.Parameter("data", inspect.Parameter.KEYWORD_ONLY, default=None) + ] + params.extend([ + inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=None) + for name in PROPERTIES + ]) + new_sig = sig.replace(parameters=params) + cls.__signature__ = new_sig + + known_properties = textwrap.fill( + ", ".join(PROPERTIES), 78, subsequent_indent=" " * 8, + ) + + if cls.__doc__ is not None: # support python -OO mode + cls.__doc__ = cls.__doc__.format(known_properties=known_properties) + + return cls + + +@build_plot_signature +class Plot: + """ + An interface for declaratively specifying statistical graphics. + + Plots are constructed by initializing this class and adding one or more + layers, comprising a `Mark` and optional `Stat` or `Move`. Additionally, + faceting variables or variable pairings may be defined to divide the space + into multiple subplots. The mappings from data values to visual properties + can be parametrized using scales, although the plot will try to infer good + defaults when scales are not explicitly defined. + + The constructor accepts a data source (a :class:`pandas.DataFrame` or + dictionary with columnar values) and variable assignments. Variables can be + passed as keys to the data source or directly as data vectors. If multiple + data-containing objects are provided, they will be index-aligned. + + The data source and variables defined in the constructor will be used for + all layers in the plot, unless overridden or disabled when adding a layer. + + The following variables can be defined in the constructor: + {known_properties} + + The `data`, `x`, and `y` variables can be passed as positional arguments or + using keywords. Whether the first positional argument is interpreted as a + data source or `x` variable depends on its type. + + The methods of this class return a copy of the instance; use chaining to + build up a plot through multiple calls. Methods can be called in any order. + + Most methods only add information to the plot spec; no actual processing + happens until the plot is shown or saved. It is also possible to compile + the plot without rendering it to access the lower-level representation. + + """ + # TODO use TypedDict throughout? + + _data: PlotData + _layers: list[Layer] + _scales: dict[str, ScaleSpec] + + _subplot_spec: dict[str, Any] # TODO values type + _facet_spec: FacetSpec + _pair_spec: PairSpec + + def __init__( + self, + *args: DataSource | VariableSpec, + data: DataSource = None, + **variables: VariableSpec, + ): + + if args: + data, variables = self._resolve_positionals(args, data, variables) + + unknown = [x for x in variables if x not in PROPERTIES] + if unknown: + err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}" + raise TypeError(err) + + self._data = PlotData(data, variables) + self._layers = [] + self._scales = {} + + self._subplot_spec = {} + self._facet_spec = {} + self._pair_spec = {} + + self._target = None + + def _resolve_positionals( + self, + args: tuple[DataSource | VariableSpec, ...], + data: DataSource, + variables: dict[str, VariableSpec], + ) -> tuple[DataSource, dict[str, VariableSpec]]: + """Handle positional arguments, which may contain data / x / y.""" + if len(args) > 3: + err = "Plot() accepts no more than 3 positional arguments (data, x, y)." + raise TypeError(err) + + # TODO need some clearer way to differentiate data / vector here + # (There might be an abstract DataFrame class to use here?) + if isinstance(args[0], (abc.Mapping, pd.DataFrame)): + if data is not None: + raise TypeError("`data` given by both name and position.") + data, args = args[0], args[1:] + + if len(args) == 2: + x, y = args + elif len(args) == 1: + x, y = *args, None + else: + x = y = None + + for name, var in zip("yx", (y, x)): + if var is not None: + if name in variables: + raise TypeError(f"`{name}` given by both name and position.") + # Keep coordinates at the front of the variables dict + variables = {name: var, **variables} + + return data, variables + + def __add__(self, other): + + if isinstance(other, Mark) or isinstance(other, Stat): + raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?") + + other_type = other.__class__.__name__ + raise TypeError(f"Unsupported operand type(s) for +: 'Plot' and '{other_type}") + + def _repr_png_(self) -> tuple[bytes, dict[str, float]]: + + return self.plot()._repr_png_() + + # TODO _repr_svg_? + + def _clone(self) -> Plot: + """Generate a new object with the same information as the current spec.""" + new = Plot() + + # TODO any way to enforce that data does not get mutated? + new._data = self._data + + new._layers.extend(self._layers) + new._scales.update(self._scales) + + new._subplot_spec.update(self._subplot_spec) + new._facet_spec.update(self._facet_spec) + new._pair_spec.update(self._pair_spec) + + new._target = self._target + + return new + + @property + def _variables(self) -> list[str]: + + variables = ( + list(self._data.frame) + + list(self._pair_spec.get("variables", [])) + + list(self._facet_spec.get("variables", [])) + ) + for layer in self._layers: + variables.extend(c for c in layer["vars"] if c not in variables) + return variables + + def on(self, target: Axes | SubFigure | Figure) -> Plot: + """ + Draw the plot into an existing Matplotlib object. + + Parameters + ---------- + target : Axes, SubFigure, or Figure + Matplotlib object to use. Passing :class:`matplotlib.axes.Axes` will add + artists without otherwise modifying the figure. Otherwise, subplots will be + created within the space of the given :class:`matplotlib.figure.Figure` or + :class:`matplotlib.figure.SubFigure`. + + """ + # TODO alternate name: target? + + accepted_types: tuple # Allow tuple of various length + if hasattr(mpl.figure, "SubFigure"): # Added in mpl 3.4 + accepted_types = ( + mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure + ) + accepted_types_str = ( + f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}" + ) + else: + accepted_types = mpl.axes.Axes, mpl.figure.Figure + accepted_types_str = f"{mpl.axes.Axes} or {mpl.figure.Figure}" + + if not isinstance(target, accepted_types): + err = ( + f"The `Plot.on` target must be an instance of {accepted_types_str}. " + f"You passed an instance of {target.__class__} instead." + ) + raise TypeError(err) + + new = self._clone() + new._target = target + + return new + + def add( + self, + mark: Mark, + stat: Stat | None = None, + move: Move | None = None, # TODO or list[Move] + *, + orient: str | None = None, + data: DataSource = None, + **variables: VariableSpec, + ) -> Plot: + """ + Define a layer of the visualization. + + This is the main method for specifying how the data should be visualized. + It can be called multiple times with different arguments to define + a plot with multiple layers. + + Parameters + ---------- + mark : :class:`seaborn.objects.Mark` + The visual representation of the data to use in this layer. + stat : :class:`seaborn.objects.Stat` + A transformation applied to the data before plotting. + move : :class:`seaborn.objects.Move` + Additional transformation(s) to handle over-plotting. + orient : "x", "y", "v", or "h" + The orientation of the mark, which affects how the stat is computed. + Typically corresponds to the axis that defines groups for aggregation. + The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y", + but may be more intuitive with some marks. When not provided, an + orientation will be inferred from characteristics of the data and scales. + data : DataFrame or dict + Data source to override the global source provided in the constructor. + variables : data vectors or identifiers + Additional layer-specific variables, including variables that will be + passed directly to the stat without scaling. + + """ + if not isinstance(mark, Mark): + msg = f"mark must be a Mark instance, not {type(mark)!r}." + raise TypeError(msg) + + if stat is not None and not isinstance(stat, Stat): + msg = f"stat must be a Stat instance, not {type(stat)!r}." + raise TypeError(msg) + + # TODO decide how to allow Mark to have default Stat/Move + # if stat is None and hasattr(mark, "default_stat"): + # stat = mark.default_stat() + + # TODO it doesn't work to supply scalars to variables, but that would be nice + + # TODO accept arbitrary variables defined by the stat (/move?) here + # (but not in the Plot constructor) + # Should stat variables ever go in the constructor, or just in the add call? + + new = self._clone() + new._layers.append({ + "mark": mark, + "stat": stat, + "move": move, + "vars": variables, + "source": data, + "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore + }) + + return new + + def pair( + self, + x: list[Hashable] | Index[Hashable] | None = None, + y: list[Hashable] | Index[Hashable] | None = None, + wrap: int | None = None, + cross: bool = True, + # TODO other existing PairGrid things like corner? + # TODO transpose, so that e.g. multiple y axes go across the columns + ) -> Plot: + """ + Produce subplots with distinct `x` and/or `y` variables. + + Parameters + ---------- + x, y : sequence(s) of data identifiers + Variables that will define the grid of subplots. + wrap : int + Maximum height/width of the grid, with additional subplots "wrapped" + on the other dimension. Requires that only one of `x` or `y` are set here. + cross : bool + When True, define a two-dimensional grid using the Cartesian product of `x` + and `y`. Otherwise, define a one-dimensional grid by pairing `x` and `y` + entries in by position. + + """ + # TODO Problems to solve: + # + # - Unclear is how to handle the diagonal plots that PairGrid offers + # + # - Implementing this will require lots of downscale changes in figure setup, + # and especially the axis scaling, which will need to be pair specific + + # TODO lists of vectors currently work, but I'm not sure where best to test + # Will need to update the signature typing to keep them + + # TODO is it weird to call .pair() to create univariate plots? + # i.e. Plot(data).pair(x=[...]). The basic logic is fine. + # But maybe a different verb (e.g. Plot.spread) would be more clear? + # Then Plot(data).pair(x=[...]) would show the given x vars vs all. + + # TODO would like to add transpose=True, which would then draw + # Plot(x=...).pair(y=[...]) across the rows + # This may also be possible by setting `wrap=1`, although currently the axes + # are shared and the interior labels are disabeled (this is a bug either way) + + pair_spec: PairSpec = {} + + if x is None and y is None: + + # Default to using all columns in the input source data, aside from + # those that were assigned to a variable in the constructor + # TODO Do we want to allow additional filtering by variable type? + # (Possibly even default to using only numeric columns) + + if self._data.source_data is None: + err = "You must pass `data` in the constructor to use default pairing." + raise RuntimeError(err) + + all_unused_columns = [ + key for key in self._data.source_data + if key not in self._data.names.values() + ] + if "x" not in self._data: + x = all_unused_columns + if "y" not in self._data: + y = all_unused_columns + + axes = {"x": [] if x is None else x, "y": [] if y is None else y} + for axis, arg in axes.items(): + if isinstance(arg, (str, int)): + err = f"You must pass a sequence of variable keys to `{axis}`" + raise TypeError(err) + + pair_spec["variables"] = {} + pair_spec["structure"] = {} + + for axis in "xy": + keys = [] + for i, col in enumerate(axes[axis]): + key = f"{axis}{i}" + keys.append(key) + pair_spec["variables"][key] = col + + if keys: + pair_spec["structure"][axis] = keys + + # TODO raise here if cross is False and len(x) != len(y)? + pair_spec["cross"] = cross + pair_spec["wrap"] = wrap + + new = self._clone() + new._pair_spec.update(pair_spec) + return new + + def facet( + self, + # TODO require kwargs? + col: VariableSpec = None, + row: VariableSpec = None, + order: OrderSpec | dict[str, OrderSpec] = None, + wrap: int | None = None, + ) -> Plot: + """ + Produce subplots with conditional subsets of the data. + + Parameters + ---------- + col, row : data vectors or identifiers + Variables used to define subsets along the columns and/or rows of the grid. + Can be references to the global data source passed in the constructor. + order : list of strings, or dict with dimensional keys + Define the order of the faceting variables. + wrap : int + Maximum height/width of the grid, with additional subplots "wrapped" + on the other dimension. Requires that only one of `x` or `y` are set here. + + """ + variables = {} + if col is not None: + variables["col"] = col + if row is not None: + variables["row"] = row + + structure = {} + if isinstance(order, dict): + for dim in ["col", "row"]: + dim_order = order.get(dim) + if dim_order is not None: + structure[dim] = list(dim_order) + elif order is not None: + if col is not None and row is not None: + err = " ".join([ + "When faceting on both col= and row=, passing `order` as a list" + "is ambiguous. Use a dict with 'col' and/or 'row' keys instead." + ]) + raise RuntimeError(err) + elif col is not None: + structure["col"] = list(order) + elif row is not None: + structure["row"] = list(order) + + spec: FacetSpec = { + "variables": variables, + "structure": structure, + "wrap": wrap, + } + + new = self._clone() + new._facet_spec.update(spec) + + return new + + # TODO def twin()? + + def scale(self, **scales: ScaleSpec) -> Plot: + """ + Control mappings from data units to visual properties. + + Keywords correspond to variables defined in the plot, including coordinate + variables (`x`, `y`) and semantic variables (`color`, `pointsize`, etc.). + + A number of "magic" arguments are accepted, including: + - The name of a transform (e.g., `"log"`, `"sqrt"`) + - The name of a palette (e.g., `"viridis"`, `"muted"`) + - A tuple of values, defining the output range (e.g. `(1, 5)`) + - A dict, implying a :class:`Nominal` scale (e.g. `{"a": .2, "b": .5}`) + - A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`) + + For more explicit control, pass a scale spec object such as :class:`Continuous` + or :class:`Nominal`. Or use `None` to use an "identity" scale, which treats data + values as literally encoding visual properties. + + """ + new = self._clone() + new._scales.update(**scales) + return new + + def configure( + self, + figsize: tuple[float, float] | None = None, + sharex: bool | str | None = None, + sharey: bool | str | None = None, + ) -> Plot: + """ + Control the figure size and layout. + + Parameters + ---------- + figsize: (width, height) + Size of the resulting figure, in inches. + sharex, sharey : bool, "row", or "col" + Whether axis limits should be shared across subplots. Boolean values apply + across the entire grid, whereas `"row"` or `"col"` have a smaller scope. + Shared axes will have tick labels disabled. + + """ + # TODO add an "auto" mode for figsize that roughly scales with the rcParams + # figsize (so that works), but expands to prevent subplots from being squished + # Also should we have height=, aspect=, exclusive with figsize? Or working + # with figsize when only one is defined? + + new = self._clone() + + # TODO this is a hack; make a proper figure spec object + new._figsize = figsize # type: ignore + + if sharex is not None: + new._subplot_spec["sharex"] = sharex + if sharey is not None: + new._subplot_spec["sharey"] = sharey + + return new + + # TODO def legend (ugh) + + def theme(self) -> Plot: + """ + Control the default appearance of elements in the plot. + + TODO + """ + # TODO Plot-specific themes using the seaborn theming system + raise NotImplementedError() + new = self._clone() + return new + + # TODO decorate? (or similar, for various texts) alt names: label? + + def save(self, fname, **kwargs) -> Plot: + """ + Render the plot and write it to a buffer or file on disk. + + Parameters + ---------- + fname : str, path, or buffer + Location on disk to save the figure, or a buffer to write into. + Other keyword arguments are passed to :meth:`matplotlib.figure.Figure.savefig`. + + """ + # TODO expose important keyword arguments in our signature? + self.plot().save(fname, **kwargs) + return self + + def plot(self, pyplot=False) -> Plotter: + """ + Compile the plot and return the :class:`Plotter` engine. + + """ + # TODO if we have _target object, pyplot should be determined by whether it + # is hooked into the pyplot state machine (how do we check?) + + plotter = Plotter(pyplot=pyplot) + + common, layers = plotter._extract_data(self) + plotter._setup_figure(self, common, layers) + plotter._transform_coords(self, common, layers) + + plotter._compute_stats(self, layers) + plotter._setup_scales(self, layers) + + # TODO Remove these after updating other methods + # ---- Maybe have debug= param that attaches these when True? + plotter._data = common + plotter._layers = layers + + for layer in layers: + plotter._plot_layer(self, layer) + + plotter._make_legend() + + # TODO this should be configurable + if not plotter._figure.get_constrained_layout(): + plotter._figure.set_tight_layout(True) + + return plotter + + def show(self, **kwargs) -> None: + """ + Render and display the plot. + + """ + # TODO make pyplot configurable at the class level, and when not using, + # import IPython.display and call on self to populate cell output? + + # Keep an eye on whether matplotlib implements "attaching" an existing + # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 + + self.plot(pyplot=True).show(**kwargs) + + +# ---- The plot compilation engine ---------------------------------------------- # + + +class Plotter: + """ + Engine for compiling a :class:`Plot` spec into a Matplotlib figure. + + This class is not intended to be instantiated directly by users. + + """ + # TODO decide if we ever want these (Plot.plot(debug=True))? + _data: PlotData + _layers: list[Layer] + _figure: Figure + + def __init__(self, pyplot=False): + + self.pyplot = pyplot + self._legend_contents: list[ + tuple[str, str | int], list[Artist], list[str], + ] = [] + self._scales: dict[str, Scale] = {} + + def save(self, fname, **kwargs) -> Plotter: + kwargs.setdefault("dpi", 96) + self._figure.savefig(os.path.expanduser(fname), **kwargs) + return self + + def show(self, **kwargs) -> None: + # TODO if we did not create the Plotter with pyplot, is it possible to do this? + # If not we should clearly raise. + import matplotlib.pyplot as plt + plt.show(**kwargs) + + # TODO API for accessing the underlying matplotlib objects + # TODO what else is useful in the public API for this class? + + def _repr_png_(self) -> tuple[bytes, dict[str, float]]: + + # TODO better to do this through a Jupyter hook? e.g. + # ipy = IPython.core.formatters.get_ipython() + # fmt = ipy.display_formatter.formatters["text/html"] + # fmt.for_type(Plot, ...) + # Would like to have a svg option too, not sure how to make that flexible + + # TODO use matplotlib backend directly instead of going through savefig? + + # TODO perhaps have self.show() flip a switch to disable this, so that + # user does not end up with two versions of the figure in the output + + # TODO use bbox_inches="tight" like the inline backend? + # pro: better results, con: (sometimes) confusing results + # Better solution would be to default (with option to change) + # to using constrained/tight layout. + + # TODO need to decide what the right default behavior here is: + # - Use dpi=72 to match default InlineBackend figure size? + # - Accept a generic "scaling" somewhere and scale DPI from that, + # either with 1x -> 72 or 1x -> 96 and the default scaling be .75? + # - Listen to rcParams? InlineBackend behavior makes that so complicated :( + # - Do we ever want to *not* use retina mode at this point? + + from PIL import Image + + dpi = 96 + buffer = io.BytesIO() + self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight") + data = buffer.getvalue() + + scaling = .85 / 2 + # w, h = self._figure.get_size_inches() + w, h = Image.open(buffer).size + metadata = {"width": w * scaling, "height": h * scaling} + return data, metadata + + def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]: + + common_data = ( + p._data + .join(None, p._facet_spec.get("variables")) + .join(None, p._pair_spec.get("variables")) + ) + + layers: list[Layer] = [] + for layer in p._layers: + spec = layer.copy() + spec["data"] = common_data.join(layer.get("source"), layer.get("vars")) + layers.append(spec) + + return common_data, layers + + def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: + + # --- Parsing the faceting/pairing parameterization to specify figure grid + + # TODO use context manager with theme that has been set + # TODO (maybe wrap THIS function with context manager; would be cleaner) + + subplot_spec = p._subplot_spec.copy() + facet_spec = p._facet_spec.copy() + pair_spec = p._pair_spec.copy() + + for dim in ["col", "row"]: + if dim in common.frame and dim not in facet_spec["structure"]: + order = categorical_order(common.frame[dim]) + facet_spec["structure"][dim] = order + + self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec) + + # --- Figure initialization + figure_kws = {"figsize": getattr(p, "_figsize", None)} # TODO fix + self._figure = subplots.init_figure( + pair_spec, self.pyplot, figure_kws, p._target, + ) + + # --- Figure annotation + for sub in subplots: + ax = sub["ax"] + for axis in "xy": + axis_key = sub[axis] + # TODO Should we make it possible to use only one x/y label for + # all rows/columns in a faceted plot? Maybe using sub{axis}label, + # although the alignments of the labels from that method leaves + # something to be desired (in terms of how it defines 'centered'). + names = [ + common.names.get(axis_key), + *(layer["data"].names.get(axis_key) for layer in layers) + ] + label = next((name for name in names if name is not None), None) + ax.set(**{f"{axis}label": label}) + + # TODO there should be some override (in Plot.configure?) so that + # tick labels can be shown on interior shared axes + axis_obj = getattr(ax, f"{axis}axis") + visible_side = {"x": "bottom", "y": "left"}.get(axis) + show_axis_label = ( + sub[visible_side] + or axis in p._pair_spec and bool(p._pair_spec.get("wrap")) + or not p._pair_spec.get("cross", True) + ) + axis_obj.get_label().set_visible(show_axis_label) + show_tick_labels = ( + show_axis_label + or subplot_spec.get(f"share{axis}") not in ( + True, "all", {"x": "col", "y": "row"}[axis] + ) + ) + for group in ("major", "minor"): + for t in getattr(axis_obj, f"get_{group}ticklabels")(): + t.set_visible(show_tick_labels) + + # TODO title template should be configurable + # ---- Also we want right-side titles for row facets in most cases? + # ---- Or wrapped? That can get annoying too. + # TODO should configure() accept a title= kwarg (for single subplot plots)? + # Let's have what we currently call "margin titles" but properly using the + # ax.set_title interface (see my gist) + title_parts = [] + for dim in ["row", "col"]: + if sub[dim] is not None: + name = common.names.get(dim) # TODO None = val looks bad + title_parts.append(f"{name} = {sub[dim]}") + + has_col = sub["col"] is not None + has_row = sub["row"] is not None + show_title = ( + has_col and has_row + or (has_col or has_row) and p._facet_spec.get("wrap") + or (has_col and sub["top"]) + # TODO or has_row and sub["right"] and + or has_row # TODO and not + ) + if title_parts: + title = " | ".join(title_parts) + title_text = ax.set_title(title) + title_text.set_visible(show_title) + + def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: + + for var in p._variables: + + # Parse name to identify variable (x, y, xmin, etc.) and axis (x/y) + # TODO should we have xmin0/xmin1 or x0min/x1min? + m = re.match(r"^(?P(?P[x|y])\d*).*", var) + + if m is None: + continue + + prefix = m["prefix"] + axis = m["axis"] + + share_state = self._subplots.subplot_spec[f"share{axis}"] + + # Concatenate layers, using only the relevant coordinate and faceting vars, + # This is unnecessarily wasteful, as layer data will often be redundant. + # But figuring out the minimal amount we need is more complicated. + cols = [var, "col", "row"] + # TODO basically copied from _setup_scales, and very clumsy + layer_values = [common.frame.filter(cols)] + for layer in layers: + if layer["data"].frame is None: + for df in layer["data"].frames.values(): + layer_values.append(df.filter(cols)) + else: + layer_values.append(layer["data"].frame.filter(cols)) + + if layer_values: + var_df = pd.concat(layer_values, ignore_index=True) + else: + var_df = pd.DataFrame(columns=cols) + + prop = Coordinate(axis) + scale_spec = self._get_scale(p, prefix, prop, var_df[var]) + + # Shared categorical axes are broken on matplotlib<3.4.0. + # https://github.com/matplotlib/matplotlib/pull/18308 + # This only affects us when sharing *paired* axes. This is a novel/niche + # behavior, so we will raise rather than hack together a workaround. + if Version(mpl.__version__) < Version("3.4.0"): + from seaborn._core.scales import Nominal + paired_axis = axis in p._pair_spec + cat_scale = isinstance(scale_spec, Nominal) + ok_dim = {"x": "col", "y": "row"}[axis] + shared_axes = share_state not in [False, "none", ok_dim] + if paired_axis and cat_scale and shared_axes: + err = "Sharing paired categorical axes requires matplotlib>=3.4.0" + raise RuntimeError(err) + + # Now loop through each subplot, deriving the relevant seed data to setup + # the scale (so that axis units / categories are initialized properly) + # And then scale the data in each layer. + subplots = [view for view in self._subplots if view[axis] == prefix] + + # Setup the scale on all of the data and plug it into self._scales + # We do this because by the time we do self._setup_scales, coordinate data + # will have been converted to floats already, so scale inference fails + self._scales[var] = scale_spec.setup(var_df[var], prop) + + # Set up an empty series to receive the transformed values. + # We need this to handle piecemeal tranforms of categories -> floats. + transformed_data = [] + for layer in layers: + index = layer["data"].frame.index + transformed_data.append(pd.Series(dtype=float, index=index, name=var)) + + for view in subplots: + axis_obj = getattr(view["ax"], f"{axis}axis") + + if share_state in [True, "all"]: + # The all-shared case is easiest, every subplot sees all the data + seed_values = var_df[var] + else: + # Otherwise, we need to setup separate scales for different subplots + if share_state in [False, "none"]: + # Fully independent axes are also easy: use each subplot's data + idx = self._get_subplot_index(var_df, view) + elif share_state in var_df: + # Sharing within row/col is more complicated + use_rows = var_df[share_state] == view[share_state] + idx = var_df.index[use_rows] + else: + # This configuration doesn't make much sense, but it's fine + idx = var_df.index + + seed_values = var_df.loc[idx, var] + + scale = scale_spec.setup(seed_values, prop, axis=axis_obj) + + for layer, new_series in zip(layers, transformed_data): + layer_df = layer["data"].frame + if var in layer_df: + idx = self._get_subplot_index(layer_df, view) + new_series.loc[idx] = scale(layer_df.loc[idx, var]) + + # TODO need decision about whether to do this or modify axis transform + set_scale_obj(view["ax"], axis, scale.matplotlib_scale) + + # Now the transformed data series are complete, set update the layer data + for layer, new_series in zip(layers, transformed_data): + layer_df = layer["data"].frame + if var in layer_df: + layer_df[var] = new_series + + def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None: + + grouping_vars = [v for v in PROPERTIES if v not in "xy"] + grouping_vars += ["col", "row", "group"] + + pair_vars = spec._pair_spec.get("structure", {}) + + for layer in layers: + + data = layer["data"] + mark = layer["mark"] + stat = layer["stat"] + + if stat is None: + continue + + iter_axes = itertools.product(*[ + pair_vars.get(axis, [axis]) for axis in "xy" + ]) + + old = data.frame + + if pair_vars: + data.frames = {} + data.frame = data.frame.iloc[:0] # TODO to simplify typing + + for coord_vars in iter_axes: + + pairings = "xy", coord_vars + + df = old.copy() + scales = self._scales.copy() + + for axis, var in zip(*pairings): + if axis != var: + df = df.rename(columns={var: axis}) + drop_cols = [x for x in df if re.match(rf"{axis}\d+", x)] + df = df.drop(drop_cols, axis=1) + scales[axis] = scales[var] + + orient = layer["orient"] or mark._infer_orient(scales) + + if stat.group_by_orient: + grouper = [orient, *grouping_vars] + else: + grouper = grouping_vars + groupby = GroupBy(grouper) + res = stat(df, groupby, orient, scales) + + if pair_vars: + data.frames[coord_vars] = res + else: + data.frame = res + + def _get_scale( + self, spec: Plot, var: str, prop: Property, values: Series + ) -> ScaleSpec: + + if var in spec._scales: + arg = spec._scales[var] + if arg is None or isinstance(arg, ScaleSpec): + scale = arg + else: + scale = prop.infer_scale(arg, values) + else: + scale = prop.default_scale(values) + + return scale + + def _setup_scales(self, p: Plot, layers: list[Layer]) -> None: + + # Identify all of the variables that will be used at some point in the plot + variables = set() + for layer in layers: + if layer["data"].frame.empty and layer["data"].frames: + for df in layer["data"].frames.values(): + variables.update(df.columns) + else: + variables.update(layer["data"].frame.columns) + + for var in variables: + + if var in self._scales: + # Scales for coordinate variables added in _transform_coords + continue + + # Get the data all the distinct appearances of this variable. + parts = [] + for layer in layers: + if layer["data"].frame.empty and layer["data"].frames: + for df in layer["data"].frames.values(): + parts.append(df.get(var)) + else: + parts.append(layer["data"].frame.get(var)) + var_values = pd.concat( + parts, axis=0, join="inner", ignore_index=True + ).rename(var) + + # Determine whether this is an coordinate variable + # (i.e., x/y, paired x/y, or derivative such as xmax) + m = re.match(r"^(?P(?Px|y)\d*).*", var) + if m is None: + axis = None + else: + var = m["prefix"] + axis = m["axis"] + + prop = PROPERTIES.get(var if axis is None else axis, Property()) + scale = self._get_scale(p, var, prop, var_values) + + # Initialize the data-dependent parameters of the scale + # Note that this returns a copy and does not mutate the original + # This dictionary is used by the semantic mappings + if scale is None: + # TODO what is the cleanest way to implement identity scale? + # We don't really need a ScaleSpec, and Identity() will be + # overloaded anyway (but maybe a general Identity object + # that can be used as Scale/Mark/Stat/Move?) + # Note that this may not be the right spacer to use + # (but that is only relevant for coordinates where identity scale + # doesn't make sense or is poorly defined — should it mean "pixes"?) + self._scales[var] = Scale([], lambda x: x, None, "identity", None) + else: + self._scales[var] = scale.setup(var_values, prop) + + def _plot_layer(self, p: Plot, layer: Layer) -> None: + + data = layer["data"] + mark = layer["mark"] + move = layer["move"] + + default_grouping_vars = ["col", "row", "group"] # TODO where best to define? + grouping_properties = [v for v in PROPERTIES if v not in "xy"] + + pair_variables = p._pair_spec.get("structure", {}) + + for subplots, df, scales in self._generate_pairings(data, pair_variables): + + orient = layer["orient"] or mark._infer_orient(scales) + + def get_order(var): + # Ignore order for x/y: they have been scaled to numeric indices, + # so any original order is no longer valid. Default ordering rules + # sorted unique numbers will correctly reconstruct intended order + # TODO This is tricky, make sure we add some tests for this + if var not in "xy" and var in scales: + return scales[var].order + + if "width" in mark._mappable_props: + width = mark._resolve(df, "width", None) + else: + width = df.get("width", 0.8) # TODO what default + if orient in df: + df["width"] = width * scales[orient].spacing(df[orient]) + + if "baseline" in mark._mappable_props: + # TODO what marks should have this? + # If we can set baseline with, e.g., Bar(), then the + # "other" (e.g. y for x oriented bars) parameterization + # is somewhat ambiguous. + baseline = mark._resolve(df, "baseline", None) + else: + # TODO unlike width, we might not want to add baseline to data + # if the mark doesn't use it. Practically, there is a concern about + # Mark abstraction like Area / Ribbon + baseline = df.get("baseline", 0) + df["baseline"] = baseline + + if move is not None: + moves = move if isinstance(move, list) else [move] + for move in moves: + move_groupers = [ + orient, + *(getattr(move, "by", None) or grouping_properties), + *default_grouping_vars, + ] + order = {var: get_order(var) for var in move_groupers} + groupby = GroupBy(order) + df = move(df, groupby, orient) + + df = self._unscale_coords(subplots, df, orient) + + grouping_vars = mark._grouping_props + default_grouping_vars + split_generator = self._setup_split_generator( + grouping_vars, df, subplots + ) + + mark._plot(split_generator, scales, orient) + + # TODO is this the right place for this? + for view in self._subplots: + view["ax"].autoscale_view() + + self._update_legend_contents(mark, data, scales) + + def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: + # TODO stricter type on subplots + + coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] + out_df = ( + df + .copy(deep=False) + .drop(coord_cols, axis=1) + .reindex(df.columns, axis=1) # So unscaled columns retain their place + ) + + for view in subplots: + view_df = self._filter_subplot_data(df, view) + axes_df = view_df[coord_cols] + with pd.option_context("mode.use_inf_as_null", True): + axes_df = axes_df.dropna() + for var, values in axes_df.items(): + scale = view[f"{var[0]}scale"] + out_df.loc[values.index, var] = scale(values) + + return out_df + + def _unscale_coords( + self, subplots: list[dict], df: DataFrame, orient: str, + ) -> DataFrame: + # TODO do we still have numbers in the variable name at this point? + coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] + drop_cols = [*coord_cols, "width"] if "width" in df else coord_cols + out_df = ( + df + .drop(drop_cols, axis=1) + .reindex(df.columns, axis=1) # So unscaled columns retain their place + .copy(deep=False) + ) + + for view in subplots: + view_df = self._filter_subplot_data(df, view) + axes_df = view_df[coord_cols] + for var, values in axes_df.items(): + + axis = getattr(view["ax"], f"{var[0]}axis") + # TODO see https://github.com/matplotlib/matplotlib/issues/22713 + transform = axis.get_transform().inverted().transform + inverted = transform(values) + out_df.loc[values.index, var] = inverted + + if var == orient and "width" in view_df: + width = view_df["width"] + out_df.loc[values.index, "width"] = ( + transform(values + width / 2) - transform(values - width / 2) + ) + + return out_df + + def _generate_pairings( + self, data: PlotData, pair_variables: dict, + ) -> Generator[ + tuple[list[dict], DataFrame, dict[str, Scale]], None, None + ]: + # TODO retype return with subplot_spec or similar + + iter_axes = itertools.product(*[ + pair_variables.get(axis, [axis]) for axis in "xy" + ]) + + for x, y in iter_axes: + + subplots = [] + for view in self._subplots: + if (view["x"] == x) and (view["y"] == y): + subplots.append(view) + + if data.frame.empty and data.frames: + out_df = data.frames[(x, y)].copy() + elif not pair_variables: + out_df = data.frame.copy() + else: + if data.frame.empty and data.frames: + out_df = data.frames[(x, y)].copy() + else: + out_df = data.frame.copy() + + scales = self._scales.copy() + if x in out_df: + scales["x"] = self._scales[x] + if y in out_df: + scales["y"] = self._scales[y] + + for axis, var in zip("xy", (x, y)): + if axis != var: + out_df = out_df.rename(columns={var: axis}) + cols = [col for col in out_df if re.match(rf"{axis}\d+", col)] + out_df = out_df.drop(cols, axis=1) + + yield subplots, out_df, scales + + def _get_subplot_index(self, df: DataFrame, subplot: dict) -> DataFrame: + + dims = df.columns.intersection(["col", "row"]) + if dims.empty: + return df.index + + keep_rows = pd.Series(True, df.index, dtype=bool) + for dim in dims: + keep_rows &= df[dim] == subplot[dim] + return df.index[keep_rows] + + def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame: + # TODO note redundancies with preceding function ... needs refactoring + dims = df.columns.intersection(["col", "row"]) + if dims.empty: + return df + + keep_rows = pd.Series(True, df.index, dtype=bool) + for dim in dims: + keep_rows &= df[dim] == subplot[dim] + return df[keep_rows] + + def _setup_split_generator( + self, grouping_vars: list[str], df: DataFrame, subplots: list[dict[str, Any]], + ) -> Callable[[], Generator]: + + allow_empty = False # TODO will need to recreate previous categorical plots + + grouping_keys = [] + grouping_vars = [ + v for v in grouping_vars if v in df and v not in ["col", "row"] + ] + for var in grouping_vars: + order = self._scales[var].order + if order is None: + order = categorical_order(df[var]) + grouping_keys.append(order) + + def split_generator() -> Generator: + + for view in subplots: + + axes_df = self._filter_subplot_data(df, view) + + subplot_keys = {} + for dim in ["col", "row"]: + if view[dim] is not None: + subplot_keys[dim] = view[dim] + + if not grouping_vars or not any(grouping_keys): + yield subplot_keys, axes_df.copy(), view["ax"] + continue + + grouped_df = axes_df.groupby(grouping_vars, sort=False, as_index=False) + + for key in itertools.product(*grouping_keys): + + # Pandas fails with singleton tuple inputs + pd_key = key[0] if len(key) == 1 else key + + try: + df_subset = grouped_df.get_group(pd_key) + except KeyError: + # TODO (from initial work on categorical plots refactor) + # We are adding this to allow backwards compatability + # with the empty artists that old categorical plots would + # add (before 0.12), which we may decide to break, in which + # case this option could be removed + df_subset = axes_df.loc[[]] + + if df_subset.empty and not allow_empty: + continue + + sub_vars = dict(zip(grouping_vars, key)) + sub_vars.update(subplot_keys) + + # TODO need copy(deep=...) policy (here, above, anywhere else?) + yield sub_vars, df_subset.copy(), view["ax"] + + return split_generator + + def _update_legend_contents( + self, mark: Mark, data: PlotData, scales: dict[str, Scale] + ) -> None: + """Add legend artists / labels for one layer in the plot.""" + if data.frame.empty and data.frames: + legend_vars = set() + for frame in data.frames.values(): + legend_vars.update(frame.columns.intersection(scales)) + else: + legend_vars = data.frame.columns.intersection(scales) + + # First pass: Identify the values that will be shown for each variable + schema: list[tuple[ + tuple[str | None, str | int], list[str], tuple[list, list[str]] + ]] = [] + schema = [] + for var in legend_vars: + var_legend = scales[var].legend + if var_legend is not None: + values, labels = var_legend + for (_, part_id), part_vars, _ in schema: + if data.ids[var] == part_id: + # Allow multiple plot semantics to represent same data variable + part_vars.append(var) + break + else: + entry = (data.names[var], data.ids[var]), [var], (values, labels) + schema.append(entry) + + # Second pass, generate an artist corresponding to each value + contents = [] + for key, variables, (values, labels) in schema: + artists = [] + for val in values: + artists.append(mark._legend_artist(variables, val, scales)) + contents.append((key, artists, labels)) + + self._legend_contents.extend(contents) + + def _make_legend(self) -> None: + """Create the legend artist(s) and add onto the figure.""" + # Combine artists representing same information across layers + # Input list has an entry for each distinct variable in each layer + # Output dict has an entry for each distinct variable + merged_contents: dict[ + tuple[str | None, str | int], tuple[list[Artist], list[str]], + ] = {} + for key, artists, labels in self._legend_contents: + # Key is (name, id); we need the id to resolve variable uniqueness, + # but will need the name in the next step to title the legend + if key in merged_contents: + # Copy so inplace updates don't propagate back to legend_contents + existing_artists = merged_contents[key][0] + for i, artist in enumerate(existing_artists): + # Matplotlib accepts a tuple of artists and will overlay them + if isinstance(artist, tuple): + artist += artist[i], + else: + existing_artists[i] = artist, artists[i] + else: + merged_contents[key] = artists.copy(), labels + + base_legend = None + for (name, _), (handles, labels) in merged_contents.items(): + + legend = mpl.legend.Legend( + self._figure, + handles, + labels, + title=name, # TODO don't show "None" as title + loc="center left", + bbox_to_anchor=(.98, .55), + ) + + # TODO: This is an illegal hack accessing private attributes on the legend + # We need to sort out how we are going to handle this given that lack of a + # proper API to do things like position legends relative to each other + if base_legend: + base_legend._legend_box._children.extend(legend._legend_box._children) + else: + base_legend = legend + self._figure.legends.append(legend) diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py new file mode 100644 index 0000000000..68836cd045 --- /dev/null +++ b/seaborn/_core/properties.py @@ -0,0 +1,761 @@ +from __future__ import annotations +import itertools +import warnings + +import numpy as np +from pandas import Series +import matplotlib as mpl +from matplotlib.colors import to_rgb, to_rgba, to_rgba_array +from matplotlib.path import Path + +from seaborn._core.scales import ScaleSpec, Nominal, Continuous, Temporal +from seaborn._core.rules import categorical_order, variable_type +from seaborn._compat import MarkerStyle +from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette +from seaborn.utils import get_color_cycle + +from typing import Any, Callable, Tuple, List, Union, Optional + +try: + from numpy.typing import ArrayLike +except ImportError: + # numpy<1.20.0 (Jan 2021) + ArrayLike = Any + +RGBTuple = Tuple[float, float, float] +RGBATuple = Tuple[float, float, float, float] +ColorSpec = Union[RGBTuple, RGBATuple, str] + +DashPattern = Tuple[float, ...] +DashPatternWithOffset = Tuple[float, Optional[DashPattern]] + +MarkerPattern = Union[ + float, + str, + Tuple[int, int, float], + List[Tuple[float, float]], + Path, + MarkerStyle, +] + + +# =================================================================================== # +# Base classes +# =================================================================================== # + + +class Property: + """Base class for visual properties that can be set directly or be data scaling.""" + + # When True, scales for this property will populate the legend by default + legend = False + + # When True, scales for this property normalize data to [0, 1] before mapping + normed = False + + def __init__(self, variable: str | None = None): + """Initialize the property with the name of the corresponding plot variable.""" + if not variable: + variable = self.__class__.__name__.lower() + self.variable = variable + + def default_scale(self, data: Series) -> ScaleSpec: + """Given data, initialize appropriate scale class.""" + # TODO allow variable_type to be "boolean" if that's a scale? + # TODO how will this handle data with units that can be treated as numeric + # if passed through a registered matplotlib converter? + var_type = variable_type(data, boolean_type="numeric") + if var_type == "numeric": + return Continuous() + elif var_type == "datetime": + return Temporal() + # TODO others + # time-based (TimeStamp, TimeDelta, Period) + # boolean scale? + else: + return Nominal() + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + """Given data and a scaling argument, initialize appropriate scale class.""" + # TODO put these somewhere external for validation + # TODO putting this here won't pick it up if subclasses define infer_scale + # (e.g. color). How best to handle that? One option is to call super after + # handling property-specific possibilities (e.g. for color check that the + # arg is not a valid palette name) but that could get tricky. + trans_args = ["log", "symlog", "logit", "pow", "sqrt"] + if isinstance(arg, str): + if any(arg.startswith(k) for k in trans_args): + # TODO validate numeric type? That should happen centrally somewhere + return Continuous(transform=arg) + else: + msg = f"Unknown magic arg for {self.variable} scale: '{arg}'." + raise ValueError(msg) + else: + arg_type = type(arg).__name__ + msg = f"Magic arg for {self.variable} scale must be str, not {arg_type}." + raise TypeError(msg) + + def get_mapping( + self, scale: ScaleSpec, data: Series + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps from data domain to property range.""" + def identity(x): + return x + return identity + + def standardize(self, val: Any) -> Any: + """Coerce flexible property value to standardized representation.""" + return val + + def _check_dict_entries(self, levels: list, values: dict) -> None: + """Input check when values are provided as a dictionary.""" + missing = set(levels) - set(values) + if missing: + formatted = ", ".join(map(repr, sorted(missing, key=str))) + err = f"No entry in {self.variable} dictionary for {formatted}" + raise ValueError(err) + + def _check_list_length(self, levels: list, values: list) -> list: + """Input check when values are provided as a list.""" + message = "" + if len(levels) > len(values): + message = " ".join([ + f"\nThe {self.variable} list has fewer values ({len(values)})", + f"than needed ({len(levels)}) and will cycle, which may", + "produce an uninterpretable plot." + ]) + values = [x for _, x in zip(levels, itertools.cycle(values))] + + elif len(values) > len(levels): + message = " ".join([ + f"The {self.variable} list has more values ({len(values)})", + f"than needed ({len(levels)}), which may not be intended.", + ]) + values = values[:len(levels)] + + # TODO look into custom PlotSpecWarning with better formatting + if message: + warnings.warn(message, UserWarning) + + return values + + +# =================================================================================== # +# Properties relating to spatial position of marks on the plotting axes +# =================================================================================== # + + +class Coordinate(Property): + """The position of visual marks with respect to the axes of the plot.""" + legend = False + normed = False + + +# =================================================================================== # +# Properties with numeric values where scale range can be defined as an interval +# =================================================================================== # + + +class IntervalProperty(Property): + """A numeric property where scale range can be defined as an interval.""" + legend = True + normed = True + + _default_range: tuple[float, float] = (0, 1) + + @property + def default_range(self) -> tuple[float, float]: + """Min and max values used by default for semantic mapping.""" + return self._default_range + + def _forward(self, values: ArrayLike) -> ArrayLike: + """Transform applied to native values before linear mapping into interval.""" + return values + + def _inverse(self, values: ArrayLike) -> ArrayLike: + """Transform applied to results of mapping that returns to native values.""" + return values + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + """Given data and a scaling argument, initialize appropriate scale class.""" + + # TODO infer continuous based on log/sqrt etc? + + if isinstance(arg, (list, dict)): + return Nominal(arg) + elif variable_type(data) == "categorical": + return Nominal(arg) + elif variable_type(data) == "datetime": + return Temporal(arg) + # TODO other variable types + else: + return Continuous(arg) + + def get_mapping( + self, scale: ScaleSpec, data: ArrayLike + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps from data domain to property range.""" + if isinstance(scale, Nominal): + return self._get_categorical_mapping(scale, data) + + if scale.values is None: + vmin, vmax = self._forward(self.default_range) + elif isinstance(scale.values, tuple) and len(scale.values) == 2: + vmin, vmax = self._forward(scale.values) + else: + if isinstance(scale.values, tuple): + actual = f"{len(scale.values)}-tuple" + else: + actual = str(type(scale.values)) + scale_class = scale.__class__.__name__ + err = " ".join([ + f"Values for {self.variable} variables with {scale_class} scale", + f"must be 2-tuple; not {actual}.", + ]) + raise TypeError(err) + + def mapping(x): + return self._inverse(np.multiply(x, vmax - vmin) + vmin) + + return mapping + + def _get_categorical_mapping( + self, scale: Nominal, data: ArrayLike + ) -> Callable[[ArrayLike], ArrayLike]: + """Identify evenly-spaced values using interval or explicit mapping.""" + levels = categorical_order(data, scale.order) + + if isinstance(scale.values, dict): + self._check_dict_entries(levels, scale.values) + values = [scale.values[x] for x in levels] + elif isinstance(scale.values, list): + values = self._check_list_length(levels, scale.values) + else: + if scale.values is None: + vmin, vmax = self.default_range + elif isinstance(scale.values, tuple): + vmin, vmax = scale.values + else: + scale_class = scale.__class__.__name__ + err = " ".join([ + f"Values for {self.variable} variables with {scale_class} scale", + f"must be a dict, list or tuple; not {type(scale.values)}", + ]) + raise TypeError(err) + + vmin, vmax = self._forward([vmin, vmax]) + values = self._inverse(np.linspace(vmax, vmin, len(levels))) + + def mapping(x): + ixs = np.asarray(x, np.intp) + out = np.full(len(x), np.nan) + use = np.isfinite(x) + out[use] = np.take(values, ixs[use]) + return out + + return mapping + + +class PointSize(IntervalProperty): + """Size (diameter) of a point mark, in points, with scaling by area.""" + _default_range = 2, 8 # TODO use rcparams? + # TODO N.B. both Scatter and Dot use this but have different expected sizes + # Is that something we need to handle? Or assume Dot size rarely scaled? + # Also will Line marks have a PointSize property? + + def _forward(self, values): + """Square native values to implement linear scaling of point area.""" + return np.square(values) + + def _inverse(self, values): + """Invert areal values back to point diameter.""" + return np.sqrt(values) + + +class LineWidth(IntervalProperty): + """Thickness of a line mark, in points.""" + @property + def default_range(self) -> tuple[float, float]: + """Min and max values used by default for semantic mapping.""" + base = mpl.rcParams["lines.linewidth"] + return base * .5, base * 2 + + +class EdgeWidth(IntervalProperty): + """Thickness of the edges on a patch mark, in points.""" + @property + def default_range(self) -> tuple[float, float]: + """Min and max values used by default for semantic mapping.""" + base = mpl.rcParams["patch.linewidth"] + return base * .5, base * 2 + + +class Stroke(IntervalProperty): + """Thickness of lines that define point glyphs.""" + _default_range = .25, 2.5 + + +class Alpha(IntervalProperty): + """Opacity of the color values for an arbitrary mark.""" + _default_range = .3, .95 + # TODO validate / enforce that output is in [0, 1] + + +# =================================================================================== # +# Properties defined by arbitrary objects with inherently nominal scaling +# =================================================================================== # + + +class ObjectProperty(Property): + """A property defined by arbitrary an object, with inherently nominal scaling.""" + legend = True + normed = False + + # Object representing null data, should appear invisible when drawn by matplotlib + null_value: Any = None + + def _default_values(self, n: int) -> list: + raise NotImplementedError() + + def default_scale(self, data: Series) -> Nominal: + return Nominal() + + def infer_scale(self, arg: Any, data: Series) -> Nominal: + return Nominal(arg) + + def get_mapping( + self, scale: ScaleSpec, data: Series, + ) -> Callable[[ArrayLike], list]: + """Define mapping as lookup into list of object values.""" + order = getattr(scale, "order", None) + levels = categorical_order(data, order) + n = len(levels) + + if isinstance(scale.values, dict): + self._check_dict_entries(levels, scale.values) + values = [scale.values[x] for x in levels] + elif isinstance(scale.values, list): + values = self._check_list_length(levels, scale.values) + elif scale.values is None: + values = self._default_values(n) + else: + msg = " ".join([ + f"Scale values for a {self.variable} variable must be provided", + f"in a dict or list; not {type(scale.values)}." + ]) + raise TypeError(msg) + + values = [self.standardize(x) for x in values] + + def mapping(x): + ixs = np.asarray(x, np.intp) + return [ + values[ix] if np.isfinite(x_i) else self.null_value + for x_i, ix in zip(x, ixs) + ] + + return mapping + + +class Marker(ObjectProperty): + """Shape of points in scatter-type marks or lines with data points marked.""" + null_value = MarkerStyle("") + + # TODO should we have named marker "palettes"? (e.g. see d3 options) + + # TODO need some sort of "require_scale" functionality + # to raise when we get the wrong kind explicitly specified + + def standardize(self, val: MarkerPattern) -> MarkerStyle: + return MarkerStyle(val) + + def _default_values(self, n: int) -> list[MarkerStyle]: + """Build an arbitrarily long list of unique marker styles. + + Parameters + ---------- + n : int + Number of unique marker specs to generate. + + Returns + ------- + markers : list of string or tuples + Values for defining :class:`matplotlib.markers.MarkerStyle` objects. + All markers will be filled. + + """ + # Start with marker specs that are well distinguishable + markers = [ + "o", "X", (4, 0, 45), "P", (4, 0, 0), (4, 1, 0), "^", (4, 1, 45), "v", + ] + + # Now generate more from regular polygons of increasing order + s = 5 + while len(markers) < n: + a = 360 / (s + 1) / 2 + markers.extend([(s + 1, 1, a), (s + 1, 0, a), (s, 1, 0), (s, 0, 0)]) + s += 1 + + markers = [MarkerStyle(m) for m in markers[:n]] + + return markers + + +class LineStyle(ObjectProperty): + """Dash pattern for line-type marks.""" + null_value = "" + + def standardize(self, val: str | DashPattern) -> DashPatternWithOffset: + return self._get_dash_pattern(val) + + def _default_values(self, n: int) -> list[DashPatternWithOffset]: + """Build an arbitrarily long list of unique dash styles for lines. + + Parameters + ---------- + n : int + Number of unique dash specs to generate. + + Returns + ------- + dashes : list of strings or tuples + Valid arguments for the ``dashes`` parameter on + :class:`matplotlib.lines.Line2D`. The first spec is a solid + line (``""``), the remainder are sequences of long and short + dashes. + + """ + # Start with dash specs that are well distinguishable + dashes: list[str | DashPattern] = [ + "-", (4, 1.5), (1, 1), (3, 1.25, 1.5, 1.25), (5, 1, 1, 1), + ] + + # Now programmatically build as many as we need + p = 3 + while len(dashes) < n: + + # Take combinations of long and short dashes + a = itertools.combinations_with_replacement([3, 1.25], p) + b = itertools.combinations_with_replacement([4, 1], p) + + # Interleave the combinations, reversing one of the streams + segment_list = itertools.chain(*zip(list(a)[1:-1][::-1], list(b)[1:-1])) + + # Now insert the gaps + for segments in segment_list: + gap = min(segments) + spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) + dashes.append(spec) + + p += 1 + + return [self._get_dash_pattern(x) for x in dashes] + + @staticmethod + def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: + """Convert linestyle arguments to dash pattern with offset.""" + # Copied and modified from Matplotlib 3.4 + # go from short hand -> full strings + ls_mapper = {"-": "solid", "--": "dashed", "-.": "dashdot", ":": "dotted"} + if isinstance(style, str): + style = ls_mapper.get(style, style) + # un-dashed styles + if style in ["solid", "none", "None"]: + offset = 0 + dashes = None + # dashed styles + elif style in ["dashed", "dashdot", "dotted"]: + offset = 0 + dashes = tuple(mpl.rcParams[f"lines.{style}_pattern"]) + else: + options = [*ls_mapper.values(), *ls_mapper.keys()] + msg = f"Linestyle string must be one of {options}, not {repr(style)}." + raise ValueError(msg) + + elif isinstance(style, tuple): + if len(style) > 1 and isinstance(style[1], tuple): + offset, dashes = style + elif len(style) > 1 and style[1] is None: + offset, dashes = style + else: + offset = 0 + dashes = style + else: + val_type = type(style).__name__ + msg = f"Linestyle must be str or tuple, not {val_type}." + raise TypeError(msg) + + # Normalize offset to be positive and shorter than the dash cycle + if dashes is not None: + try: + dsum = sum(dashes) + except TypeError as err: + msg = f"Invalid dash pattern: {dashes}" + raise TypeError(msg) from err + if dsum: + offset %= dsum + + return offset, dashes + + +# =================================================================================== # +# Properties with RGB(A) color values +# =================================================================================== # + + +class Color(Property): + """Color, as RGB(A), scalable with nominal palettes or continuous gradients.""" + legend = True + normed = True + + def standardize(self, val: ColorSpec) -> RGBTuple | RGBATuple: + # Return color with alpha channel only if the input spec has it + # This is so that RGBA colors can override the Alpha property + if to_rgba(val) != to_rgba(val, 1): + return to_rgba(val) + else: + return to_rgb(val) + + def _standardize_color_sequence(self, colors: ArrayLike) -> ArrayLike: + """Convert color sequence to RGB(A) array, preserving but not adding alpha.""" + def has_alpha(x): + return to_rgba(x) != to_rgba(x, 1) + + if isinstance(colors, np.ndarray): + needs_alpha = colors.shape[1] == 4 + else: + needs_alpha = any(has_alpha(x) for x in colors) + + if needs_alpha: + return to_rgba_array(colors) + else: + return to_rgba_array(colors)[:, :3] + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + # TODO when inferring Continuous without data, verify type + + # TODO need to rethink the variable type system + # (e.g. boolean, ordered categories as Ordinal, etc).. + var_type = variable_type(data, boolean_type="categorical") + + if isinstance(arg, (dict, list)): + return Nominal(arg) + + if isinstance(arg, tuple): + if var_type == "categorical": + # TODO It seems reasonable to allow a gradient mapping for nominal + # scale but it also feels "technically" wrong. Should this infer + # Ordinal with categorical data and, if so, verify orderedness? + return Nominal(arg) + return Continuous(arg) + + if callable(arg): + return Continuous(arg) + + # TODO Do we accept str like "log", "pow", etc. for semantics? + + # TODO what about + # - Temporal? (i.e. datetime) + # - Boolean? + + if not isinstance(arg, str): + msg = " ".join([ + f"A single scale argument for {self.variable} variables must be", + f"a string, dict, tuple, list, or callable, not {type(arg)}." + ]) + raise TypeError(msg) + + if arg in QUAL_PALETTES: + return Nominal(arg) + elif var_type == "numeric": + return Continuous(arg) + # TODO implement scales for date variables and any others. + else: + return Nominal(arg) + + def _get_categorical_mapping(self, scale, data): + """Define mapping as lookup in list of discrete color values.""" + levels = categorical_order(data, scale.order) + n = len(levels) + values = scale.values + + if isinstance(values, dict): + self._check_dict_entries(levels, values) + # TODO where to ensure that dict values have consistent representation? + colors = [values[x] for x in levels] + elif isinstance(values, list): + colors = self._check_list_length(levels, scale.values) + elif isinstance(values, tuple): + colors = blend_palette(values, n) + elif isinstance(values, str): + colors = color_palette(values, n) + elif values is None: + if n <= len(get_color_cycle()): + # Use current (global) default palette + colors = color_palette(n_colors=n) + else: + colors = color_palette("husl", n) + else: + scale_class = scale.__class__.__name__ + msg = " ".join([ + f"Scale values for {self.variable} with a {scale_class} mapping", + f"must be string, list, tuple, or dict; not {type(scale.values)}." + ]) + raise TypeError(msg) + + # If color specified here has alpha channel, it will override alpha property + colors = self._standardize_color_sequence(colors) + + def mapping(x): + ixs = np.asarray(x, np.intp) + use = np.isfinite(x) + out = np.full((len(ixs), colors.shape[1]), np.nan) + out[use] = np.take(colors, ixs[use], axis=0) + return out + + return mapping + + def get_mapping( + self, scale: ScaleSpec, data: Series + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps from data domain to color values.""" + # TODO what is best way to do this conditional? + # Should it be class-based or should classes have behavioral attributes? + if isinstance(scale, Nominal): + return self._get_categorical_mapping(scale, data) + + if scale.values is None: + # TODO Rethink best default continuous color gradient + mapping = color_palette("ch:", as_cmap=True) + elif isinstance(scale.values, tuple): + # TODO blend_palette will strip alpha, but we should support + # interpolation on all four channels + mapping = blend_palette(scale.values, as_cmap=True) + elif isinstance(scale.values, str): + # TODO for matplotlib colormaps this will clip extremes, which is + # different from what using the named colormap directly would do + # This may or may not be desireable. + mapping = color_palette(scale.values, as_cmap=True) + elif callable(scale.values): + mapping = scale.values + else: + scale_class = scale.__class__.__name__ + msg = " ".join([ + f"Scale values for {self.variable} with a {scale_class} mapping", + f"must be string, tuple, or callable; not {type(scale.values)}." + ]) + raise TypeError(msg) + + def _mapping(x): + # Remove alpha channel so it does not override alpha property downstream + # TODO this will need to be more flexible to support RGBA tuples (see above) + invalid = ~np.isfinite(x) + out = mapping(x)[:, :3] + out[invalid] = np.nan + return out + + return _mapping + + +# =================================================================================== # +# Properties that can take only two states +# =================================================================================== # + + +class Fill(Property): + """Boolean property of points/bars/patches that can be solid or outlined.""" + legend = True + normed = False + + # TODO default to Nominal scale always? + # Actually this will just not work with Continuous (except 0/1), suggesting we need + # an abstraction for failing gracefully on bad Property <> Scale interactions + + def standardize(self, val: Any) -> bool: + return bool(val) + + def _default_values(self, n: int) -> list: + """Return a list of n values, alternating True and False.""" + if n > 2: + msg = " ".join([ + f"The variable assigned to {self.variable} has more than two levels,", + f"so {self.variable} values will cycle and may be uninterpretable", + ]) + # TODO fire in a "nice" way (see above) + warnings.warn(msg, UserWarning) + return [x for x, _ in zip(itertools.cycle([True, False]), range(n))] + + def default_scale(self, data: Series) -> Nominal: + """Given data, initialize appropriate scale class.""" + return Nominal() + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + """Given data and a scaling argument, initialize appropriate scale class.""" + # TODO infer Boolean where possible? + return Nominal(arg) + + def get_mapping( + self, scale: ScaleSpec, data: Series + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps each data value to True or False.""" + # TODO categorical_order is going to return [False, True] for booleans, + # and [0, 1] for binary, but the default values order is [True, False]. + # We should special case this to handle it properly, or change + # categorical_order to not "sort" booleans. Note that we need to sync with + # what's going to happen upstream in the scale, so we can't just do it here. + order = getattr(scale, "order", None) + levels = categorical_order(data, order) + + if isinstance(scale.values, list): + values = [bool(x) for x in scale.values] + elif isinstance(scale.values, dict): + values = [bool(scale.values[x]) for x in levels] + elif scale.values is None: + values = self._default_values(len(levels)) + else: + msg = " ".join([ + f"Scale values for {self.variable} must be passed in", + f"a list or dict; not {type(scale.values)}." + ]) + raise TypeError(msg) + + def mapping(x): + return np.take(values, np.asarray(x, np.intp)) + + return mapping + + +# =================================================================================== # +# Enumeration of properties for use by Plot and Mark classes +# =================================================================================== # +# TODO turn this into a property registry with hooks, etc. +# TODO Users do not interact directly with properties, so how to document them? + + +PROPERTY_CLASSES = { + "x": Coordinate, + "y": Coordinate, + "color": Color, + "alpha": Alpha, + "fill": Fill, + "marker": Marker, + "pointsize": PointSize, + "stroke": Stroke, + "linewidth": LineWidth, + "linestyle": LineStyle, + "fillcolor": Color, + "fillalpha": Alpha, + "edgewidth": EdgeWidth, + "edgestyle": LineStyle, + "edgecolor": Color, + "edgealpha": Alpha, + "xmin": Coordinate, + "xmax": Coordinate, + "ymin": Coordinate, + "ymax": Coordinate, + "group": Property, + # TODO pattern? + # TODO gradient? +} + +PROPERTIES = {var: cls(var) for var, cls in PROPERTY_CLASSES.items()} diff --git a/seaborn/_core/rules.py b/seaborn/_core/rules.py new file mode 100644 index 0000000000..d378fb2dc2 --- /dev/null +++ b/seaborn/_core/rules.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import warnings +from collections import UserString +from numbers import Number +from datetime import datetime + +import numpy as np +import pandas as pd + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Literal + from pandas import Series + + +class VarType(UserString): + """ + Prevent comparisons elsewhere in the library from using the wrong name. + + Errors are simple assertions because users should not be able to trigger + them. If that changes, they should be more verbose. + + """ + # TODO VarType is an awfully overloaded name, but so is DataType ... + # TODO adding unknown because we are using this in for scales, is that right? + allowed = "numeric", "datetime", "categorical", "unknown" + + def __init__(self, data): + assert data in self.allowed, data + super().__init__(data) + + def __eq__(self, other): + assert other in self.allowed, other + return self.data == other + + +def variable_type( + vector: Series, + boolean_type: Literal["numeric", "categorical"] = "numeric", +) -> VarType: + """ + Determine whether a vector contains numeric, categorical, or datetime data. + + This function differs from the pandas typing API in two ways: + + - Python sequences or object-typed PyData objects are considered numeric if + all of their entries are numeric. + - String or mixed-type data are considered categorical even if not + explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. + + Parameters + ---------- + vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence + Input data to test. + boolean_type : 'numeric' or 'categorical' + Type to use for vectors containing only 0s and 1s (and NAs). + + Returns + ------- + var_type : 'numeric', 'categorical', or 'datetime' + Name identifying the type of data in the vector. + """ + + # If a categorical dtype is set, infer categorical + if pd.api.types.is_categorical_dtype(vector): + return VarType("categorical") + + # Special-case all-na data, which is always "numeric" + if pd.isna(vector).all(): + return VarType("numeric") + + # Special-case binary/boolean data, allow caller to determine + # This triggers a numpy warning when vector has strings/objects + # https://github.com/numpy/numpy/issues/6784 + # Because we reduce with .all(), we are agnostic about whether the + # comparison returns a scalar or vector, so we will ignore the warning. + # It triggers a separate DeprecationWarning when the vector has datetimes: + # https://github.com/numpy/numpy/issues/13548 + # This is considered a bug by numpy and will likely go away. + with warnings.catch_warnings(): + warnings.simplefilter( + action='ignore', + category=(FutureWarning, DeprecationWarning) # type: ignore # mypy bug? + ) + if np.isin(vector, [0, 1, np.nan]).all(): + return VarType(boolean_type) + + # Defer to positive pandas tests + if pd.api.types.is_numeric_dtype(vector): + return VarType("numeric") + + if pd.api.types.is_datetime64_dtype(vector): + return VarType("datetime") + + # --- If we get to here, we need to check the entries + + # Check for a collection where everything is a number + + def all_numeric(x): + for x_i in x: + if not isinstance(x_i, Number): + return False + return True + + if all_numeric(vector): + return VarType("numeric") + + # Check for a collection where everything is a datetime + + def all_datetime(x): + for x_i in x: + if not isinstance(x_i, (datetime, np.datetime64)): + return False + return True + + if all_datetime(vector): + return VarType("datetime") + + # Otherwise, our final fallback is to consider things categorical + + return VarType("categorical") + + +def categorical_order(vector: Series, order: list | None = None) -> list: + """ + Return a list of unique data values using seaborn's ordering rules. + + Parameters + ---------- + vector : Series + Vector of "categorical" values + order : list + Desired order of category levels to override the order determined + from the `data` object. + + Returns + ------- + order : list + Ordered list of category levels not including null values. + + """ + if order is not None: + return order + + if vector.dtype.name == "category": + order = list(vector.cat.categories) + else: + order = list(filter(pd.notnull, vector.unique())) + if variable_type(order) == "numeric": + order.sort() + + return order diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py new file mode 100644 index 0000000000..eaf4604af8 --- /dev/null +++ b/seaborn/_core/scales.py @@ -0,0 +1,744 @@ +from __future__ import annotations +import re +from copy import copy +from dataclasses import dataclass +from functools import partial + +import numpy as np +import pandas as pd +import matplotlib as mpl +from matplotlib.ticker import ( + Locator, + Formatter, + AutoLocator, + AutoMinorLocator, + FixedLocator, + LinearLocator, + LogLocator, + MaxNLocator, + MultipleLocator, + ScalarFormatter, +) +from matplotlib.dates import ( + AutoDateLocator, + AutoDateFormatter, + ConciseDateFormatter, +) +from matplotlib.axis import Axis + +from seaborn._core.rules import categorical_order + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any, Callable, Tuple, Optional, Union + from collections.abc import Sequence + from matplotlib.scale import ScaleBase as MatplotlibScale + from pandas import Series + from numpy.typing import ArrayLike + from seaborn._core.properties import Property + + Transforms = Tuple[ + Callable[[ArrayLike], ArrayLike], Callable[[ArrayLike], ArrayLike] + ] + + # TODO standardize String / ArrayLike interface + Pipeline = Sequence[Optional[Callable[[Union[Series, ArrayLike]], ArrayLike]]] + + +class Scale: + + def __init__( + self, + forward_pipe: Pipeline, + spacer: Callable[[Series], float], + legend: tuple[list[Any], list[str]] | None, + scale_type: str, + matplotlib_scale: MatplotlibScale, + ): + + self.forward_pipe = forward_pipe + self.spacer = spacer + self.legend = legend + self.scale_type = scale_type + self.matplotlib_scale = matplotlib_scale + + # TODO need to make this work + self.order = None + + def __call__(self, data: Series) -> ArrayLike: + + return self._apply_pipeline(data, self.forward_pipe) + + # TODO def as_identity(cls): ? + + def _apply_pipeline( + self, data: ArrayLike, pipeline: Pipeline, + ) -> ArrayLike: + + # TODO sometimes we need to handle scalars (e.g. for Line) + # but what is the best way to do that? + scalar_data = np.isscalar(data) + if scalar_data: + data = np.array([data]) + + for func in pipeline: + if func is not None: + data = func(data) + + if scalar_data: + data = data[0] + + return data + + def spacing(self, data: Series) -> float: + return self.spacer(data) + + def invert_axis_transform(self, x): + # TODO we may no longer need this method as we use the axis + # transform directly in Plotter._unscale_coords + finv = self.matplotlib_scale.get_transform().inverted().transform + out = finv(x) + if isinstance(x, pd.Series): + return pd.Series(out, index=x.index, name=x.name) + return out + + +@dataclass +class ScaleSpec: + + values: tuple | str | list | dict | None = None + + ... + # TODO have Scale define width (/height?) ('space'?) (using data?), so e.g. nominal + # scale sets width=1, continuous scale sets width min(diff(unique(data))), etc. + + def __post_init__(self): + + # TODO do we need anything else here? + self.tick() + self.format() + + def tick(self): + # TODO what is the right base method? + self._major_locator: Locator + self._minor_locator: Locator + return self + + def format(self): + self._major_formatter: Formatter + return self + + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: + ... + + # TODO typing + def _get_scale(self, name, forward, inverse): + + major_locator = self._major_locator + minor_locator = self._minor_locator + + # TODO hack, need to add default to Continuous + major_formatter = getattr(self, "_major_formatter", ScalarFormatter()) + # major_formatter = self._major_formatter + + class Scale(mpl.scale.FuncScale): + def set_default_locators_and_formatters(self, axis): + axis.set_major_locator(major_locator) + if minor_locator is not None: + axis.set_minor_locator(minor_locator) + axis.set_major_formatter(major_formatter) + + return Scale(name, (forward, inverse)) + + +@dataclass +class Nominal(ScaleSpec): + """ + A categorical scale without relative importance / magnitude. + """ + # Categorical (convert to strings), un-sortable + + order: list | None = None + + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: + + class CatScale(mpl.scale.LinearScale): + # TODO turn this into a real thing I guess + name = None # To work around mpl<3.4 compat issues + + def set_default_locators_and_formatters(self, axis): + pass + + # TODO flexibility over format() which isn't great for numbers / dates + stringify = np.vectorize(format) + + units_seed = categorical_order(data, self.order) + + mpl_scale = CatScale(data.name) + if axis is None: + axis = PseudoAxis(mpl_scale) + + # TODO Currently just used in non-Coordinate contexts, but should + # we use this to (A) set the padding we want for categorial plots + # and (B) allow the values parameter for a Coordinate to set xlim/ylim + axis.set_view_interval(0, len(units_seed) - 1) + + # TODO array cast necessary to handle float/int mixture, which we need + # to solve in a more systematic way probably + # (i.e. if we have [1, 2.5], do we want [1.0, 2.5]? Unclear) + axis.update_units(stringify(np.array(units_seed))) + + # TODO define this more centrally + def convert_units(x): + # TODO only do this with explicit order? + # (But also category dtype?) + # TODO isin fails when units_seed mixes numbers and strings (numpy error?) + # but np.isin also does not seem any faster? (Maybe not broadcasting in C) + # keep = x.isin(units_seed) + keep = np.array([x_ in units_seed for x_ in x], bool) + out = np.full(len(x), np.nan) + out[keep] = axis.convert_units(stringify(x[keep])) + return out + + forward_pipe = [ + convert_units, + prop.get_mapping(self, data), + # TODO how to handle color representation consistency? + ] + + def spacer(x): + return 1 + + if prop.legend: + legend = units_seed, list(stringify(units_seed)) + else: + legend = None + + scale_type = self.__class__.__name__.lower() + scale = Scale(forward_pipe, spacer, legend, scale_type, mpl_scale) + return scale + + +@dataclass +class Ordinal(ScaleSpec): + # Categorical (convert to strings), sortable, can skip ticklabels + ... + + +@dataclass +class Discrete(ScaleSpec): + # Numeric, integral, can skip ticks/ticklabels + ... + + +@dataclass +class ContinuousBase(ScaleSpec): + + values: tuple | str | None = None + norm: tuple | None = None + + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: + + new = copy(self) + forward, inverse = self._get_transform() + + mpl_scale = self._get_scale(data.name, forward, inverse) + + if axis is None: + axis = PseudoAxis(mpl_scale) + axis.update_units(data) + + mpl_scale.set_default_locators_and_formatters(axis) + + normalize: Optional[Callable[[ArrayLike], ArrayLike]] + if prop.normed: + if self.norm is None: + vmin, vmax = data.min(), data.max() + else: + vmin, vmax = self.norm + vmin, vmax = axis.convert_units((vmin, vmax)) + a = forward(vmin) + b = forward(vmax) - forward(vmin) + + def normalize(x): + return (x - a) / b + + else: + normalize = vmin = vmax = None + + forward_pipe = [ + axis.convert_units, + forward, + normalize, + prop.get_mapping(new, data) + ] + + def spacer(x): + return np.min(np.diff(np.sort(x.unique()))) + + # TODO make legend optional on per-plot basis with ScaleSpec parameter? + if prop.legend: + axis.set_view_interval(vmin, vmax) + locs = axis.major.locator() + locs = locs[(vmin <= locs) & (locs <= vmax)] + labels = axis.major.formatter.format_ticks(locs) + legend = list(locs), list(labels) + + else: + legend = None + + scale_type = self.__class__.__name__.lower() + return Scale(forward_pipe, spacer, legend, scale_type, mpl_scale) + + def _get_transform(self): + + arg = self.transform + + def get_param(method, default): + if arg == method: + return default + return float(arg[len(method):]) + + if arg is None: + return _make_identity_transforms() + elif isinstance(arg, tuple): + return arg + elif isinstance(arg, str): + if arg == "ln": + return _make_log_transforms() + elif arg == "logit": + base = get_param("logit", 10) + return _make_logit_transforms(base) + elif arg.startswith("log"): + base = get_param("log", 10) + return _make_log_transforms(base) + elif arg.startswith("symlog"): + c = get_param("symlog", 1) + return _make_symlog_transforms(c) + elif arg.startswith("pow"): + exp = get_param("pow", 2) + return _make_power_transforms(exp) + elif arg == "sqrt": + return _make_sqrt_transforms() + else: + # TODO useful error message + raise ValueError() + + +@dataclass +class Continuous(ContinuousBase): + """ + A numeric scale supporting norms and functional transforms. + """ + transform: str | Transforms | None = None + + # TODO Add this to deal with outliers? + # outside: Literal["keep", "drop", "clip"] = "keep" + + # TODO maybe expose matplotlib more directly like this? + # def using(self, scale: mpl.scale.ScaleBase) ? + + def tick( + self, + locator: Locator | None = None, *, + at: Sequence[float] = None, + upto: int | None = None, + count: int | None = None, + every: float | None = None, + between: tuple[float, float] | None = None, + minor: int | None = None, + ) -> Continuous: # TODO type return value as Self + """ + Configure the selection of ticks for the scale's axis or legend. + + Parameters + ---------- + locator: matplotlib Locator + Pre-configured matplotlib locator; other parameters will not be used. + at : sequence of floats + Place ticks at these specific locations (in data units). + upto : int + Choose "nice" locations for ticks, but do not exceed this number. + count : int + Choose exactly this number of ticks, bounded by `between` or axis limits. + every : float + Choose locations at this interval of separation (in data units). + between : pair of floats + Bound upper / lower ticks when using `every` or `count`. + minor : int + Number of unlabeled ticks to draw between labeled "major" ticks. + + Returns + ------- + Returns self with new tick configuration. + + """ + + # TODO what about symlog? + if isinstance(self.transform, str): + m = re.match(r"log(\d*)", self.transform) + log_transform = m is not None + log_base = m[1] or 10 if m is not None else None + forward, inverse = self._get_transform() + else: + log_transform = False + log_base = forward = inverse = None + + if locator is not None: + # TODO accept tuple for major, minor? + if not isinstance(locator, Locator): + err = ( + f"Tick locator must be an instance of {Locator!r}, " + f"not {type(locator)!r}." + ) + raise TypeError(err) + major_locator = locator + + # TODO raise if locator is passed with any other parameters + + elif upto is not None: + if log_transform: + major_locator = LogLocator(base=log_base, numticks=upto) + else: + major_locator = MaxNLocator(upto, steps=[1, 1.5, 2, 2.5, 3, 5, 10]) + + elif count is not None: + if between is None: + if log_transform: + msg = "`count` requires `between` with log transform." + raise RuntimeError(msg) + # This is rarely useful (unless you are setting limits) + major_locator = LinearLocator(count) + else: + if log_transform: + lo, hi = forward(between) + ticks = inverse(np.linspace(lo, hi, num=count)) + else: + ticks = np.linspace(*between, num=count) + major_locator = FixedLocator(ticks) + + elif every is not None: + if log_transform: + msg = "`every` not supported with log transform." + raise RuntimeError(msg) + if between is None: + major_locator = MultipleLocator(every) + else: + lo, hi = between + ticks = np.arange(lo, hi + every, every) + major_locator = FixedLocator(ticks) + + elif at is not None: + major_locator = FixedLocator(at) + + else: + major_locator = LogLocator(log_base) if log_transform else AutoLocator() + + if minor is None: + minor_locator = LogLocator(log_base, subs=None) if log_transform else None + else: + if log_transform: + subs = np.linspace(0, log_base, minor + 2)[1:-1] + minor_locator = LogLocator(log_base, subs=subs) + else: + minor_locator = AutoMinorLocator(minor + 1) + + self._major_locator = major_locator + self._minor_locator = minor_locator + + return self + + # TODO need to fill this out + # def format(self, ...): + + +@dataclass +class Temporal(ContinuousBase): + """ + A scale for date/time data. + """ + # TODO date: bool? + # For when we only care about the time component, would affect + # default formatter and norm conversion. Should also happen in + # Property.default_scale. The alternative was having distinct + # Calendric / Temporal scales, but that feels a bit fussy, and it + # would get in the way of using first-letter shorthands because + # Calendric and Continuous would collide. Still, we haven't implemented + # those yet, and having a clear distinction betewen date(time) / time + # may be more useful. + + transform = None + + def tick( + self, locator: Locator | None = None, *, + upto: int | None = None, + ) -> Temporal: + + if locator is not None: + # TODO accept tuple for major, minor? + if not isinstance(locator, Locator): + err = ( + f"Tick locator must be an instance of {Locator!r}, " + f"not {type(locator)!r}." + ) + raise TypeError(err) + major_locator = locator + + elif upto is not None: + # TODO atleast for minticks? + major_locator = AutoDateLocator(minticks=2, maxticks=upto) + + else: + major_locator = AutoDateLocator(minticks=2, maxticks=6) + + self._major_locator = major_locator + self._minor_locator = None + + self.format() + + return self + + def format( + self, formater: Formatter | None = None, *, + concise: bool = False, + ) -> Temporal: + + # TODO ideally we would have concise coordinate ticks, + # but full semantic ticks. Is that possible? + if concise: + major_formatter = ConciseDateFormatter(self._major_locator) + else: + major_formatter = AutoDateFormatter(self._major_locator) + self._major_formatter = major_formatter + + return self + + +# ----------------------------------------------------------------------------------- # + + +class Calendric(ScaleSpec): + # TODO have this separate from Temporal or have Temporal(date=True) or similar? + ... + + +class Binned(ScaleSpec): + # Needed? Or handle this at layer (in stat or as param, eg binning=) + ... + + +# TODO any need for color-specific scales? +# class Sequential(Continuous): +# class Diverging(Continuous): +# class Qualitative(Nominal): + + +# ----------------------------------------------------------------------------------- # + + +class PseudoAxis: + """ + Internal class implementing minimal interface equivalent to matplotlib Axis. + + Coordinate variables are typically scaled by attaching the Axis object from + the figure where the plot will end up. Matplotlib has no similar concept of + and axis for the other mappable variables (color, etc.), but to simplify the + code, this object acts like an Axis and can be used to scale other variables. + + """ + axis_name = "" # TODO Needs real value? Just used for x/y logic in matplotlib + + def __init__(self, scale): + + self.converter = None + self.units = None + self.scale = scale + self.major = mpl.axis.Ticker() + self.minor = mpl.axis.Ticker() + + # It appears that this needs to be initialized this way on matplotlib 3.1, + # but not later versions. It is unclear whether there are any issues with it. + self._data_interval = None, None + + scale.set_default_locators_and_formatters(self) + # self.set_default_intervals() TODO mock? + + def set_view_interval(self, vmin, vmax): + # TODO this gets called when setting DateTime units, + # but we may not need it to do anything + self._view_interval = vmin, vmax + + def get_view_interval(self): + return self._view_interval + + # TODO do we want to distinguish view/data intervals? e.g. for a legend + # we probably want to represent the full range of the data values, but + # still norm the colormap. If so, we'll need to track data range separately + # from the norm, which we currently don't do. + + def set_data_interval(self, vmin, vmax): + self._data_interval = vmin, vmax + + def get_data_interval(self): + return self._data_interval + + def get_tick_space(self): + # TODO how to do this in a configurable / auto way? + # Would be cool to have legend density adapt to figure size, etc. + return 5 + + def set_major_locator(self, locator): + self.major.locator = locator + locator.set_axis(self) + + def set_major_formatter(self, formatter): + # TODO matplotlib method does more handling (e.g. to set w/format str) + # We will probably handle that in the tick/format interface, though + self.major.formatter = formatter + formatter.set_axis(self) + + def set_minor_locator(self, locator): + self.minor.locator = locator + locator.set_axis(self) + + def set_minor_formatter(self, formatter): + self.minor.formatter = formatter + formatter.set_axis(self) + + def set_units(self, units): + self.units = units + + def update_units(self, x): + """Pass units to the internal converter, potentially updating its mapping.""" + self.converter = mpl.units.registry.get_converter(x) + if self.converter is not None: + self.converter.default_units(x, self) + + info = self.converter.axisinfo(self.units, self) + + if info is None: + return + if info.majloc is not None: + # TODO matplotlib method has more conditions here; are they needed? + self.set_major_locator(info.majloc) + if info.majfmt is not None: + self.set_major_formatter(info.majfmt) + + # TODO this is in matplotlib method; do we need this? + # self.set_default_intervals() + + def convert_units(self, x): + """Return a numeric representation of the input data.""" + if np.issubdtype(np.asarray(x).dtype, np.number): + return x + elif self.converter is None: + return x + return self.converter.convert(x, self.units, self) + + def get_scale(self): + # TODO matplotlib actually returns a string here! + # Currently we just hit it with minor ticks where it checks for + # scale == "log". I'm not sure how you'd actually use log-scale + # minor "ticks" in a legend context, so this is fine..... + return self.scale + + def get_majorticklocs(self): + return self.major.locator() + + +# ------------------------------------------------------------------------------------ + + +def _make_identity_transforms() -> Transforms: + + def identity(x): + return x + + return identity, identity + + +def _make_logit_transforms(base: float = None) -> Transforms: + + log, exp = _make_log_transforms(base) + + def logit(x): + with np.errstate(invalid="ignore", divide="ignore"): + return log(x) - log(1 - x) + + def expit(x): + with np.errstate(invalid="ignore", divide="ignore"): + return exp(x) / (1 + exp(x)) + + return logit, expit + + +def _make_log_transforms(base: float | None = None) -> Transforms: + + if base is None: + fs = np.log, np.exp + elif base == 2: + fs = np.log2, partial(np.power, 2) + elif base == 10: + fs = np.log10, partial(np.power, 10) + else: + def forward(x): + return np.log(x) / np.log(base) + fs = forward, partial(np.power, base) + + def log(x): + with np.errstate(invalid="ignore", divide="ignore"): + return fs[0](x) + + def exp(x): + with np.errstate(invalid="ignore", divide="ignore"): + return fs[1](x) + + return log, exp + + +def _make_symlog_transforms(c: float = 1, base: float = 10) -> Transforms: + + # From https://iopscience.iop.org/article/10.1088/0957-0233/24/2/027001 + + # Note: currently not using base because we only get + # one parameter from the string, and are using c (this is consistent with d3) + + log, exp = _make_log_transforms(base) + + def symlog(x): + with np.errstate(invalid="ignore", divide="ignore"): + return np.sign(x) * log(1 + np.abs(np.divide(x, c))) + + def symexp(x): + with np.errstate(invalid="ignore", divide="ignore"): + return np.sign(x) * c * (exp(np.abs(x)) - 1) + + return symlog, symexp + + +def _make_sqrt_transforms() -> Transforms: + + def sqrt(x): + return np.sign(x) * np.sqrt(np.abs(x)) + + def square(x): + return np.sign(x) * np.square(x) + + return sqrt, square + + +def _make_power_transforms(exp: float) -> Transforms: + + def forward(x): + return np.sign(x) * np.power(np.abs(x), exp) + + def inverse(x): + return np.sign(x) * np.power(np.abs(x), 1 / exp) + + return forward, inverse diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py new file mode 100644 index 0000000000..88134ba2c0 --- /dev/null +++ b/seaborn/_core/subplots.py @@ -0,0 +1,270 @@ +from __future__ import annotations +from collections.abc import Generator + +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt + +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from typing import TYPE_CHECKING +if TYPE_CHECKING: # TODO move to seaborn._core.typing? + from seaborn._core.plot import FacetSpec, PairSpec + from matplotlib.figure import SubFigure + + +class Subplots: + """ + Interface for creating and using matplotlib subplots based on seaborn parameters. + + Parameters + ---------- + subplot_spec : dict + Keyword args for :meth:`matplotlib.figure.Figure.subplots`. + facet_spec : dict + Parameters that control subplot faceting. + pair_spec : dict + Parameters that control subplot pairing. + data : PlotData + Data used to define figure setup. + + """ + def __init__( + # TODO defined TypedDict types for these specs + self, + subplot_spec: dict, + facet_spec: FacetSpec, + pair_spec: PairSpec, + ): + + self.subplot_spec = subplot_spec + + self._check_dimension_uniqueness(facet_spec, pair_spec) + self._determine_grid_dimensions(facet_spec, pair_spec) + self._handle_wrapping(facet_spec, pair_spec) + self._determine_axis_sharing(pair_spec) + + def _check_dimension_uniqueness( + self, facet_spec: FacetSpec, pair_spec: PairSpec + ) -> None: + """Reject specs that pair and facet on (or wrap to) same figure dimension.""" + err = None + + facet_vars = facet_spec.get("variables", {}) + + if facet_spec.get("wrap") and {"col", "row"} <= set(facet_vars): + err = "Cannot wrap facets when specifying both `col` and `row`." + elif ( + pair_spec.get("wrap") + and pair_spec.get("cross", True) + and len(pair_spec.get("structure", {}).get("x", [])) > 1 + and len(pair_spec.get("structure", {}).get("y", [])) > 1 + ): + err = "Cannot wrap subplots when pairing on both `x` and `y`." + + collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]} + for pair_axis, (multi_dim, wrap_dim) in collisions.items(): + if pair_axis not in pair_spec.get("structure", {}): + continue + elif multi_dim[:3] in facet_vars: + err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``." + elif wrap_dim[:3] in facet_vars and facet_spec.get("wrap"): + err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``." + elif wrap_dim[:3] in facet_vars and pair_spec.get("wrap"): + err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}." + + if err is not None: + raise RuntimeError(err) # TODO what err class? Define PlotSpecError? + + def _determine_grid_dimensions( + self, facet_spec: FacetSpec, pair_spec: PairSpec + ) -> None: + """Parse faceting and pairing information to define figure structure.""" + self.grid_dimensions: dict[str, list] = {} + for dim, axis in zip(["col", "row"], ["x", "y"]): + + facet_vars = facet_spec.get("variables", {}) + if dim in facet_vars: + self.grid_dimensions[dim] = facet_spec["structure"][dim] + elif axis in pair_spec.get("structure", {}): + self.grid_dimensions[dim] = [ + None for _ in pair_spec.get("structure", {})[axis] + ] + else: + self.grid_dimensions[dim] = [None] + + self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim]) + + if not pair_spec.get("cross", True): + self.subplot_spec["nrows"] = 1 + + self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"] + + def _handle_wrapping( + self, facet_spec: FacetSpec, pair_spec: PairSpec + ) -> None: + """Update figure structure parameters based on facet/pair wrapping.""" + self.wrap = wrap = facet_spec.get("wrap") or pair_spec.get("wrap") + if not wrap: + return + + wrap_dim = "row" if self.subplot_spec["nrows"] > 1 else "col" + flow_dim = {"row": "col", "col": "row"}[wrap_dim] + n_subplots = self.subplot_spec[f"n{wrap_dim}s"] + flow = int(np.ceil(n_subplots / wrap)) + + if wrap < self.subplot_spec[f"n{wrap_dim}s"]: + self.subplot_spec[f"n{wrap_dim}s"] = wrap + self.subplot_spec[f"n{flow_dim}s"] = flow + self.n_subplots = n_subplots + self.wrap_dim = wrap_dim + + def _determine_axis_sharing(self, pair_spec: PairSpec) -> None: + """Update subplot spec with default or specified axis sharing parameters.""" + axis_to_dim = {"x": "col", "y": "row"} + key: str + val: str | bool + for axis in "xy": + key = f"share{axis}" + # Always use user-specified value, if present + if key not in self.subplot_spec: + if axis in pair_spec.get("structure", {}): + # Paired axes are shared along one dimension by default + if self.wrap in [None, 1] and pair_spec.get("cross", True): + val = axis_to_dim[axis] + else: + val = False + else: + # This will pick up faceted plots, as well as single subplot + # figures, where the value doesn't really matter + val = True + self.subplot_spec[key] = val + + def init_figure( + self, + pair_spec: PairSpec, + pyplot: bool = False, + figure_kws: dict | None = None, + target: Axes | Figure | SubFigure = None, + ) -> Figure: + """Initialize matplotlib objects and add seaborn-relevant metadata.""" + # TODO reduce need to pass pair_spec here? + + if figure_kws is None: + figure_kws = {} + + if isinstance(target, mpl.axes.Axes): + + if max(self.subplot_spec["nrows"], self.subplot_spec["ncols"]) > 1: + err = " ".join([ + "Cannot create multiple subplots after calling `Plot.on` with", + f"a {mpl.axes.Axes} object.", + ]) + try: + err += f" You may want to use a {mpl.figure.SubFigure} instead." + except AttributeError: # SubFigure added in mpl 3.4 + pass + raise RuntimeError(err) + + self._subplot_list = [{ + "ax": target, + "left": True, + "right": True, + "top": True, + "bottom": True, + "col": None, + "row": None, + "x": "x", + "y": "y", + }] + self._figure = target.figure + return self._figure + + elif ( + hasattr(mpl.figure, "SubFigure") # Added in mpl 3.4 + and isinstance(target, mpl.figure.SubFigure) + ): + figure = target.figure + elif isinstance(target, mpl.figure.Figure): + figure = target + else: + if pyplot: + figure = plt.figure(**figure_kws) + else: + figure = mpl.figure.Figure(**figure_kws) + target = figure + self._figure = figure + + axs = target.subplots(**self.subplot_spec, squeeze=False) + + if self.wrap: + # Remove unused Axes and flatten the rest into a (2D) vector + axs_flat = axs.ravel({"col": "C", "row": "F"}[self.wrap_dim]) + axs, extra = np.split(axs_flat, [self.n_subplots]) + for ax in extra: + ax.remove() + if self.wrap_dim == "col": + axs = axs[np.newaxis, :] + else: + axs = axs[:, np.newaxis] + + # Get i, j coordinates for each Axes object + # Note that i, j are with respect to faceting/pairing, + # not the subplot grid itself, (which only matters in the case of wrapping). + iter_axs: np.ndenumerate | zip + if not pair_spec.get("cross", True): + indices = np.arange(self.n_subplots) + iter_axs = zip(zip(indices, indices), axs.flat) + else: + iter_axs = np.ndenumerate(axs) + + self._subplot_list = [] + for (i, j), ax in iter_axs: + + info = {"ax": ax} + + nrows, ncols = self.subplot_spec["nrows"], self.subplot_spec["ncols"] + if not self.wrap: + info["left"] = j % ncols == 0 + info["right"] = (j + 1) % ncols == 0 + info["top"] = i == 0 + info["bottom"] = i == nrows - 1 + elif self.wrap_dim == "col": + info["left"] = j % ncols == 0 + info["right"] = ((j + 1) % ncols == 0) or ((j + 1) == self.n_subplots) + info["top"] = j < ncols + info["bottom"] = j >= (self.n_subplots - ncols) + elif self.wrap_dim == "row": + info["left"] = i < nrows + info["right"] = i >= self.n_subplots - nrows + info["top"] = i % nrows == 0 + info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots) + + if not pair_spec.get("cross", True): + info["top"] = j < ncols + info["bottom"] = j >= self.n_subplots - ncols + + for dim in ["row", "col"]: + idx = {"row": i, "col": j}[dim] + info[dim] = self.grid_dimensions[dim][idx] + + for axis in "xy": + + idx = {"x": j, "y": i}[axis] + if axis in pair_spec.get("structure", {}): + key = f"{axis}{idx}" + else: + key = axis + info[axis] = key + + self._subplot_list.append(info) + + return figure + + def __iter__(self) -> Generator[dict, None, None]: # TODO TypedDict? + """Yield each subplot dictionary with Axes object and metadata.""" + yield from self._subplot_list + + def __len__(self) -> int: + """Return the number of subplots in this figure.""" + return len(self._subplot_list) diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py new file mode 100644 index 0000000000..3599aaae7a --- /dev/null +++ b/seaborn/_core/typing.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Any, Optional, Union, Mapping, Tuple, List, Dict +from collections.abc import Hashable, Iterable +from numpy import ndarray # TODO use ArrayLike? +from pandas import DataFrame, Series, Index +from matplotlib.colors import Colormap, Normalize + +Vector = Union[Series, Index, ndarray] +PaletteSpec = Union[str, list, dict, Colormap, None] +VariableSpec = Union[Hashable, Vector, None] +# TODO can we better unify the VarType object and the VariableType alias? +DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] + +OrderSpec = Union[Iterable, None] # TODO technically str is iterable +NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] + +# TODO for discrete mappings, it would be ideal to use a parameterized type +# as the dict values / list entries should be of specific type(s) for each method +DiscreteValueSpec = Union[dict, list, None] +ContinuousValueSpec = Union[ + Tuple[float, float], List[float], Dict[Any, float], None, +] diff --git a/seaborn/_marks/__init__.py b/seaborn/_marks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/_marks/area.py b/seaborn/_marks/area.py new file mode 100644 index 0000000000..27324f25f4 --- /dev/null +++ b/seaborn/_marks/area.py @@ -0,0 +1,115 @@ +from __future__ import annotations +from collections import defaultdict +from dataclasses import dataclass + +import numpy as np +import matplotlib as mpl + +from seaborn._marks.base import ( + Mark, + Mappable, + MappableBool, + MappableFloat, + MappableColor, + MappableStyle, + resolve_properties, + resolve_color, +) + + +class AreaBase: + + def _plot(self, split_gen, scales, orient): + + kws = {} + + for keys, data, ax in split_gen(): + + kws.setdefault(ax, defaultdict(list)) + + data = self._standardize_coordinate_parameters(data, orient) + resolved = resolve_properties(self, keys, scales) + verts = self._get_verts(data, orient) + + ax.update_datalim(verts) + kws[ax]["verts"].append(verts) + + # TODO fill= is not working here properly + # We could hack a fix, but would be better to handle fill in resolve_color + + kws[ax]["facecolors"].append(resolve_color(self, keys, "", scales)) + kws[ax]["edgecolors"].append(resolve_color(self, keys, "edge", scales)) + + kws[ax]["linewidth"].append(resolved["edgewidth"]) + kws[ax]["linestyle"].append(resolved["edgestyle"]) + + for ax, ax_kws in kws.items(): + ax.add_collection(mpl.collections.PolyCollection(**ax_kws)) + + def _standardize_coordinate_parameters(self, data, orient): + return data + + def _get_verts(self, data, orient): + + dv = {"x": "y", "y": "x"}[orient] + data = data.sort_values(orient) + verts = np.concatenate([ + data[[orient, f"{dv}min"]].to_numpy(), + data[[orient, f"{dv}max"]].to_numpy()[::-1], + ]) + if orient == "y": + verts = verts[:, ::-1] + return verts + + def _legend_artist(self, variables, value, scales): + + keys = {v: value for v in variables} + resolved = resolve_properties(self, keys, scales) + + return mpl.patches.Patch( + facecolor=resolve_color(self, keys, "", scales), + edgecolor=resolve_color(self, keys, "edge", scales), + linewidth=resolved["edgewidth"], + linestyle=resolved["edgestyle"], + **self.artist_kws, + ) + + +@dataclass +class Area(AreaBase, Mark): + """ + An interval mark that fills between baseline and data values. + """ + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(.2, ) + fill: MappableBool = Mappable(True, ) + edgecolor: MappableColor = Mappable(depend="color") + edgealpha: MappableFloat = Mappable(1, ) + edgewidth: MappableFloat = Mappable(rc="patch.linewidth", ) + edgestyle: MappableStyle = Mappable("-", ) + + # TODO should this be settable / mappable? + baseline: MappableFloat = Mappable(0, grouping=False) + + def _standardize_coordinate_parameters(self, data, orient): + dv = {"x": "y", "y": "x"}[orient] + return data.rename(columns={"baseline": f"{dv}min", dv: f"{dv}max"}) + + +@dataclass +class Ribbon(AreaBase, Mark): + """ + An interval mark that fills between minimum and maximum values. + """ + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(.2, ) + fill: MappableBool = Mappable(True, ) + edgecolor: MappableColor = Mappable(depend="color", ) + edgealpha: MappableFloat = Mappable(1, ) + edgewidth: MappableFloat = Mappable(0, ) + edgestyle: MappableFloat = Mappable("-", ) + + def _standardize_coordinate_parameters(self, data, orient): + # dv = {"x": "y", "y": "x"}[orient] + # TODO assert that all(ymax >= ymin)? + return data diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py new file mode 100644 index 0000000000..514686990f --- /dev/null +++ b/seaborn/_marks/bars.py @@ -0,0 +1,108 @@ +from __future__ import annotations +from dataclasses import dataclass + +import matplotlib as mpl + +from seaborn._marks.base import ( + Mark, + Mappable, + MappableBool, + MappableColor, + MappableFloat, + MappableStyle, + resolve_properties, + resolve_color, +) + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any + from matplotlib.artist import Artist + from seaborn._core.scales import Scale + + +@dataclass +class Bar(Mark): + """ + An interval mark drawn between baseline and data values with a width. + """ + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(.7, ) + fill: MappableBool = Mappable(True, ) + edgecolor: MappableColor = Mappable(depend="color", ) + edgealpha: MappableFloat = Mappable(1, ) + edgewidth: MappableFloat = Mappable(rc="patch.linewidth") + edgestyle: MappableStyle = Mappable("-", ) + # pattern: MappableString = Mappable(None, ) # TODO no Property yet + + width: MappableFloat = Mappable(.8, grouping=False) + baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable? + + def _resolve_properties(self, data, scales): + + resolved = resolve_properties(self, data, scales) + + resolved["facecolor"] = resolve_color(self, data, "", scales) + resolved["edgecolor"] = resolve_color(self, data, "edge", scales) + + fc = resolved["facecolor"] + if isinstance(fc, tuple): + resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] + else: + fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? + resolved["facecolor"] = fc + + return resolved + + def _plot(self, split_gen, scales, orient): + + def coords_to_geometry(x, y, w, b): + # TODO possible too slow with lots of bars (e.g. dense hist) + # Why not just use BarCollection? + if orient == "x": + w, h = w, y - b + xy = x - w / 2, b + else: + w, h = x - b, w + xy = b, y - h / 2 + return xy, w, h + + for _, data, ax in split_gen(): + + xys = data[["x", "y"]].to_numpy() + data = self._resolve_properties(data, scales) + + bars = [] + for i, (x, y) in enumerate(xys): + + baseline = data["baseline"][i] + width = data["width"][i] + xy, w, h = coords_to_geometry(x, y, width, baseline) + + bar = mpl.patches.Rectangle( + xy=xy, + width=w, + height=h, + facecolor=data["facecolor"][i], + edgecolor=data["edgecolor"][i], + linewidth=data["edgewidth"][i], + linestyle=data["edgestyle"][i], + ) + ax.add_patch(bar) + bars.append(bar) + + # TODO add container object to ax, line ax.bar does + + def _legend_artist( + self, variables: list[str], value: Any, scales: dict[str, Scale], + ) -> Artist: + # TODO return some sensible default? + key = {v: value for v in variables} + key = self._resolve_properties(key, scales) + artist = mpl.patches.Patch( + facecolor=key["facecolor"], + edgecolor=key["edgecolor"], + linewidth=key["edgewidth"], + linestyle=key["edgestyle"], + ) + return artist diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py new file mode 100644 index 0000000000..7ea59768e8 --- /dev/null +++ b/seaborn/_marks/base.py @@ -0,0 +1,294 @@ +from __future__ import annotations +from dataclasses import dataclass, fields, field + +import numpy as np +import pandas as pd +import matplotlib as mpl + +from seaborn._core.properties import PROPERTIES, Property + +from typing import Any, Callable, Union +from collections.abc import Generator +from numpy import ndarray +from pandas import DataFrame +from matplotlib.artist import Artist +from seaborn._core.properties import RGBATuple, DashPattern, DashPatternWithOffset +from seaborn._core.scales import Scale + + +class Mappable: + def __init__( + self, + val: Any = None, + depend: str | None = None, + rc: str | None = None, + grouping: bool = True, + ): + """ + Property that can be mapped from data or set directly, with flexible defaults. + + Parameters + ---------- + val : Any + Use this value as the default. + depend : str + Use the value of this feature as the default. + rc : str + Use the value of this rcParam as the default. + grouping : bool + If True, use the mapped variable to define groups. + + """ + if depend is not None: + assert depend in PROPERTIES + if rc is not None: + assert rc in mpl.rcParams + + self._val = val + self._rc = rc + self._depend = depend + self._grouping = grouping + + def __repr__(self): + """Nice formatting for when object appears in Mark init signature.""" + if self._val is not None: + s = f"<{repr(self._val)}>" + elif self._depend is not None: + s = f"" + elif self._rc is not None: + s = f"" + else: + s = "" + return s + + @property + def depend(self) -> Any: + """Return the name of the feature to source a default value from.""" + return self._depend + + @property + def grouping(self) -> bool: + return self._grouping + + @property + def default(self) -> Any: + """Get the default value for this feature, or access the relevant rcParam.""" + if self._val is not None: + return self._val + return mpl.rcParams.get(self._rc) + + +# TODO where is the right place to put this kind of type aliasing? + +MappableBool = Union[bool, Mappable] +MappableString = Union[str, Mappable] +MappableFloat = Union[float, Mappable] +MappableColor = Union[str, tuple, Mappable] +MappableStyle = Union[str, DashPattern, DashPatternWithOffset, Mappable] + + +@dataclass +class Mark: + + artist_kws: dict = field(default_factory=dict) + + @property + def _mappable_props(self): + return { + f.name: getattr(self, f.name) for f in fields(self) + if isinstance(f.default, Mappable) + } + + @property + def _grouping_props(self): + # TODO does it make sense to have variation within a Mark's + # properties about whether they are grouping? + return [ + f.name for f in fields(self) + if isinstance(f.default, Mappable) and f.default.grouping + ] + + # TODO make this method private? Would extender every need to call directly? + def _resolve( + self, + data: DataFrame | dict[str, Any], + name: str, + scales: dict[str, Scale] | None = None, + ) -> Any: + """Obtain default, specified, or mapped value for a named feature. + + Parameters + ---------- + data : DataFrame or dict with scalar values + Container with data values for features that will be semantically mapped. + name : string + Identity of the feature / semantic. + scales: dict + Mapping from variable to corresponding scale object. + + Returns + ------- + value or array of values + Outer return type depends on whether `data` is a dict (implying that + we want a single value) or DataFrame (implying that we want an array + of values with matching length). + + """ + feature = self._mappable_props[name] + prop = PROPERTIES.get(name, Property(name)) + directly_specified = not isinstance(feature, Mappable) + return_multiple = isinstance(data, pd.DataFrame) + return_array = return_multiple and not name.endswith("style") + + # Special case width because it needs to be resolved and added to the dataframe + # during layer prep (so the Move operations use it properly). + # TODO how does width *scaling* work, e.g. for violin width by count? + if name == "width": + directly_specified = directly_specified and name not in data + + if directly_specified: + feature = prop.standardize(feature) + if return_multiple: + feature = [feature] * len(data) + if return_array: + feature = np.array(feature) + return feature + + if name in data: + if scales is None or name not in scales: + # TODO Might this obviate the identity scale? Just don't add a scale? + feature = data[name] + else: + feature = scales[name](data[name]) + if return_array: + feature = np.asarray(feature) + return feature + + if feature.depend is not None: + # TODO add source_func or similar to transform the source value? + # e.g. set linewidth as a proportion of pointsize? + return self._resolve(data, feature.depend, scales) + + default = prop.standardize(feature.default) + if return_multiple: + default = [default] * len(data) + if return_array: + default = np.array(default) + return default + + def _infer_orient(self, scales: dict) -> str: # TODO type scales + + # TODO The original version of this (in seaborn._oldcore) did more checking. + # Paring that down here for the prototype to see what restrictions make sense. + + # TODO rethink this to map from scale type to "DV priority" and use that? + # e.g. Nominal > Discrete > Continuous + + x_type = None if "x" not in scales else scales["x"].scale_type + y_type = None if "y" not in scales else scales["y"].scale_type + + if x_type is None: + return "y" + + elif y_type is None: + return "x" + + elif x_type != "nominal" and y_type == "nominal": + return "y" + + elif x_type != "continuous" and y_type == "continuous": + + # TODO should we try to orient based on number of unique values? + + return "x" + + elif x_type == "continuous" and y_type != "continuous": + return "y" + + else: + return "x" + + def _plot( + self, + split_generator: Callable[[], Generator], + scales: dict[str, Scale], + orient: str, + ) -> None: + """Main interface for creating a plot.""" + raise NotImplementedError() + + def _legend_artist( + self, variables: list[str], value: Any, scales: dict[str, Scale], + ) -> Artist: + # TODO return some sensible default? + raise NotImplementedError + + +def resolve_properties( + mark: Mark, data: DataFrame, scales: dict[str, Scale] +) -> dict[str, Any]: + + props = { + name: mark._resolve(data, name, scales) for name in mark._mappable_props + } + return props + + +def resolve_color( + mark: Mark, + data: DataFrame | dict, + prefix: str = "", + scales: dict[str, Scale] | None = None, +) -> RGBATuple | ndarray: + """ + Obtain a default, specified, or mapped value for a color feature. + + This method exists separately to support the relationship between a + color and its corresponding alpha. We want to respect alpha values that + are passed in specified (or mapped) color values but also make use of a + separate `alpha` variable, which can be mapped. This approach may also + be extended to support mapping of specific color channels (i.e. + luminance, chroma) in the future. + + Parameters + ---------- + mark : + Mark with the color property. + data : + Container with data values for features that will be semantically mapped. + prefix : + Support "color", "fillcolor", etc. + + """ + color = mark._resolve(data, f"{prefix}color", scales) + alpha = mark._resolve(data, f"{prefix}alpha", scales) + + def visible(x, axis=None): + """Detect "invisible" colors to set alpha appropriately.""" + # TODO First clause only needed to handle non-rgba arrays, + # which we are trying to handle upstream + return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis) + + # Second check here catches vectors of strings with identity scale + # It could probably be handled better upstream. This is a tricky problem + if np.ndim(color) < 2 and all(isinstance(x, float) for x in color): + if len(color) == 4: + return mpl.colors.to_rgba(color) + alpha = alpha if visible(color) else np.nan + return mpl.colors.to_rgba(color, alpha) + else: + if np.ndim(color) == 2 and color.shape[1] == 4: + return mpl.colors.to_rgba_array(color) + alpha = np.where(visible(color, axis=1), alpha, np.nan) + return mpl.colors.to_rgba_array(color, alpha) + + # TODO should we be implementing fill here too? + # (i.e. set fillalpha to 0 when fill=False) + + +class MultiMark(Mark): + + # TODO implement this as a way to wrap multiple marks (e.g. line and ribbon) + # It should be fairly lightweight, the main thing is to expose the union + # of each mark's parameters and then to call them sequentially in _plot. + pass diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py new file mode 100644 index 0000000000..2d5d8d38a5 --- /dev/null +++ b/seaborn/_marks/basic.py @@ -0,0 +1,67 @@ +from __future__ import annotations +from dataclasses import dataclass + +import matplotlib as mpl + +from seaborn._marks.base import ( + Mark, + Mappable, + MappableFloat, + MappableString, + MappableColor, + resolve_properties, +) + + +# TODO the collection of marks defined here is a holdover from very early +# "let's just got some plots on the screen" phase. They should maybe go elsewhere. + + +@dataclass +class Line(Mark): + """ + A mark connecting data points with sorting along the orientation axis. + """ + + # TODO other semantics (marker?) + + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(1, ) + linewidth: MappableFloat = Mappable(rc="lines.linewidth", ) + linestyle: MappableString = Mappable(rc="lines.linestyle", ) + + # TODO alternately, have Path mark that doesn't sort + sort: bool = True + + def _plot(self, split_gen, scales, orient): + + for keys, data, ax in split_gen(): + + keys = resolve_properties(self, keys, scales) + + if self.sort: + # TODO where to dropna? + data = data.dropna().sort_values(orient) + + line = mpl.lines.Line2D( + data["x"].to_numpy(), + data["y"].to_numpy(), + color=keys["color"], + alpha=keys["alpha"], + linewidth=keys["linewidth"], + linestyle=keys["linestyle"], + **self.artist_kws, # TODO keep? remove? be consistent across marks + ) + ax.add_line(line) + + def _legend_artist(self, variables, value, scales): + + key = resolve_properties(self, {v: value for v in variables}, scales) + + return mpl.lines.Line2D( + [], [], + color=key["color"], + alpha=key["alpha"], + linewidth=key["linewidth"], + linestyle=key["linestyle"], + ) diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py new file mode 100644 index 0000000000..b5d4a0bcc8 --- /dev/null +++ b/seaborn/_marks/scatter.py @@ -0,0 +1,173 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy as np +import matplotlib as mpl + +from seaborn._marks.base import ( + Mark, + Mappable, + MappableBool, + MappableFloat, + MappableString, + MappableColor, + MappableStyle, + resolve_properties, + resolve_color, +) + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any + from matplotlib.artist import Artist + from seaborn._core.scales import Scale + + +@dataclass +class Scatter(Mark): + """ + A point mark defined by strokes with optional fills. + """ + # TODO retype marker as MappableMarker + marker: MappableString = Mappable(rc="scatter.marker", grouping=False) + stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam? + pointsize: MappableFloat = Mappable(3, grouping=False) # TODO rcParam? + color: MappableColor = Mappable("C0", grouping=False) + alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha? + fill: MappableBool = Mappable(True, grouping=False) + fillcolor: MappableColor = Mappable(depend="color", grouping=False) + fillalpha: MappableFloat = Mappable(.2, grouping=False) + + def _resolve_paths(self, data): + + paths = [] + path_cache = {} + marker = data["marker"] + + def get_transformed_path(m): + return m.get_path().transformed(m.get_transform()) + + if isinstance(marker, mpl.markers.MarkerStyle): + return get_transformed_path(marker) + + for m in marker: + if m not in path_cache: + path_cache[m] = get_transformed_path(m) + paths.append(path_cache[m]) + return paths + + def _resolve_properties(self, data, scales): + + resolved = resolve_properties(self, data, scales) + resolved["path"] = self._resolve_paths(resolved) + + if isinstance(data, dict): # TODO need a better way to check + filled_marker = resolved["marker"].is_filled() + else: + filled_marker = [m.is_filled() for m in resolved["marker"]] + + resolved["linewidth"] = resolved["stroke"] + resolved["fill"] = resolved["fill"] & filled_marker + resolved["size"] = resolved["pointsize"] ** 2 + + resolved["edgecolor"] = resolve_color(self, data, "", scales) + resolved["facecolor"] = resolve_color(self, data, "fill", scales) + + # Because only Dot, and not Scatter, has an edgestyle + resolved.setdefault("edgestyle", (0, None)) + + fc = resolved["facecolor"] + if isinstance(fc, tuple): + resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] + else: + fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? + resolved["facecolor"] = fc + + return resolved + + def _plot(self, split_gen, scales, orient): + + # TODO Not backcompat with allowed (but nonfunctional) univariate plots + # (That should be solved upstream by defaulting to "" for unset x/y?) + # (Be mindful of xmin/xmax, etc!) + + # TODO pass scales *into* split_gen? + for keys, data, ax in split_gen(): + + offsets = np.column_stack([data["x"], data["y"]]) + data = self._resolve_properties(data, scales) + + points = mpl.collections.PathCollection( + offsets=offsets, + paths=data["path"], + sizes=data["size"], + facecolors=data["facecolor"], + edgecolors=data["edgecolor"], + linewidths=data["linewidth"], + linestyles=data["edgestyle"], + transOffset=ax.transData, + transform=mpl.transforms.IdentityTransform(), + ) + ax.add_collection(points) + + def _legend_artist( + self, variables: list[str], value: Any, scales: dict[str, Scale], + ) -> Artist: + + key = {v: value for v in variables} + res = self._resolve_properties(key, scales) + + return mpl.collections.PathCollection( + paths=[res["path"]], + sizes=[res["size"]], + facecolors=[res["facecolor"]], + edgecolors=[res["edgecolor"]], + linewidths=[res["linewidth"]], + linestyles=[res["edgestyle"]], + transform=mpl.transforms.IdentityTransform(), + ) + + +# TODO change this to depend on ScatterBase? +@dataclass +class Dot(Scatter): + """ + A point mark defined by shape with optional edges. + """ + marker: MappableString = Mappable("o", grouping=False) + color: MappableColor = Mappable("C0", grouping=False) + alpha: MappableFloat = Mappable(1, grouping=False) + fill: MappableBool = Mappable(True, grouping=False) + edgecolor: MappableColor = Mappable(depend="color", grouping=False) + edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False) + pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam? + edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam? + edgestyle: MappableStyle = Mappable("-", grouping=False) + + def _resolve_properties(self, data, scales): + # TODO this is maybe a little hacky, is there a better abstraction? + resolved = super()._resolve_properties(data, scales) + + filled = resolved["fill"] + + main_stroke = resolved["stroke"] + edge_stroke = resolved["edgewidth"] + resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke) + + # Overwrite the colors that the super class set + main_color = resolve_color(self, data, "", scales) + edge_color = resolve_color(self, data, "edge", scales) + + if not np.isscalar(filled): + # Expand dims to use in np.where with rgba arrays + filled = filled[:, None] + resolved["edgecolor"] = np.where(filled, edge_color, main_color) + + filled = np.squeeze(filled) + if isinstance(main_color, tuple): + main_color = tuple([*main_color[:3], main_color[3] * filled]) + else: + main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled] + resolved["facecolor"] = main_color + + return resolved diff --git a/seaborn/_core.py b/seaborn/_oldcore.py similarity index 99% rename from seaborn/_core.py rename to seaborn/_oldcore.py index 24ddff1d27..b23fd9fb88 100644 --- a/seaborn/_core.py +++ b/seaborn/_oldcore.py @@ -1431,6 +1431,7 @@ class VariableType(UserString): them. If that changes, they should be more verbose. """ + # TODO we can replace this with typing.Literal on Python 3.8+ allowed = "numeric", "datetime", "categorical" def __init__(self, data): diff --git a/seaborn/_stats/__init__.py b/seaborn/_stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py new file mode 100644 index 0000000000..d870499901 --- /dev/null +++ b/seaborn/_stats/aggregation.py @@ -0,0 +1,66 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import ClassVar + +from seaborn._stats.base import Stat + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Callable + from numbers import Number + from seaborn._core.typing import Vector + + +@dataclass +class Agg(Stat): + """ + Aggregate data along the value axis using given method. + + Parameters + ---------- + func + Name of a method understood by Pandas or an arbitrary vector -> scalar function. + + """ + # TODO In current practice we will always have a numeric x/y variable, + # but they may represent non-numeric values. Needs clear documentation. + func: str | Callable[[Vector], Number] = "mean" + + group_by_orient: ClassVar[bool] = True + + def __call__(self, data, groupby, orient, scales): + + var = {"x": "y", "y": "x"}.get(orient) + res = ( + groupby + .agg(data, {var: self.func}) + # TODO Could be an option not to drop NA? + .dropna() + .reset_index(drop=True) + ) + return res + + +@dataclass +class Est(Stat): + + # TODO a string here must be a numpy ufunc? + func: str | Callable[[Vector], Number] = "mean" + + # TODO type errorbar options with literal? + errorbar: str | tuple[str, float] = ("ci", 95) + + group_by_orient: ClassVar[bool] = True + + def __call__(self, data, groupby, orient, scales): + + # TODO port code over from _statistics + ... + + +@dataclass +class Rolling(Stat): + ... + + def __call__(self, data, groupby, orient, scales): + ... diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py new file mode 100644 index 0000000000..bf0d1ddeb0 --- /dev/null +++ b/seaborn/_stats/base.py @@ -0,0 +1,41 @@ +"""Base module for statistical transformations.""" +from __future__ import annotations +from dataclasses import dataclass +from typing import ClassVar + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from pandas import DataFrame + from seaborn._core.groupby import GroupBy + from seaborn._core.scales import Scale + + +@dataclass +class Stat: + """ + Base class for objects that define statistical transformations on plot data. + + The class supports a partial-function application pattern. The object is + initialized with desired parameters and the result is a callable that + accepts and returns dataframes. + + The statistical transformation logic should not add any state to the instance + beyond what is defined with the initialization parameters. + + """ + # Subclasses can declare whether the orient dimension should be used in grouping + # TODO consider whether this should be a parameter. Motivating example: + # use the same KDE class violin plots and univariate density estimation. + # In the former case, we would expect separate densities for each unique + # value on the orient axis, but we would not in the latter case. + group_by_orient: ClassVar[bool] = False + + def __call__( + self, + data: DataFrame, + groupby: GroupBy, + orient: str, + scales: dict[str, Scale], + ) -> DataFrame: + """Apply statistical transform to data subgroups and return combined result.""" + return data diff --git a/seaborn/_stats/histograms.py b/seaborn/_stats/histograms.py new file mode 100644 index 0000000000..5e2a565f7b --- /dev/null +++ b/seaborn/_stats/histograms.py @@ -0,0 +1,149 @@ +from __future__ import annotations +from dataclasses import dataclass +from functools import partial + +import numpy as np +import pandas as pd + +from seaborn._core.groupby import GroupBy +from seaborn._stats.base import Stat + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from numpy.typing import ArrayLike + + +@dataclass +class Hist(Stat): + """ + Bin observations, count them, and optionally normalize or cumulate. + """ + stat: str = "count" # TODO how to do validation on this arg? + + bins: str | int | ArrayLike = "auto" + binwidth: float | None = None + binrange: tuple[float, float] | None = None + common_norm: bool | list[str] = True + common_bins: bool | list[str] = True + cumulative: bool = False + + # TODO Require this to be set here or have interface with scale? + # Q: would Discrete() scale imply binwidth=1 or bins centered on integers? + discrete: bool = False + + def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete): + """Inner function that takes bin parameters as arguments.""" + vals = vals.dropna() + + if binrange is None: + start, stop = vals.min(), vals.max() + else: + start, stop = binrange + + if discrete: + bin_edges = np.arange(start - .5, stop + 1.5) + elif binwidth is not None: + step = binwidth + bin_edges = np.arange(start, stop + step, step) + else: + bin_edges = np.histogram_bin_edges(vals, bins, binrange, weight) + + # TODO warning or cap on too many bins? + + return bin_edges + + def _define_bin_params(self, data, orient, scale_type): + """Given data, return numpy.histogram parameters to define bins.""" + vals = data[orient] + weight = data.get("weight", None) + + # TODO We'll want this for ordinal / discrete scales too + # (Do we need discrete as a parameter or just infer from scale?) + discrete = self.discrete or scale_type == "nominal" + + bin_edges = self._define_bin_edges( + vals, weight, self.bins, self.binwidth, self.binrange, discrete, + ) + + if isinstance(self.bins, (str, int)): + n_bins = len(bin_edges) - 1 + bin_range = bin_edges.min(), bin_edges.max() + bin_kws = dict(bins=n_bins, range=bin_range) + else: + bin_kws = dict(bins=bin_edges) + + return bin_kws + + def _get_bins_and_eval(self, data, orient, groupby, scale_type): + + bin_kws = self._define_bin_params(data, orient, scale_type) + return groupby.apply(data, self._eval, orient, bin_kws) + + def _eval(self, data, orient, bin_kws): + + vals = data[orient] + weight = data.get("weight", None) + + density = self.stat == "density" + hist, bin_edges = np.histogram( + vals, **bin_kws, weights=weight, density=density, + ) + + width = np.diff(bin_edges) + pos = bin_edges[:-1] + width / 2 + other = {"x": "y", "y": "x"}[orient] + + return pd.DataFrame({orient: pos, other: hist, "space": width}) + + def _normalize(self, data, orient): + + other = "y" if orient == "x" else "x" + hist = data[other] + + if self.stat == "probability" or self.stat == "proportion": + hist = hist.astype(float) / hist.sum() + elif self.stat == "percent": + hist = hist.astype(float) / hist.sum() * 100 + elif self.stat == "frequency": + hist = hist.astype(float) / data["space"] + + if self.cumulative: + if self.stat in ["density", "frequency"]: + hist = (hist * data["space"]).cumsum() + else: + hist = hist.cumsum() + + return data.assign(**{other: hist}) + + def __call__(self, data, groupby, orient, scales): + + scale_type = scales[orient].scale_type + grouping_vars = [v for v in data if v in groupby.order] + if not grouping_vars or self.common_bins is True: + bin_kws = self._define_bin_params(data, orient, scale_type) + data = groupby.apply(data, self._eval, orient, bin_kws) + else: + if self.common_bins is False: + bin_groupby = GroupBy(grouping_vars) + else: + bin_groupby = GroupBy(self.common_bins) + data = bin_groupby.apply( + data, self._get_bins_and_eval, orient, groupby, scale_type, + ) + + # TODO Make this an option? + # (This needs to be tested if enabled, and maybe should be in _eval) + # other = {"x": "y", "y": "x"}[orient] + # data = data[data[other] > 0] + + if not grouping_vars or self.common_norm is True: + data = self._normalize(data, orient) + else: + if self.common_norm is False: + norm_grouper = grouping_vars + else: + norm_grouper = self.common_norm + normalize = partial(self._normalize, orient=orient) + data = GroupBy(norm_grouper).apply(data, normalize) + + return data diff --git a/seaborn/_stats/regression.py b/seaborn/_stats/regression.py new file mode 100644 index 0000000000..7b7ddc8d82 --- /dev/null +++ b/seaborn/_stats/regression.py @@ -0,0 +1,47 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy as np +import pandas as pd + +from seaborn._stats.base import Stat + + +@dataclass +class PolyFit(Stat): + """ + Fit a polynomial of the given order and resample data onto predicted curve. + """ + # This is a provisional class that is useful for building out functionality. + # It may or may not change substantially in form or dissappear as we think + # through the organization of the stats subpackage. + + order: int = 2 + gridsize: int = 100 + + def _fit_predict(self, data): + + x = data["x"] + y = data["y"] + if x.nunique() <= self.order: + # TODO warn? + xx = yy = [] + else: + p = np.polyfit(x, y, self.order) + xx = np.linspace(x.min(), x.max(), self.gridsize) + yy = np.polyval(p, xx) + + return pd.DataFrame(dict(x=xx, y=yy)) + + # TODO we should have a way of identifying the method that will be applied + # and then only define __call__ on a base-class of stats with this pattern + + def __call__(self, data, groupby, orient, scales): + + return groupby.apply(data, self._fit_predict) + + +@dataclass +class OLSFit(Stat): + + ... diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 62375ba47d..754ff8912c 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -1,3 +1,4 @@ +from __future__ import annotations from itertools import product from inspect import signature import warnings @@ -8,7 +9,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from ._core import VectorPlotter, variable_type, categorical_order +from ._oldcore import VectorPlotter, variable_type, categorical_order from . import utils from .utils import _check_argument, adjust_legend_subtitles, _draw_figure from .palettes import color_palette, blend_palette @@ -17,7 +18,6 @@ _core_docs, ) - __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"] @@ -308,6 +308,7 @@ def legend(self): class FacetGrid(Grid): """Multi-plot grid for plotting conditional relationships.""" + def __init__( self, data, *, row=None, col=None, hue=None, col_wrap=None, @@ -315,7 +316,7 @@ def __init__( row_order=None, col_order=None, hue_order=None, hue_kws=None, dropna=False, legend_out=True, despine=True, margin_titles=False, xlim=None, ylim=None, subplot_kws=None, - gridspec_kws=None, size=None + gridspec_kws=None, size=None, ): super(FacetGrid, self).__init__() diff --git a/seaborn/categorical.py b/seaborn/categorical.py index a2d77771c9..5ce9268b70 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -18,7 +18,7 @@ import matplotlib.patches as Patches import matplotlib.pyplot as plt -from ._core import ( +from ._oldcore import ( VectorPlotter, variable_type, infer_orient, diff --git a/seaborn/conftest.py b/seaborn/conftest.py index 0797ced5cd..c3ab49ba12 100644 --- a/seaborn/conftest.py +++ b/seaborn/conftest.py @@ -75,6 +75,7 @@ def wide_array(wide_df): return wide_df.to_numpy() +# TODO s/flat/thin? @pytest.fixture def flat_series(rng): diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 8ee7a05972..276e028531 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -13,7 +13,7 @@ from matplotlib.colors import to_rgba from matplotlib.collections import LineCollection -from ._core import ( +from ._oldcore import ( VectorPlotter, ) from ._statistics import ( diff --git a/seaborn/objects.py b/seaborn/objects.py new file mode 100644 index 0000000000..470566b6b1 --- /dev/null +++ b/seaborn/objects.py @@ -0,0 +1,19 @@ +""" +TODO Give this module a useful docstring +""" +from seaborn._core.plot import Plot # noqa: F401 + +from seaborn._marks.base import Mark # noqa: F401 +from seaborn._marks.basic import Line # noqa: F401 +from seaborn._marks.area import Area, Ribbon # noqa: F401 +from seaborn._marks.bars import Bar # noqa: F401 +from seaborn._marks.scatter import Dot, Scatter # noqa: F401 + +from seaborn._stats.base import Stat # noqa: F401 +from seaborn._stats.aggregation import Agg # noqa: F401 +from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401 +from seaborn._stats.histograms import Hist # noqa: F401 + +from seaborn._core.moves import Dodge, Jitter, Shift, Stack # noqa: F401 + +from seaborn._core.scales import Nominal, Continuous, Temporal # noqa: F401 diff --git a/seaborn/relational.py b/seaborn/relational.py index bf5319da7b..1ac2f3c93a 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -5,7 +5,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from ._core import ( +from ._oldcore import ( VectorPlotter, ) from .utils import ( @@ -561,6 +561,7 @@ def plot(self, ax, kws): # See https://github.com/matplotlib/matplotlib/issues/17849 for context m = kws.get("marker", mpl.rcParams.get("marker", "o")) if not isinstance(m, mpl.markers.MarkerStyle): + # TODO in more recent matplotlib (which?) can pass a MarkerStyle here m = mpl.markers.MarkerStyle(m) if m.is_filled(): kws.setdefault("edgecolor", "w") diff --git a/seaborn/tests/_core/__init__.py b/seaborn/tests/_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py new file mode 100644 index 0000000000..b3e0026c19 --- /dev/null +++ b/seaborn/tests/_core/test_data.py @@ -0,0 +1,398 @@ +import functools +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal +from pandas.testing import assert_series_equal + +from seaborn._core.data import PlotData + + +assert_vector_equal = functools.partial(assert_series_equal, check_names=False) + + +class TestPlotData: + + @pytest.fixture + def long_variables(self): + variables = dict(x="x", y="y", color="a", size="z", style="s_cat") + return variables + + def test_named_vectors(self, long_df, long_variables): + + p = PlotData(long_df, long_variables) + assert p.source_data is long_df + assert p.source_vars is long_variables + for key, val in long_variables.items(): + assert p.names[key] == val + assert_vector_equal(p.frame[key], long_df[val]) + + def test_named_and_given_vectors(self, long_df, long_variables): + + long_variables["y"] = long_df["b"] + long_variables["size"] = long_df["z"].to_numpy() + + p = PlotData(long_df, long_variables) + + assert_vector_equal(p.frame["color"], long_df[long_variables["color"]]) + assert_vector_equal(p.frame["y"], long_df["b"]) + assert_vector_equal(p.frame["size"], long_df["z"]) + + assert p.names["color"] == long_variables["color"] + assert p.names["y"] == "b" + assert p.names["size"] is None + + assert p.ids["color"] == long_variables["color"] + assert p.ids["y"] == "b" + assert p.ids["size"] == id(long_variables["size"]) + + def test_index_as_variable(self, long_df, long_variables): + + index = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int) + long_variables["x"] = "i" + p = PlotData(long_df.set_index(index), long_variables) + + assert p.names["x"] == p.ids["x"] == "i" + assert_vector_equal(p.frame["x"], pd.Series(index, index)) + + def test_multiindex_as_variables(self, long_df, long_variables): + + index_i = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int) + index_j = pd.Index(np.arange(len(long_df)) * 3 + 5, name="j", dtype=int) + index = pd.MultiIndex.from_arrays([index_i, index_j]) + long_variables.update({"x": "i", "y": "j"}) + + p = PlotData(long_df.set_index(index), long_variables) + assert_vector_equal(p.frame["x"], pd.Series(index_i, index)) + assert_vector_equal(p.frame["y"], pd.Series(index_j, index)) + + def test_int_as_variable_key(self, rng): + + df = pd.DataFrame(rng.uniform(size=(10, 3))) + + var = "x" + key = 2 + + p = PlotData(df, {var: key}) + assert_vector_equal(p.frame[var], df[key]) + assert p.names[var] == p.ids[var] == str(key) + + def test_int_as_variable_value(self, long_df): + + p = PlotData(long_df, {"x": 0, "y": "y"}) + assert (p.frame["x"] == 0).all() + assert p.names["x"] is None + assert p.ids["x"] == id(0) + + def test_tuple_as_variable_key(self, rng): + + cols = pd.MultiIndex.from_product([("a", "b", "c"), ("x", "y")]) + df = pd.DataFrame(rng.uniform(size=(10, 6)), columns=cols) + + var = "color" + key = ("b", "y") + p = PlotData(df, {var: key}) + assert_vector_equal(p.frame[var], df[key]) + assert p.names[var] == p.ids[var] == str(key) + + def test_dict_as_data(self, long_dict, long_variables): + + p = PlotData(long_dict, long_variables) + assert p.source_data is long_dict + for key, val in long_variables.items(): + assert_vector_equal(p.frame[key], pd.Series(long_dict[val])) + + @pytest.mark.parametrize( + "vector_type", + ["series", "numpy", "list"], + ) + def test_vectors_various_types(self, long_df, long_variables, vector_type): + + variables = {key: long_df[val] for key, val in long_variables.items()} + if vector_type == "numpy": + variables = {key: val.to_numpy() for key, val in variables.items()} + elif vector_type == "list": + variables = {key: val.to_list() for key, val in variables.items()} + + p = PlotData(None, variables) + + assert list(p.names) == list(long_variables) + if vector_type == "series": + assert p.source_vars is variables + assert p.names == p.ids == {key: val.name for key, val in variables.items()} + else: + assert p.names == {key: None for key in variables} + assert p.ids == {key: id(val) for key, val in variables.items()} + + for key, val in long_variables.items(): + if vector_type == "series": + assert_vector_equal(p.frame[key], long_df[val]) + else: + assert_array_equal(p.frame[key], long_df[val]) + + def test_none_as_variable_value(self, long_df): + + p = PlotData(long_df, {"x": "z", "y": None}) + assert list(p.frame.columns) == ["x"] + assert p.names == p.ids == {"x": "z"} + + def test_frame_and_vector_mismatched_lengths(self, long_df): + + vector = np.arange(len(long_df) * 2) + with pytest.raises(ValueError): + PlotData(long_df, {"x": "x", "y": vector}) + + @pytest.mark.parametrize( + "arg", [[], np.array([]), pd.DataFrame()], + ) + def test_empty_data_input(self, arg): + + p = PlotData(arg, {}) + assert p.frame.empty + assert not p.names + + if not isinstance(arg, pd.DataFrame): + p = PlotData(None, dict(x=arg, y=arg)) + assert p.frame.empty + assert not p.names + + def test_index_alignment_series_to_dataframe(self): + + x = [1, 2, 3] + x_index = pd.Index(x, dtype=int) + + y_values = [3, 4, 5] + y_index = pd.Index(y_values, dtype=int) + y = pd.Series(y_values, y_index, name="y") + + data = pd.DataFrame(dict(x=x), index=x_index) + + p = PlotData(data, {"x": "x", "y": y}) + + x_col_expected = pd.Series([1, 2, 3, np.nan, np.nan], np.arange(1, 6)) + y_col_expected = pd.Series([np.nan, np.nan, 3, 4, 5], np.arange(1, 6)) + assert_vector_equal(p.frame["x"], x_col_expected) + assert_vector_equal(p.frame["y"], y_col_expected) + + def test_index_alignment_between_series(self): + + x_index = [1, 2, 3] + x_values = [10, 20, 30] + x = pd.Series(x_values, x_index, name="x") + + y_index = [3, 4, 5] + y_values = [300, 400, 500] + y = pd.Series(y_values, y_index, name="y") + + p = PlotData(None, {"x": x, "y": y}) + + x_col_expected = pd.Series([10, 20, 30, np.nan, np.nan], np.arange(1, 6)) + y_col_expected = pd.Series([np.nan, np.nan, 300, 400, 500], np.arange(1, 6)) + assert_vector_equal(p.frame["x"], x_col_expected) + assert_vector_equal(p.frame["y"], y_col_expected) + + def test_key_not_in_data_raises(self, long_df): + + var = "x" + key = "what" + msg = f"Could not interpret value `{key}` for `{var}`. An entry with this name" + with pytest.raises(ValueError, match=msg): + PlotData(long_df, {var: key}) + + def test_key_with_no_data_raises(self): + + var = "x" + key = "what" + msg = f"Could not interpret value `{key}` for `{var}`. Value is a string," + with pytest.raises(ValueError, match=msg): + PlotData(None, {var: key}) + + def test_data_vector_different_lengths_raises(self, long_df): + + vector = np.arange(len(long_df) - 5) + msg = "Length of ndarray vectors must match length of `data`" + with pytest.raises(ValueError, match=msg): + PlotData(long_df, {"y": vector}) + + def test_undefined_variables_raise(self, long_df): + + with pytest.raises(ValueError): + PlotData(long_df, dict(x="not_in_df")) + + with pytest.raises(ValueError): + PlotData(long_df, dict(x="x", y="not_in_df")) + + with pytest.raises(ValueError): + PlotData(long_df, dict(x="x", y="y", color="not_in_df")) + + def test_contains_operation(self, long_df): + + p = PlotData(long_df, {"x": "y", "color": long_df["a"]}) + assert "x" in p + assert "y" not in p + assert "color" in p + + def test_join_add_variable(self, long_df): + + v1 = {"x": "x", "y": "f"} + v2 = {"color": "a"} + + p1 = PlotData(long_df, v1) + p2 = p1.join(None, v2) + + for var, key in dict(**v1, **v2).items(): + assert var in p2 + assert p2.names[var] == key + assert_vector_equal(p2.frame[var], long_df[key]) + + def test_join_replace_variable(self, long_df): + + v1 = {"x": "x", "y": "y"} + v2 = {"y": "s"} + + p1 = PlotData(long_df, v1) + p2 = p1.join(None, v2) + + variables = v1.copy() + variables.update(v2) + + for var, key in variables.items(): + assert var in p2 + assert p2.names[var] == key + assert_vector_equal(p2.frame[var], long_df[key]) + + def test_join_remove_variable(self, long_df): + + variables = {"x": "x", "y": "f"} + drop_var = "y" + + p1 = PlotData(long_df, variables) + p2 = p1.join(None, {drop_var: None}) + + assert drop_var in p1 + assert drop_var not in p2 + assert drop_var not in p2.frame + assert drop_var not in p2.names + + def test_join_all_operations(self, long_df): + + v1 = {"x": "x", "y": "y", "color": "a"} + v2 = {"y": "s", "size": "s", "color": None} + + p1 = PlotData(long_df, v1) + p2 = p1.join(None, v2) + + for var, key in v2.items(): + if key is None: + assert var not in p2 + else: + assert p2.names[var] == key + assert_vector_equal(p2.frame[var], long_df[key]) + + def test_join_all_operations_same_data(self, long_df): + + v1 = {"x": "x", "y": "y", "color": "a"} + v2 = {"y": "s", "size": "s", "color": None} + + p1 = PlotData(long_df, v1) + p2 = p1.join(long_df, v2) + + for var, key in v2.items(): + if key is None: + assert var not in p2 + else: + assert p2.names[var] == key + assert_vector_equal(p2.frame[var], long_df[key]) + + def test_join_add_variable_new_data(self, long_df): + + d1 = long_df[["x", "y"]] + d2 = long_df[["a", "s"]] + + v1 = {"x": "x", "y": "y"} + v2 = {"color": "a"} + + p1 = PlotData(d1, v1) + p2 = p1.join(d2, v2) + + for var, key in dict(**v1, **v2).items(): + assert p2.names[var] == key + assert_vector_equal(p2.frame[var], long_df[key]) + + def test_join_replace_variable_new_data(self, long_df): + + d1 = long_df[["x", "y"]] + d2 = long_df[["a", "s"]] + + v1 = {"x": "x", "y": "y"} + v2 = {"x": "a"} + + p1 = PlotData(d1, v1) + p2 = p1.join(d2, v2) + + variables = v1.copy() + variables.update(v2) + + for var, key in variables.items(): + assert p2.names[var] == key + assert_vector_equal(p2.frame[var], long_df[key]) + + def test_join_add_variable_different_index(self, long_df): + + d1 = long_df.iloc[:70] + d2 = long_df.iloc[30:] + + v1 = {"x": "a"} + v2 = {"y": "z"} + + p1 = PlotData(d1, v1) + p2 = p1.join(d2, v2) + + (var1, key1), = v1.items() + (var2, key2), = v2.items() + + assert_vector_equal(p2.frame.loc[d1.index, var1], d1[key1]) + assert_vector_equal(p2.frame.loc[d2.index, var2], d2[key2]) + + assert p2.frame.loc[d2.index.difference(d1.index), var1].isna().all() + assert p2.frame.loc[d1.index.difference(d2.index), var2].isna().all() + + def test_join_replace_variable_different_index(self, long_df): + + d1 = long_df.iloc[:70] + d2 = long_df.iloc[30:] + + var = "x" + k1, k2 = "a", "z" + v1 = {var: k1} + v2 = {var: k2} + + p1 = PlotData(d1, v1) + p2 = p1.join(d2, v2) + + (var1, key1), = v1.items() + (var2, key2), = v2.items() + + assert_vector_equal(p2.frame.loc[d2.index, var], d2[k2]) + assert p2.frame.loc[d1.index.difference(d2.index), var].isna().all() + + def test_join_subset_data_inherit_variables(self, long_df): + + sub_df = long_df[long_df["a"] == "b"] + + var = "y" + p1 = PlotData(long_df, {var: var}) + p2 = p1.join(sub_df, None) + + assert_vector_equal(p2.frame.loc[sub_df.index, var], sub_df[var]) + assert p2.frame.loc[long_df.index.difference(sub_df.index), var].isna().all() + + def test_join_multiple_inherits_from_orig(self, rng): + + d1 = pd.DataFrame(dict(a=rng.normal(0, 1, 100), b=rng.normal(0, 1, 100))) + d2 = pd.DataFrame(dict(a=rng.normal(0, 1, 100))) + + p = PlotData(d1, {"x": "a"}).join(d2, {"y": "a"}).join(None, {"y": "a"}) + assert_vector_equal(p.frame["x"], d1["a"]) + assert_vector_equal(p.frame["y"], d1["a"]) diff --git a/seaborn/tests/_core/test_groupby.py b/seaborn/tests/_core/test_groupby.py new file mode 100644 index 0000000000..46888db577 --- /dev/null +++ b/seaborn/tests/_core/test_groupby.py @@ -0,0 +1,134 @@ + +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal + +from seaborn._core.groupby import GroupBy + + +@pytest.fixture +def df(): + + return pd.DataFrame( + columns=["a", "b", "x", "y"], + data=[ + ["a", "g", 1, .2], + ["b", "h", 3, .5], + ["a", "f", 2, .8], + ["a", "h", 1, .3], + ["b", "f", 2, .4], + ] + ) + + +def test_init_from_list(): + g = GroupBy(["a", "c", "b"]) + assert g.order == {"a": None, "c": None, "b": None} + + +def test_init_from_dict(): + order = {"a": [3, 2, 1], "c": None, "b": ["x", "y", "z"]} + g = GroupBy(order) + assert g.order == order + + +def test_init_requires_order(): + + with pytest.raises(ValueError, match="GroupBy requires at least one"): + GroupBy([]) + + +def test_at_least_one_grouping_variable_required(df): + + with pytest.raises(ValueError, match="No grouping variables are present"): + GroupBy(["z"]).agg(df, x="mean") + + +def test_agg_one_grouper(df): + + res = GroupBy(["a"]).agg(df, {"y": "max"}) + assert_array_equal(res.index, [0, 1]) + assert_array_equal(res.columns, ["a", "y"]) + assert_array_equal(res["a"], ["a", "b"]) + assert_array_equal(res["y"], [.8, .5]) + + +def test_agg_two_groupers(df): + + res = GroupBy(["a", "x"]).agg(df, {"y": "min"}) + assert_array_equal(res.index, [0, 1, 2, 3, 4, 5]) + assert_array_equal(res.columns, ["a", "x", "y"]) + assert_array_equal(res["a"], ["a", "a", "a", "b", "b", "b"]) + assert_array_equal(res["x"], [1, 2, 3, 1, 2, 3]) + assert_array_equal(res["y"], [.2, .8, np.nan, np.nan, .4, .5]) + + +def test_agg_two_groupers_ordered(df): + + order = {"b": ["h", "g", "f"], "x": [3, 2, 1]} + res = GroupBy(order).agg(df, {"a": "min", "y": lambda x: x.iloc[0]}) + assert_array_equal(res.index, [0, 1, 2, 3, 4, 5, 6, 7, 8]) + assert_array_equal(res.columns, ["a", "b", "x", "y"]) + assert_array_equal(res["b"], ["h", "h", "h", "g", "g", "g", "f", "f", "f"]) + assert_array_equal(res["x"], [3, 2, 1, 3, 2, 1, 3, 2, 1]) + + T, F = True, False + assert_array_equal(res["a"].isna(), [F, T, F, T, T, F, T, F, T]) + assert_array_equal(res["a"].dropna(), ["b", "a", "a", "a"]) + assert_array_equal(res["y"].dropna(), [.5, .3, .2, .8]) + + +def test_apply_no_grouper(df): + + df = df[["x", "y"]] + res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x")) + assert_array_equal(res.columns, ["x", "y"]) + assert_array_equal(res["x"], df["x"].sort_values()) + assert_array_equal(res["y"], df.loc[np.argsort(df["x"]), "y"]) + + +def test_apply_one_grouper(df): + + res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x")) + assert_array_equal(res.index, [0, 1, 2, 3, 4]) + assert_array_equal(res.columns, ["a", "b", "x", "y"]) + assert_array_equal(res["a"], ["a", "a", "a", "b", "b"]) + assert_array_equal(res["b"], ["g", "h", "f", "f", "h"]) + assert_array_equal(res["x"], [1, 1, 2, 2, 3]) + + +def test_apply_mutate_columns(df): + + xx = np.arange(0, 5) + hats = [] + + def polyfit(df): + fit = np.polyfit(df["x"], df["y"], 1) + hat = np.polyval(fit, xx) + hats.append(hat) + return pd.DataFrame(dict(x=xx, y=hat)) + + res = GroupBy(["a"]).apply(df, polyfit) + assert_array_equal(res.index, np.arange(xx.size * 2)) + assert_array_equal(res.columns, ["a", "x", "y"]) + assert_array_equal(res["a"], ["a"] * xx.size + ["b"] * xx.size) + assert_array_equal(res["x"], xx.tolist() + xx.tolist()) + assert_array_equal(res["y"], np.concatenate(hats)) + + +def test_apply_replace_columns(df): + + def add_sorted_cumsum(df): + + x = df["x"].sort_values() + z = df.loc[x.index, "y"].cumsum() + return pd.DataFrame(dict(x=x.values, z=z.values)) + + res = GroupBy(["a"]).apply(df, add_sorted_cumsum) + assert_array_equal(res.index, df.index) + assert_array_equal(res.columns, ["a", "x", "z"]) + assert_array_equal(res["a"], ["a", "a", "a", "b", "b"]) + assert_array_equal(res["x"], [1, 1, 2, 2, 3]) + assert_array_equal(res["z"], [.2, .5, 1.3, .4, .9]) diff --git a/seaborn/tests/_core/test_moves.py b/seaborn/tests/_core/test_moves.py new file mode 100644 index 0000000000..1a55789a02 --- /dev/null +++ b/seaborn/tests/_core/test_moves.py @@ -0,0 +1,320 @@ + +from itertools import product + +import numpy as np +import pandas as pd +from pandas.testing import assert_series_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal + +from seaborn._core.moves import Dodge, Jitter, Shift, Stack +from seaborn._core.rules import categorical_order +from seaborn._core.groupby import GroupBy + +import pytest + + +class MoveFixtures: + + @pytest.fixture + def df(self, rng): + + n = 50 + data = { + "x": rng.choice([0., 1., 2., 3.], n), + "y": rng.normal(0, 1, n), + "grp2": rng.choice(["a", "b"], n), + "grp3": rng.choice(["x", "y", "z"], n), + "width": 0.8, + "baseline": 0, + } + return pd.DataFrame(data) + + @pytest.fixture + def toy_df(self): + + data = { + "x": [0, 0, 1], + "y": [1, 2, 3], + "grp": ["a", "b", "b"], + "width": .8, + "baseline": 0, + } + return pd.DataFrame(data) + + @pytest.fixture + def toy_df_widths(self, toy_df): + + toy_df["width"] = [.8, .2, .4] + return toy_df + + @pytest.fixture + def toy_df_facets(self): + + data = { + "x": [0, 0, 1, 0, 1, 2], + "y": [1, 2, 3, 1, 2, 3], + "grp": ["a", "b", "a", "b", "a", "b"], + "col": ["x", "x", "x", "y", "y", "y"], + "width": .8, + "baseline": 0, + } + return pd.DataFrame(data) + + +class TestJitter(MoveFixtures): + + def get_groupby(self, data, orient): + other = {"x": "y", "y": "x"}[orient] + variables = [v for v in data if v not in [other, "width"]] + return GroupBy(variables) + + def check_same(self, res, df, *cols): + for col in cols: + assert_series_equal(res[col], df[col]) + + def check_pos(self, res, df, var, limit): + + assert (res[var] != df[var]).all() + assert (res[var] < df[var] + limit / 2).all() + assert (res[var] > df[var] - limit / 2).all() + + def test_width(self, df): + + width = .4 + orient = "x" + groupby = self.get_groupby(df, orient) + res = Jitter(width=width)(df, groupby, orient) + self.check_same(res, df, "y", "grp2", "width") + self.check_pos(res, df, "x", width * df["width"]) + + def test_x(self, df): + + val = .2 + orient = "x" + groupby = self.get_groupby(df, orient) + res = Jitter(x=val)(df, groupby, orient) + self.check_same(res, df, "y", "grp2", "width") + self.check_pos(res, df, "x", val) + + def test_y(self, df): + + val = .2 + orient = "x" + groupby = self.get_groupby(df, orient) + res = Jitter(y=val)(df, groupby, orient) + self.check_same(res, df, "x", "grp2", "width") + self.check_pos(res, df, "y", val) + + def test_seed(self, df): + + kws = dict(width=.2, y=.1, seed=0) + orient = "x" + groupby = self.get_groupby(df, orient) + res1 = Jitter(**kws)(df, groupby, orient) + res2 = Jitter(**kws)(df, groupby, orient) + for var in "xy": + assert_series_equal(res1[var], res2[var]) + + +class TestDodge(MoveFixtures): + + # First some very simple toy examples + + def test_default(self, toy_df): + + groupby = GroupBy(["x", "grp"]) + res = Dodge()(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]), + assert_array_almost_equal(res["x"], [-.2, .2, 1.2]) + assert_array_almost_equal(res["width"], [.4, .4, .4]) + + def test_fill(self, toy_df): + + groupby = GroupBy(["x", "grp"]) + res = Dodge(empty="fill")(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]), + assert_array_almost_equal(res["x"], [-.2, .2, 1]) + assert_array_almost_equal(res["width"], [.4, .4, .8]) + + def test_drop(self, toy_df): + + groupby = GroupBy(["x", "grp"]) + res = Dodge("drop")(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1]) + assert_array_almost_equal(res["width"], [.4, .4, .4]) + + def test_gap(self, toy_df): + + groupby = GroupBy(["x", "grp"]) + res = Dodge(gap=.25)(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1.2]) + assert_array_almost_equal(res["width"], [.3, .3, .3]) + + def test_widths_default(self, toy_df_widths): + + groupby = GroupBy(["x", "grp"]) + res = Dodge()(toy_df_widths, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.08, .32, 1.1]) + assert_array_almost_equal(res["width"], [.64, .16, .2]) + + def test_widths_fill(self, toy_df_widths): + + groupby = GroupBy(["x", "grp"]) + res = Dodge(empty="fill")(toy_df_widths, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.08, .32, 1]) + assert_array_almost_equal(res["width"], [.64, .16, .4]) + + def test_widths_drop(self, toy_df_widths): + + groupby = GroupBy(["x", "grp"]) + res = Dodge(empty="drop")(toy_df_widths, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.08, .32, 1]) + assert_array_almost_equal(res["width"], [.64, .16, .2]) + + def test_faceted_default(self, toy_df_facets): + + groupby = GroupBy(["x", "grp", "col"]) + res = Dodge()(toy_df_facets, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, .8, .2, .8, 2.2]) + assert_array_almost_equal(res["width"], [.4] * 6) + + def test_faceted_fill(self, toy_df_facets): + + groupby = GroupBy(["x", "grp", "col"]) + res = Dodge(empty="fill")(toy_df_facets, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2]) + assert_array_almost_equal(res["width"], [.4, .4, .8, .8, .8, .8]) + + def test_faceted_drop(self, toy_df_facets): + + groupby = GroupBy(["x", "grp", "col"]) + res = Dodge(empty="drop")(toy_df_facets, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2]) + assert_array_almost_equal(res["width"], [.4] * 6) + + def test_orient(self, toy_df): + + df = toy_df.assign(x=toy_df["y"], y=toy_df["x"]) + + groupby = GroupBy(["y", "grp"]) + res = Dodge("drop")(df, groupby, "y") + + assert_array_equal(res["x"], [1, 2, 3]) + assert_array_almost_equal(res["y"], [-.2, .2, 1]) + assert_array_almost_equal(res["width"], [.4, .4, .4]) + + # Now tests with slightly more complicated data + + @pytest.mark.parametrize("grp", ["grp2", "grp3"]) + def test_single_semantic(self, df, grp): + + groupby = GroupBy(["x", grp]) + res = Dodge()(df, groupby, "x") + + levels = categorical_order(df[grp]) + w, n = 0.8, len(levels) + + shifts = np.linspace(0, w - w / n, n) + shifts -= shifts.mean() + + assert_series_equal(res["y"], df["y"]) + assert_series_equal(res["width"], df["width"] / n) + + for val, shift in zip(levels, shifts): + rows = df[grp] == val + assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift) + + def test_two_semantics(self, df): + + groupby = GroupBy(["x", "grp2", "grp3"]) + res = Dodge()(df, groupby, "x") + + levels = categorical_order(df["grp2"]), categorical_order(df["grp3"]) + w, n = 0.8, len(levels[0]) * len(levels[1]) + + shifts = np.linspace(0, w - w / n, n) + shifts -= shifts.mean() + + assert_series_equal(res["y"], df["y"]) + assert_series_equal(res["width"], df["width"] / n) + + for (v2, v3), shift in zip(product(*levels), shifts): + rows = (df["grp2"] == v2) & (df["grp3"] == v3) + assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift) + + +class TestStack(MoveFixtures): + + def test_basic(self, toy_df): + + groupby = GroupBy(["color", "group"]) + res = Stack()(toy_df, groupby, "x") + + assert_array_equal(res["x"], [0, 0, 1]) + assert_array_equal(res["y"], [1, 3, 3]) + assert_array_equal(res["baseline"], [0, 1, 0]) + + def test_faceted(self, toy_df_facets): + + groupby = GroupBy(["color", "group"]) + res = Stack()(toy_df_facets, groupby, "x") + + assert_array_equal(res["x"], [0, 0, 1, 0, 1, 2]) + assert_array_equal(res["y"], [1, 3, 3, 1, 2, 3]) + assert_array_equal(res["baseline"], [0, 1, 0, 0, 0, 0]) + + def test_misssing_data(self, toy_df): + + df = pd.DataFrame({ + "x": [0, 0, 0], + "y": [2, np.nan, 1], + "baseline": [0, 0, 0], + }) + res = Stack()(df, None, "x") + assert_array_equal(res["y"], [2, np.nan, 3]) + assert_array_equal(res["baseline"], [0, np.nan, 2]) + + def test_baseline_homogeneity_check(self, toy_df): + + toy_df["baseline"] = [0, 1, 2] + groupby = GroupBy(["color", "group"]) + move = Stack() + err = "Stack move cannot be used when baselines" + with pytest.raises(RuntimeError, match=err): + move(toy_df, groupby, "x") + + +class TestShift(MoveFixtures): + + def test_default(self, toy_df): + + gb = GroupBy(["color", "group"]) + res = Shift()(toy_df, gb, "x") + for col in toy_df: + assert_series_equal(toy_df[col], res[col]) + + @pytest.mark.parametrize("x,y", [(.3, 0), (0, .2), (.1, .3)]) + def test_moves(self, toy_df, x, y): + + gb = GroupBy(["color", "group"]) + res = Shift(x=x, y=y)(toy_df, gb, "x") + assert_array_equal(res["x"], toy_df["x"] + x) + assert_array_equal(res["y"], toy_df["y"] + y) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py new file mode 100644 index 0000000000..35c6a1715d --- /dev/null +++ b/seaborn/tests/_core/test_plot.py @@ -0,0 +1,1703 @@ +import functools +import itertools +import warnings +import imghdr + +import numpy as np +import pandas as pd +import matplotlib as mpl +import matplotlib.pyplot as plt + +import pytest +from pandas.testing import assert_frame_equal, assert_series_equal +from numpy.testing import assert_array_equal + +from seaborn._core.plot import Plot +from seaborn._core.scales import Nominal, Continuous +from seaborn._core.rules import categorical_order +from seaborn._core.moves import Move +from seaborn._marks.base import Mark +from seaborn._stats.base import Stat +from seaborn.external.version import Version + +assert_vector_equal = functools.partial( + # TODO do we care about int/float dtype consistency? + # Eventually most variables become floats ... but does it matter when? + # (Or rather, does it matter if it happens too early?) + assert_series_equal, check_names=False, check_dtype=False, +) + + +def assert_gridspec_shape(ax, nrows=1, ncols=1): + + gs = ax.get_gridspec() + if Version(mpl.__version__) < Version("3.2"): + assert gs._nrows == nrows + assert gs._ncols == ncols + else: + assert gs.nrows == nrows + assert gs.ncols == ncols + + +class MockMark(Mark): + + _grouping_props = ["color"] + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.passed_keys = [] + self.passed_data = [] + self.passed_axes = [] + self.passed_scales = None + self.passed_orient = None + self.n_splits = 0 + + def _plot(self, split_gen, scales, orient): + + for keys, data, ax in split_gen(): + self.n_splits += 1 + self.passed_keys.append(keys) + self.passed_data.append(data) + self.passed_axes.append(ax) + + self.passed_scales = scales + self.passed_orient = orient + + def _legend_artist(self, variables, value, scales): + + a = mpl.lines.Line2D([], []) + a.variables = variables + a.value = value + return a + + +class TestInit: + + def test_empty(self): + + p = Plot() + assert p._data.source_data is None + assert p._data.source_vars == {} + + def test_data_only(self, long_df): + + p = Plot(long_df) + assert p._data.source_data is long_df + assert p._data.source_vars == {} + + def test_df_and_named_variables(self, long_df): + + variables = {"x": "a", "y": "z"} + p = Plot(long_df, **variables) + for var, col in variables.items(): + assert_vector_equal(p._data.frame[var], long_df[col]) + assert p._data.source_data is long_df + assert p._data.source_vars.keys() == variables.keys() + + def test_df_and_mixed_variables(self, long_df): + + variables = {"x": "a", "y": long_df["z"]} + p = Plot(long_df, **variables) + for var, col in variables.items(): + if isinstance(col, str): + assert_vector_equal(p._data.frame[var], long_df[col]) + else: + assert_vector_equal(p._data.frame[var], col) + assert p._data.source_data is long_df + assert p._data.source_vars.keys() == variables.keys() + + def test_vector_variables_only(self, long_df): + + variables = {"x": long_df["a"], "y": long_df["z"]} + p = Plot(**variables) + for var, col in variables.items(): + assert_vector_equal(p._data.frame[var], col) + assert p._data.source_data is None + assert p._data.source_vars.keys() == variables.keys() + + def test_vector_variables_no_index(self, long_df): + + variables = {"x": long_df["a"].to_numpy(), "y": long_df["z"].to_list()} + p = Plot(**variables) + for var, col in variables.items(): + assert_vector_equal(p._data.frame[var], pd.Series(col)) + assert p._data.names[var] is None + assert p._data.source_data is None + assert p._data.source_vars.keys() == variables.keys() + + def test_data_only_named(self, long_df): + + p = Plot(data=long_df) + assert p._data.source_data is long_df + assert p._data.source_vars == {} + + def test_positional_and_named_data(self, long_df): + + err = "`data` given by both name and position" + with pytest.raises(TypeError, match=err): + Plot(long_df, data=long_df) + + @pytest.mark.parametrize("var", ["x", "y"]) + def test_positional_and_named_xy(self, long_df, var): + + err = f"`{var}` given by both name and position" + with pytest.raises(TypeError, match=err): + Plot(long_df, "a", "b", **{var: "c"}) + + def test_positional_data_x_y(self, long_df): + + p = Plot(long_df, "a", "b") + assert p._data.source_data is long_df + assert list(p._data.source_vars) == ["x", "y"] + + def test_positional_x_y(self, long_df): + + p = Plot(long_df["a"], long_df["b"]) + assert p._data.source_data is None + assert list(p._data.source_vars) == ["x", "y"] + + def test_positional_data_x(self, long_df): + + p = Plot(long_df, "a") + assert p._data.source_data is long_df + assert list(p._data.source_vars) == ["x"] + + def test_positional_x(self, long_df): + + p = Plot(long_df["a"]) + assert p._data.source_data is None + assert list(p._data.source_vars) == ["x"] + + def test_positional_too_many(self, long_df): + + err = r"Plot\(\) accepts no more than 3 positional arguments \(data, x, y\)" + with pytest.raises(TypeError, match=err): + Plot(long_df, "x", "y", "z") + + def test_unknown_keywords(self, long_df): + + err = r"Plot\(\) got unexpected keyword argument\(s\): bad" + with pytest.raises(TypeError, match=err): + Plot(long_df, bad="x") + + +class TestLayerAddition: + + def test_without_data(self, long_df): + + p = Plot(long_df, x="x", y="y").add(MockMark()).plot() + layer, = p._layers + assert_frame_equal(p._data.frame, layer["data"].frame, check_dtype=False) + + def test_with_new_variable_by_name(self, long_df): + + p = Plot(long_df, x="x").add(MockMark(), y="y").plot() + layer, = p._layers + assert layer["data"].frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert_vector_equal(layer["data"].frame[var], long_df[var]) + + def test_with_new_variable_by_vector(self, long_df): + + p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]).plot() + layer, = p._layers + assert layer["data"].frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert_vector_equal(layer["data"].frame[var], long_df[var]) + + def test_with_late_data_definition(self, long_df): + + p = Plot().add(MockMark(), data=long_df, x="x", y="y").plot() + layer, = p._layers + assert layer["data"].frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert_vector_equal(layer["data"].frame[var], long_df[var]) + + def test_with_new_data_definition(self, long_df): + + long_df_sub = long_df.sample(frac=.5) + + p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub).plot() + layer, = p._layers + assert layer["data"].frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert_vector_equal( + layer["data"].frame[var], long_df_sub[var].reindex(long_df.index) + ) + + def test_drop_variable(self, long_df): + + p = Plot(long_df, x="x", y="y").add(MockMark(), y=None).plot() + layer, = p._layers + assert layer["data"].frame.columns.to_list() == ["x"] + assert_vector_equal(layer["data"].frame["x"], long_df["x"], check_dtype=False) + + @pytest.mark.xfail(reason="Need decision on default stat") + def test_stat_default(self): + + class MarkWithDefaultStat(Mark): + default_stat = Stat + + p = Plot().add(MarkWithDefaultStat()) + layer, = p._layers + assert layer["stat"].__class__ is Stat + + def test_stat_nondefault(self): + + class MarkWithDefaultStat(Mark): + default_stat = Stat + + class OtherMockStat(Stat): + pass + + p = Plot().add(MarkWithDefaultStat(), OtherMockStat()) + layer, = p._layers + assert layer["stat"].__class__ is OtherMockStat + + @pytest.mark.parametrize( + "arg,expected", + [("x", "x"), ("y", "y"), ("v", "x"), ("h", "y")], + ) + def test_orient(self, arg, expected): + + class MockStatTrackOrient(Stat): + def __call__(self, data, groupby, orient, scales): + self.orient_at_call = orient + return data + + class MockMoveTrackOrient(Move): + def __call__(self, data, groupby, orient): + self.orient_at_call = orient + return data + + s = MockStatTrackOrient() + m = MockMoveTrackOrient() + Plot(x=[1, 2, 3], y=[1, 2, 3]).add(MockMark(), s, m, orient=arg).plot() + + assert s.orient_at_call == expected + assert m.orient_at_call == expected + + def test_variable_list(self, long_df): + + p = Plot(long_df, x="x", y="y") + assert p._variables == ["x", "y"] + + p = Plot(long_df).add(MockMark(), x="x", y="y") + assert p._variables == ["x", "y"] + + p = Plot(long_df, y="x", color="a").add(MockMark(), x="y") + assert p._variables == ["y", "color", "x"] + + p = Plot(long_df, x="x", y="y", color="a").add(MockMark(), color=None) + assert p._variables == ["x", "y", "color"] + + p = ( + Plot(long_df, x="x", y="y") + .add(MockMark(), color="a") + .add(MockMark(), alpha="s") + ) + assert p._variables == ["x", "y", "color", "alpha"] + + p = Plot(long_df, y="x").pair(x=["a", "b"]) + assert p._variables == ["y", "x0", "x1"] + + def test_type_checks(self): + + p = Plot() + with pytest.raises(TypeError, match="mark must be a Mark instance"): + p.add(MockMark) + + class MockStat(Stat): + pass + + with pytest.raises(TypeError, match="stat must be a Stat instance"): + p.add(MockMark(), MockStat) + + +class TestScaling: + + def test_inference(self, long_df): + + for col, scale_type in zip("zat", ["continuous", "nominal", "temporal"]): + p = Plot(long_df, x=col, y=col).add(MockMark()).plot() + for var in "xy": + assert p._scales[var].scale_type == scale_type + + def test_inference_from_layer_data(self): + + p = Plot().add(MockMark(), x=["a", "b", "c"]).plot() + assert p._scales["x"]("b") == 1 + + def test_inference_joins(self): + + p = ( + Plot(y=pd.Series([1, 2, 3, 4])) + .add(MockMark(), x=pd.Series([1, 2])) + .add(MockMark(), x=pd.Series(["a", "b"], index=[2, 3])) + .plot() + ) + assert p._scales["x"]("a") == 2 + + def test_inferred_categorical_converter(self): + + p = Plot(x=["b", "c", "a"]).add(MockMark()).plot() + ax = p._figure.axes[0] + assert ax.xaxis.convert_units("c") == 1 + + def test_explicit_categorical_converter(self): + + p = Plot(y=[2, 1, 3]).scale(y=Nominal()).add(MockMark()).plot() + ax = p._figure.axes[0] + assert ax.yaxis.convert_units("3") == 2 + + @pytest.mark.xfail(reason="Temporal auto-conversion not implemented") + def test_categorical_as_datetime(self): + + dates = ["1970-01-03", "1970-01-02", "1970-01-04"] + p = Plot(x=dates).scale(...).add(MockMark()).plot() + p # TODO + ... + + def test_faceted_log_scale(self): + + p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot() + for ax in p._figure.axes: + xfm = ax.yaxis.get_transform().transform + assert_array_equal(xfm([1, 10, 100]), [0, 1, 2]) + + def test_paired_single_log_scale(self): + + x0, x1 = [1, 2, 3], [1, 10, 100] + p = Plot().pair(x=[x0, x1]).scale(x1="log").plot() + ax_lin, ax_log = p._figure.axes + xfm_lin = ax_lin.xaxis.get_transform().transform + assert_array_equal(xfm_lin([1, 10, 100]), [1, 10, 100]) + xfm_log = ax_log.xaxis.get_transform().transform + assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2]) + + @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") + def test_log_scale_name(self): + + p = Plot().scale(x="log").plot() + ax = p._figure.axes[0] + assert ax.get_xscale() == "log" + assert ax.get_yscale() == "linear" + + def test_mark_data_log_transform_is_inverted(self, long_df): + + col = "z" + m = MockMark() + Plot(long_df, x=col).scale(x="log").add(m).plot() + assert_vector_equal(m.passed_data[0]["x"], long_df[col]) + + def test_mark_data_log_transfrom_with_stat(self, long_df): + + class Mean(Stat): + group_by_orient = True + + def __call__(self, data, groupby, orient, scales): + other = {"x": "y", "y": "x"}[orient] + return groupby.agg(data, {other: "mean"}) + + col = "z" + grouper = "a" + m = MockMark() + s = Mean() + + Plot(long_df, x=grouper, y=col).scale(y="log").add(m, s).plot() + + expected = ( + long_df[col] + .pipe(np.log) + .groupby(long_df[grouper], sort=False) + .mean() + .pipe(np.exp) + .reset_index(drop=True) + ) + assert_vector_equal(m.passed_data[0]["y"], expected) + + def test_mark_data_from_categorical(self, long_df): + + col = "a" + m = MockMark() + Plot(long_df, x=col).add(m).plot() + + levels = categorical_order(long_df[col]) + level_map = {x: float(i) for i, x in enumerate(levels)} + assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map)) + + def test_mark_data_from_datetime(self, long_df): + + col = "t" + m = MockMark() + Plot(long_df, x=col).add(m).plot() + + expected = long_df[col].map(mpl.dates.date2num) + if Version(mpl.__version__) < Version("3.3"): + expected = expected + mpl.dates.date2num(np.datetime64('0000-12-31')) + + assert_vector_equal(m.passed_data[0]["x"], expected) + + def test_facet_categories(self): + + m = MockMark() + p = Plot(x=["a", "b", "a", "c"]).facet(col=["x", "x", "y", "y"]).add(m).plot() + ax1, ax2 = p._figure.axes + assert len(ax1.get_xticks()) == 3 + assert len(ax2.get_xticks()) == 3 + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3])) + + def test_facet_categories_unshared(self): + + m = MockMark() + p = ( + Plot(x=["a", "b", "a", "c"]) + .facet(col=["x", "x", "y", "y"]) + .configure(sharex=False) + .add(m) + .plot() + ) + ax1, ax2 = p._figure.axes + assert len(ax1.get_xticks()) == 2 + assert len(ax2.get_xticks()) == 2 + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [2, 3])) + + def test_facet_categories_single_dim_shared(self): + + data = [ + ("a", 1, 1), ("b", 1, 1), + ("a", 1, 2), ("c", 1, 2), + ("b", 2, 1), ("d", 2, 1), + ("e", 2, 2), ("e", 2, 1), + ] + df = pd.DataFrame(data, columns=["x", "row", "col"]).assign(y=1) + m = MockMark() + p = ( + Plot(df, x="x") + .facet(row="row", col="col") + .add(m) + .configure(sharex="row") + .plot() + ) + + axs = p._figure.axes + for ax in axs: + assert ax.get_xticks() == [0, 1, 2] + + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3])) + assert_vector_equal(m.passed_data[2]["x"], pd.Series([0., 1., 2.], [4, 5, 7])) + assert_vector_equal(m.passed_data[3]["x"], pd.Series([2.], [6])) + + def test_pair_categories(self): + + data = [("a", "a"), ("b", "c")] + df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1) + m = MockMark() + p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).plot() + + ax1, ax2 = p._figure.axes + assert ax1.get_xticks() == [0, 1] + assert ax2.get_xticks() == [0, 1] + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [0, 1])) + + @pytest.mark.xfail( + Version(mpl.__version__) < Version("3.4.0"), + reason="Sharing paired categorical axes requires matplotlib>3.4.0" + ) + def test_pair_categories_shared(self): + + data = [("a", "a"), ("b", "c")] + df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1) + m = MockMark() + p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).configure(sharex=True).plot() + + for ax in p._figure.axes: + assert ax.get_xticks() == [0, 1, 2] + print(m.passed_data) + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [0, 1])) + + def test_identity_mapping_linewidth(self): + + m = MockMark() + x = y = [1, 2, 3, 4, 5] + lw = pd.Series([.5, .1, .1, .9, 3]) + Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot() + assert_vector_equal(m.passed_scales["linewidth"](lw), lw) + + def test_pair_single_coordinate_stat_orient(self, long_df): + + class MockStat(Stat): + def __call__(self, data, groupby, orient, scales): + self.orient = orient + return data + + s = MockStat() + Plot(long_df).pair(x=["x", "y"]).add(MockMark(), s).plot() + assert s.orient == "x" + + def test_inferred_nominal_passed_to_stat(self): + + class MockStat(Stat): + def __call__(self, data, groupby, orient, scales): + self.scales = scales + return data + + s = MockStat() + y = ["a", "a", "b", "c"] + Plot(y=y).add(MockMark(), s).plot() + assert s.scales["y"].scale_type == "nominal" + + # TODO where should RGB consistency be enforced? + @pytest.mark.xfail( + reason="Correct output representation for color with identity scale undefined" + ) + def test_identity_mapping_color_strings(self): + + m = MockMark() + x = y = [1, 2, 3] + c = ["C0", "C2", "C1"] + Plot(x=x, y=y, color=c).scale(color=None).add(m).plot() + expected = mpl.colors.to_rgba_array(c)[:, :3] + assert_array_equal(m.passed_scales["color"](c), expected) + + def test_identity_mapping_color_tuples(self): + + m = MockMark() + x = y = [1, 2, 3] + c = [(1, 0, 0), (0, 1, 0), (1, 0, 0)] + Plot(x=x, y=y, color=c).scale(color=None).add(m).plot() + expected = mpl.colors.to_rgba_array(c)[:, :3] + assert_array_equal(m.passed_scales["color"](c), expected) + + @pytest.mark.xfail( + reason="Need decision on what to do with scale defined for unused variable" + ) + def test_undefined_variable_raises(self): + + p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale(y=Continuous()) + err = r"No data found for variable\(s\) with explicit scale: {'y'}" + with pytest.raises(RuntimeError, match=err): + p.plot() + + +class TestPlotting: + + def test_matplotlib_object_creation(self): + + p = Plot().plot() + assert isinstance(p._figure, mpl.figure.Figure) + for sub in p._subplots: + assert isinstance(sub["ax"], mpl.axes.Axes) + + def test_empty(self): + + m = MockMark() + Plot().plot() + assert m.n_splits == 0 + + def test_single_split_single_layer(self, long_df): + + m = MockMark() + p = Plot(long_df, x="f", y="z").add(m).plot() + assert m.n_splits == 1 + + assert m.passed_keys[0] == {} + assert m.passed_axes == [sub["ax"] for sub in p._subplots] + for col in p._data.frame: + assert_series_equal(m.passed_data[0][col], p._data.frame[col]) + + def test_single_split_multi_layer(self, long_df): + + vs = [{"color": "a", "linewidth": "z"}, {"color": "b", "pattern": "c"}] + + class NoGroupingMark(MockMark): + _grouping_props = [] + + ms = [NoGroupingMark(), NoGroupingMark()] + Plot(long_df).add(ms[0], **vs[0]).add(ms[1], **vs[1]).plot() + + for m, v in zip(ms, vs): + for var, col in v.items(): + assert_vector_equal(m.passed_data[0][var], long_df[col]) + + def check_splits_single_var( + self, data, mark, data_vars, split_var, split_col, split_keys + ): + + assert mark.n_splits == len(split_keys) + assert mark.passed_keys == [{split_var: key} for key in split_keys] + + for i, key in enumerate(split_keys): + + split_data = data[data[split_col] == key] + for var, col in data_vars.items(): + assert_array_equal(mark.passed_data[i][var], split_data[col]) + + def check_splits_multi_vars( + self, data, mark, data_vars, split_vars, split_cols, split_keys + ): + + assert mark.n_splits == np.prod([len(ks) for ks in split_keys]) + + expected_keys = [ + dict(zip(split_vars, level_keys)) + for level_keys in itertools.product(*split_keys) + ] + assert mark.passed_keys == expected_keys + + for i, keys in enumerate(itertools.product(*split_keys)): + + use_rows = pd.Series(True, data.index) + for var, col, key in zip(split_vars, split_cols, keys): + use_rows &= data[col] == key + split_data = data[use_rows] + for var, col in data_vars.items(): + assert_array_equal(mark.passed_data[i][var], split_data[col]) + + @pytest.mark.parametrize( + "split_var", [ + "color", # explicitly declared on the Mark + "group", # implicitly used for all Mark classes + ]) + def test_one_grouping_variable(self, long_df, split_var): + + split_col = "a" + data_vars = {"x": "f", "y": "z", split_var: split_col} + + m = MockMark() + p = Plot(long_df, **data_vars).add(m).plot() + + split_keys = categorical_order(long_df[split_col]) + sub, *_ = p._subplots + assert m.passed_axes == [sub["ax"] for _ in split_keys] + self.check_splits_single_var( + long_df, m, data_vars, split_var, split_col, split_keys + ) + + def test_two_grouping_variables(self, long_df): + + split_vars = ["color", "group"] + split_cols = ["a", "b"] + data_vars = {"y": "z", **{var: col for var, col in zip(split_vars, split_cols)}} + + m = MockMark() + p = Plot(long_df, **data_vars).add(m).plot() + + split_keys = [categorical_order(long_df[col]) for col in split_cols] + sub, *_ = p._subplots + assert m.passed_axes == [ + sub["ax"] for _ in itertools.product(*split_keys) + ] + self.check_splits_multi_vars( + long_df, m, data_vars, split_vars, split_cols, split_keys + ) + + def test_facets_no_subgroups(self, long_df): + + split_var = "col" + split_col = "b" + data_vars = {"x": "f", "y": "z"} + + m = MockMark() + p = Plot(long_df, **data_vars).facet(**{split_var: split_col}).add(m).plot() + + split_keys = categorical_order(long_df[split_col]) + assert m.passed_axes == list(p._figure.axes) + self.check_splits_single_var( + long_df, m, data_vars, split_var, split_col, split_keys + ) + + def test_facets_one_subgroup(self, long_df): + + facet_var, facet_col = fx = "col", "a" + group_var, group_col = gx = "group", "b" + split_vars, split_cols = zip(*[fx, gx]) + data_vars = {"x": "f", "y": "z", group_var: group_col} + + m = MockMark() + p = ( + Plot(long_df, **data_vars) + .facet(**{facet_var: facet_col}) + .add(m) + .plot() + ) + + split_keys = [categorical_order(long_df[col]) for col in [facet_col, group_col]] + assert m.passed_axes == [ + ax + for ax in list(p._figure.axes) + for _ in categorical_order(long_df[group_col]) + ] + self.check_splits_multi_vars( + long_df, m, data_vars, split_vars, split_cols, split_keys + ) + + def test_layer_specific_facet_disabling(self, long_df): + + axis_vars = {"x": "y", "y": "z"} + row_var = "a" + + m = MockMark() + p = Plot(long_df, **axis_vars).facet(row=row_var).add(m, row=None).plot() + + col_levels = categorical_order(long_df[row_var]) + assert len(p._figure.axes) == len(col_levels) + + for data in m.passed_data: + for var, col in axis_vars.items(): + assert_vector_equal(data[var], long_df[col]) + + def test_paired_variables(self, long_df): + + x = ["x", "y"] + y = ["f", "z"] + + m = MockMark() + Plot(long_df).pair(x, y).add(m).plot() + + var_product = itertools.product(x, y) + + for data, (x_i, y_i) in zip(m.passed_data, var_product): + assert_vector_equal(data["x"], long_df[x_i].astype(float)) + assert_vector_equal(data["y"], long_df[y_i].astype(float)) + + def test_paired_one_dimension(self, long_df): + + x = ["y", "z"] + + m = MockMark() + Plot(long_df).pair(x).add(m).plot() + + for data, x_i in zip(m.passed_data, x): + assert_vector_equal(data["x"], long_df[x_i].astype(float)) + + def test_paired_variables_one_subset(self, long_df): + + x = ["x", "y"] + y = ["f", "z"] + group = "a" + + long_df["x"] = long_df["x"].astype(float) # simplify vector comparison + + m = MockMark() + Plot(long_df, group=group).pair(x, y).add(m).plot() + + groups = categorical_order(long_df[group]) + var_product = itertools.product(x, y, groups) + + for data, (x_i, y_i, g_i) in zip(m.passed_data, var_product): + rows = long_df[group] == g_i + assert_vector_equal(data["x"], long_df.loc[rows, x_i]) + assert_vector_equal(data["y"], long_df.loc[rows, y_i]) + + def test_paired_and_faceted(self, long_df): + + x = ["y", "z"] + y = "f" + row = "c" + + m = MockMark() + Plot(long_df, y=y).facet(row=row).pair(x).add(m).plot() + + facets = categorical_order(long_df[row]) + var_product = itertools.product(x, facets) + + for data, (x_i, f_i) in zip(m.passed_data, var_product): + rows = long_df[row] == f_i + assert_vector_equal(data["x"], long_df.loc[rows, x_i]) + assert_vector_equal(data["y"], long_df.loc[rows, y]) + + def test_movement(self, long_df): + + orig_df = long_df.copy(deep=True) + + class MockMove(Move): + def __call__(self, data, groupby, orient): + return data.assign(x=data["x"] + 1) + + m = MockMark() + Plot(long_df, x="z", y="z").add(m, move=MockMove()).plot() + assert_vector_equal(m.passed_data[0]["x"], long_df["z"] + 1) + assert_vector_equal(m.passed_data[0]["y"], long_df["z"]) + + assert_frame_equal(long_df, orig_df) # Test data was not mutated + + def test_movement_log_scale(self, long_df): + + class MockMove(Move): + def __call__(self, data, groupby, orient): + return data.assign(x=data["x"] - 1) + + m = MockMark() + Plot( + long_df, x="z", y="z" + ).scale(x="log").add(m, move=MockMove()).plot() + assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10) + + def test_methods_clone(self, long_df): + + p1 = Plot(long_df, "x", "y") + p2 = p1.add(MockMark()).facet("a") + + assert p1 is not p2 + assert not p1._layers + assert not p1._facet_spec + + def test_default_is_no_pyplot(self): + + p = Plot().plot() + + assert not plt.get_fignums() + assert isinstance(p._figure, mpl.figure.Figure) + + def test_with_pyplot(self): + + p = Plot().plot(pyplot=True) + + assert len(plt.get_fignums()) == 1 + fig = plt.gcf() + assert p._figure is fig + + def test_show(self): + + p = Plot() + + with warnings.catch_warnings(record=True) as msg: + out = p.show(block=False) + assert out is None + assert not hasattr(p, "_figure") + + assert len(plt.get_fignums()) == 1 + fig = plt.gcf() + + gui_backend = ( + # From https://github.com/matplotlib/matplotlib/issues/20281 + fig.canvas.manager.show != mpl.backend_bases.FigureManagerBase.show + ) + if not gui_backend: + assert msg + + def test_png_representation(self): + + p = Plot() + data, metadata = p._repr_png_() + + assert not hasattr(p, "_figure") + assert isinstance(data, bytes) + assert imghdr.what("", data) == "png" + assert sorted(metadata) == ["height", "width"] + # TODO test retina scaling + + @pytest.mark.xfail(reason="Plot.save not yet implemented") + def test_save(self): + + Plot().save() + + def test_on_axes(self): + + ax = mpl.figure.Figure().subplots() + m = MockMark() + p = Plot().on(ax).add(m).plot() + assert m.passed_axes == [ax] + assert p._figure is ax.figure + + @pytest.mark.parametrize("facet", [True, False]) + def test_on_figure(self, facet): + + f = mpl.figure.Figure() + m = MockMark() + p = Plot().on(f).add(m) + if facet: + p = p.facet(["a", "b"]) + p = p.plot() + assert m.passed_axes == f.axes + assert p._figure is f + + @pytest.mark.skipif( + Version(mpl.__version__) < Version("3.4"), + reason="mpl<3.4 does not have SubFigure", + ) + @pytest.mark.parametrize("facet", [True, False]) + def test_on_subfigure(self, facet): + + sf1, sf2 = mpl.figure.Figure().subfigures(2) + sf1.subplots() + m = MockMark() + p = Plot().on(sf2).add(m) + if facet: + p = p.facet(["a", "b"]) + p = p.plot() + assert m.passed_axes == sf2.figure.axes[1:] + assert p._figure is sf2.figure + + def test_on_type_check(self): + + p = Plot() + with pytest.raises(TypeError, match="The `Plot.on`.+"): + p.on([]) + + def test_on_axes_with_subplots_error(self): + + ax = mpl.figure.Figure().subplots() + + p1 = Plot().facet(["a", "b"]).on(ax) + with pytest.raises(RuntimeError, match="Cannot create multiple subplots"): + p1.plot() + + p2 = Plot().pair([["a", "b"], ["x", "y"]]).on(ax) + with pytest.raises(RuntimeError, match="Cannot create multiple subplots"): + p2.plot() + + def test_axis_labels_from_constructor(self, long_df): + + ax, = Plot(long_df, x="a", y="b").plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "b" + + ax, = Plot(x=long_df["a"], y=long_df["b"].to_numpy()).plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "" + + def test_axis_labels_from_layer(self, long_df): + + m = MockMark() + + ax, = Plot(long_df).add(m, x="a", y="b").plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "b" + + p = Plot().add(m, x=long_df["a"], y=long_df["b"].to_list()) + ax, = p.plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "" + + def test_axis_labels_are_first_name(self, long_df): + + m = MockMark() + p = ( + Plot(long_df, x=long_df["z"].to_list(), y="b") + .add(m, x="a") + .add(m, x="x", y="y") + ) + ax, = p.plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "b" + + +class TestFacetInterface: + + @pytest.fixture(scope="class", params=["row", "col"]) + def dim(self, request): + return request.param + + @pytest.fixture(scope="class", params=["reverse", "subset", "expand"]) + def reorder(self, request): + return { + "reverse": lambda x: x[::-1], + "subset": lambda x: x[:-1], + "expand": lambda x: x + ["z"], + }[request.param] + + def check_facet_results_1d(self, p, df, dim, key, order=None): + + p = p.plot() + + order = categorical_order(df[key], order) + assert len(p._figure.axes) == len(order) + + other_dim = {"row": "col", "col": "row"}[dim] + + for subplot, level in zip(p._subplots, order): + assert subplot[dim] == level + assert subplot[other_dim] is None + assert subplot["ax"].get_title() == f"{key} = {level}" + assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)}) + + def test_1d(self, long_df, dim): + + key = "a" + p = Plot(long_df).facet(**{dim: key}) + self.check_facet_results_1d(p, long_df, dim, key) + + def test_1d_as_vector(self, long_df, dim): + + key = "a" + p = Plot(long_df).facet(**{dim: long_df[key]}) + self.check_facet_results_1d(p, long_df, dim, key) + + def test_1d_with_order(self, long_df, dim, reorder): + + key = "a" + order = reorder(categorical_order(long_df[key])) + p = Plot(long_df).facet(**{dim: key, "order": order}) + self.check_facet_results_1d(p, long_df, dim, key, order) + + def check_facet_results_2d(self, p, df, variables, order=None): + + p = p.plot() + + if order is None: + order = {dim: categorical_order(df[key]) for dim, key in variables.items()} + + levels = itertools.product(*[order[dim] for dim in ["row", "col"]]) + assert len(p._subplots) == len(list(levels)) + + for subplot, (row_level, col_level) in zip(p._subplots, levels): + assert subplot["row"] == row_level + assert subplot["col"] == col_level + assert subplot["axes"].get_title() == ( + f"{variables['row']} = {row_level} | {variables['col']} = {col_level}" + ) + assert_gridspec_shape( + subplot["axes"], len(levels["row"]), len(levels["col"]) + ) + + def test_2d(self, long_df): + + variables = {"row": "a", "col": "c"} + p = Plot(long_df).facet(**variables) + self.check_facet_results_2d(p, long_df, variables) + + def test_2d_with_order(self, long_df, reorder): + + variables = {"row": "a", "col": "c"} + order = { + dim: reorder(categorical_order(long_df[key])) + for dim, key in variables.items() + } + + p = Plot(long_df).facet(**variables, order=order) + self.check_facet_results_2d(p, long_df, variables, order) + + def test_axis_sharing(self, long_df): + + variables = {"row": "a", "col": "c"} + + p = Plot(long_df).facet(**variables) + + p1 = p.plot() + root, *other = p1._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert all(shareset.joined(root, ax) for ax in other) + + p2 = p.configure(sharex=False, sharey=False).plot() + root, *other = p2._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert not any(shareset.joined(root, ax) for ax in other) + + p3 = p.configure(sharex="col", sharey="row").plot() + shape = ( + len(categorical_order(long_df[variables["row"]])), + len(categorical_order(long_df[variables["col"]])), + ) + axes_matrix = np.reshape(p3._figure.axes, shape) + + for (shared, unshared), vectors in zip( + ["yx", "xy"], [axes_matrix, axes_matrix.T] + ): + for root, *other in vectors: + shareset = { + axis: getattr(root, f"get_shared_{axis}_axes")() for axis in "xy" + } + assert all(shareset[shared].joined(root, ax) for ax in other) + assert not any(shareset[unshared].joined(root, ax) for ax in other) + + def test_col_wrapping(self): + + cols = list("abcd") + wrap = 3 + p = Plot().facet(col=cols, wrap=wrap).plot() + + assert len(p._figure.axes) == 4 + assert_gridspec_shape(p._figure.axes[0], len(cols) // wrap + 1, wrap) + + # TODO test axis labels and titles + + def test_row_wrapping(self): + + rows = list("abcd") + wrap = 3 + p = Plot().facet(row=rows, wrap=wrap).plot() + + assert_gridspec_shape(p._figure.axes[0], wrap, len(rows) // wrap + 1) + assert len(p._figure.axes) == 4 + + # TODO test axis labels and titles + + +class TestPairInterface: + + def check_pair_grid(self, p, x, y): + + xys = itertools.product(y, x) + + for (y_i, x_j), subplot in zip(xys, p._subplots): + + ax = subplot["ax"] + assert ax.get_xlabel() == "" if x_j is None else x_j + assert ax.get_ylabel() == "" if y_i is None else y_i + assert_gridspec_shape(subplot["ax"], len(y), len(x)) + + @pytest.mark.parametrize( + "vector_type", [list, np.array, pd.Series, pd.Index] + ) + def test_all_numeric(self, long_df, vector_type): + + x, y = ["x", "y", "z"], ["s", "f"] + p = Plot(long_df).pair(vector_type(x), vector_type(y)).plot() + self.check_pair_grid(p, x, y) + + def test_single_variable_key_raises(self, long_df): + + p = Plot(long_df) + err = "You must pass a sequence of variable keys to `y`" + with pytest.raises(TypeError, match=err): + p.pair(x=["x", "y"], y="z") + + @pytest.mark.parametrize("dim", ["x", "y"]) + def test_single_dimension(self, long_df, dim): + + variables = {"x": None, "y": None} + variables[dim] = ["x", "y", "z"] + p = Plot(long_df).pair(**variables).plot() + variables = {k: [v] if v is None else v for k, v in variables.items()} + self.check_pair_grid(p, **variables) + + def test_non_cross(self, long_df): + + x = ["x", "y"] + y = ["f", "z"] + + p = Plot(long_df).pair(x, y, cross=False).plot() + + for i, subplot in enumerate(p._subplots): + ax = subplot["ax"] + assert ax.get_xlabel() == x[i] + assert ax.get_ylabel() == y[i] + assert_gridspec_shape(ax, 1, len(x)) + + root, *other = p._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert not any(shareset.joined(root, ax) for ax in other) + + def test_with_no_variables(self, long_df): + + all_cols = long_df.columns + + p1 = Plot(long_df).pair() + for axis in "xy": + actual = [ + v for k, v in p1._pair_spec["variables"].items() if k.startswith(axis) + ] + assert actual == all_cols.to_list() + + p2 = Plot(long_df, y="y").pair() + x_vars = [ + v for k, v in p2._pair_spec["variables"].items() if k.startswith("x") + ] + assert all_cols.difference(x_vars).item() == "y" + assert "y" not in p2._pair_spec + + p3 = Plot(long_df, color="a").pair() + for axis in "xy": + x_vars = [ + v for k, v in p3._pair_spec["variables"].items() if k.startswith("x") + ] + assert all_cols.difference(x_vars).item() == "a" + + with pytest.raises(RuntimeError, match="You must pass `data`"): + Plot().pair() + + def test_with_facets(self, long_df): + + x = "x" + y = ["y", "z"] + col = "a" + + p = Plot(long_df, x=x).facet(col).pair(y=y).plot() + + facet_levels = categorical_order(long_df[col]) + dims = itertools.product(y, facet_levels) + + for (y_i, col_i), subplot in zip(dims, p._subplots): + + ax = subplot["ax"] + assert ax.get_xlabel() == x + assert ax.get_ylabel() == y_i + assert ax.get_title() == f"{col} = {col_i}" + assert_gridspec_shape(ax, len(y), len(facet_levels)) + + @pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")]) + def test_error_on_facet_overlap(self, long_df, variables): + + facet_dim, pair_axis = variables + p = Plot(long_df).facet(**{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]}) + expected = f"Cannot facet the {facet_dim} while pairing on `{pair_axis}`." + with pytest.raises(RuntimeError, match=expected): + p.plot() + + @pytest.mark.parametrize("variables", [("columns", "y"), ("rows", "x")]) + def test_error_on_wrap_overlap(self, long_df, variables): + + facet_dim, pair_axis = variables + p = ( + Plot(long_df) + .facet(wrap=2, **{facet_dim[:3]: "a"}) + .pair(**{pair_axis: ["x", "y"]}) + ) + expected = f"Cannot wrap the {facet_dim} while pairing on `{pair_axis}``." + with pytest.raises(RuntimeError, match=expected): + p.plot() + + def test_axis_sharing(self, long_df): + + p = Plot(long_df).pair(x=["a", "b"], y=["y", "z"]) + shape = 2, 2 + + p1 = p.plot() + axes_matrix = np.reshape(p1._figure.axes, shape) + + for root, *other in axes_matrix: # Test row-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert not any(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert all(y_shareset.joined(root, ax) for ax in other) + + for root, *other in axes_matrix.T: # Test col-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert all(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert not any(y_shareset.joined(root, ax) for ax in other) + + p2 = p.configure(sharex=False, sharey=False).plot() + root, *other = p2._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert not any(shareset.joined(root, ax) for ax in other) + + def test_axis_sharing_with_facets(self, long_df): + + p = Plot(long_df, y="y").pair(x=["a", "b"]).facet(row="c").plot() + shape = 2, 2 + + axes_matrix = np.reshape(p._figure.axes, shape) + + for root, *other in axes_matrix: # Test row-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert not any(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert all(y_shareset.joined(root, ax) for ax in other) + + for root, *other in axes_matrix.T: # Test col-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert all(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert all(y_shareset.joined(root, ax) for ax in other) + + def test_x_wrapping(self, long_df): + + x_vars = ["f", "x", "y", "z"] + wrap = 3 + p = Plot(long_df, y="y").pair(x=x_vars, wrap=wrap).plot() + + assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap) + assert len(p._figure.axes) == len(x_vars) + + # TODO test axis labels and visibility + + def test_y_wrapping(self, long_df): + + y_vars = ["f", "x", "y", "z"] + wrap = 3 + p = Plot(long_df, x="x").pair(y=y_vars, wrap=wrap).plot() + + assert_gridspec_shape(p._figure.axes[0], wrap, len(y_vars) // wrap + 1) + assert len(p._figure.axes) == len(y_vars) + + # TODO test axis labels and visibility + + def test_non_cross_wrapping(self, long_df): + + x_vars = ["a", "b", "c", "t"] + y_vars = ["f", "x", "y", "z"] + wrap = 3 + + p = ( + Plot(long_df, x="x") + .pair(x=x_vars, y=y_vars, wrap=wrap, cross=False) + .plot() + ) + + assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap) + assert len(p._figure.axes) == len(x_vars) + + def test_orient_inference(self, long_df): + + orient_list = [] + + class CaptureMoveOrient(Move): + def __call__(self, data, groupby, orient): + orient_list.append(orient) + return data + + ( + Plot(long_df, x="x") + .pair(y=["b", "z"]) + .add(MockMark(), move=CaptureMoveOrient()) + .plot() + ) + + assert orient_list == ["y", "x"] + + def test_two_variables_single_order_error(self, long_df): + + p = Plot(long_df) + err = "When faceting on both col= and row=, passing `order`" + with pytest.raises(RuntimeError, match=err): + p.facet(col="a", row="b", order=["a", "b", "c"]) + + +class TestLabelVisibility: + + def test_single_subplot(self, long_df): + + x, y = "a", "z" + p = Plot(long_df, x=x, y=y).plot() + subplot, *_ = p._subplots + ax = subplot["ax"] + assert ax.xaxis.get_label().get_visible() + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert all(t.get_visible() for t in ax.get_yticklabels()) + + @pytest.mark.parametrize( + "facet_kws,pair_kws", [({"col": "b"}, {}), ({}, {"x": ["x", "y", "f"]})] + ) + def test_1d_column(self, long_df, facet_kws, pair_kws): + + x = None if "x" in pair_kws else "a" + y = "z" + p = Plot(long_df, x=x, y=y).plot() + first, *other = p._subplots + + ax = first["ax"] + assert ax.xaxis.get_label().get_visible() + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in other: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert not ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + @pytest.mark.parametrize( + "facet_kws,pair_kws", [({"row": "b"}, {}), ({}, {"y": ["x", "y", "f"]})] + ) + def test_1d_row(self, long_df, facet_kws, pair_kws): + + x = "z" + y = None if "y" in pair_kws else "z" + p = Plot(long_df, x=x, y=y).plot() + first, *other = p._subplots + + ax = first["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in other: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + assert all(t.get_visible() for t in ax.get_yticklabels()) + + def test_1d_column_wrapped(self): + + p = Plot().facet(col=["a", "b", "c", "d"], wrap=3).plot() + subplots = list(p._subplots) + + for s in [subplots[0], subplots[-1]]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in subplots[1:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[1:-1]: + ax = s["ax"] + assert not ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + ax = subplots[0]["ax"] + assert not ax.xaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + + def test_1d_row_wrapped(self): + + p = Plot().facet(row=["a", "b", "c", "d"], wrap=3).plot() + subplots = list(p._subplots) + + for s in subplots[:-1]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in subplots[-2:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[:-2]: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + + ax = subplots[-1]["ax"] + assert not ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + def test_1d_column_wrapped_non_cross(self, long_df): + + p = ( + Plot(long_df) + .pair(x=["a", "b", "c"], y=["x", "y", "z"], wrap=2, cross=False) + .plot() + ) + for s in p._subplots: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + def test_2d(self): + + p = Plot().facet(col=["a", "b"], row=["x", "y"]).plot() + subplots = list(p._subplots) + + for s in subplots[:2]: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[2:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in [subplots[0], subplots[2]]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in [subplots[1], subplots[3]]: + ax = s["ax"] + assert not ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + def test_2d_unshared(self): + + p = ( + Plot() + .facet(col=["a", "b"], row=["x", "y"]) + .configure(sharex=False, sharey=False) + .plot() + ) + subplots = list(p._subplots) + + for s in subplots[:2]: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[2:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in [subplots[0], subplots[2]]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in [subplots[1], subplots[3]]: + ax = s["ax"] + assert not ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + +class TestLegend: + + @pytest.fixture + def xy(self): + return dict(x=[1, 2, 3, 4], y=[1, 2, 3, 4]) + + def test_single_layer_single_variable(self, xy): + + s = pd.Series(["a", "b", "a", "c"], name="s") + p = Plot(**xy).add(MockMark(), color=s).plot() + e, = p._legend_contents + + labels = categorical_order(s) + + assert e[0] == (s.name, s.name) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == ["color"] + + def test_single_layer_common_variable(self, xy): + + s = pd.Series(["a", "b", "a", "c"], name="s") + sem = dict(color=s, marker=s) + p = Plot(**xy).add(MockMark(), **sem).plot() + e, = p._legend_contents + + labels = categorical_order(s) + + assert e[0] == (s.name, s.name) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == list(sem) + + def test_single_layer_common_unnamed_variable(self, xy): + + s = np.array(["a", "b", "a", "c"]) + sem = dict(color=s, marker=s) + p = Plot(**xy).add(MockMark(), **sem).plot() + + e, = p._legend_contents + + labels = list(np.unique(s)) # assumes sorted order + + assert e[0] == (None, id(s)) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == list(sem) + + def test_single_layer_multi_variable(self, xy): + + s1 = pd.Series(["a", "b", "a", "c"], name="s1") + s2 = pd.Series(["m", "m", "p", "m"], name="s2") + sem = dict(color=s1, marker=s2) + p = Plot(**xy).add(MockMark(), **sem).plot() + e1, e2 = p._legend_contents + + variables = {v.name: k for k, v in sem.items()} + + for e, s in zip([e1, e2], [s1, s2]): + assert e[0] == (s.name, s.name) + + labels = categorical_order(s) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == [variables[s.name]] + + def test_multi_layer_single_variable(self, xy): + + s = pd.Series(["a", "b", "a", "c"], name="s") + p = Plot(**xy, color=s).add(MockMark()).add(MockMark()).plot() + e1, e2 = p._legend_contents + + labels = categorical_order(s) + + for e in [e1, e2]: + assert e[0] == (s.name, s.name) + + labels = categorical_order(s) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == ["color"] + + def test_multi_layer_multi_variable(self, xy): + + s1 = pd.Series(["a", "b", "a", "c"], name="s1") + s2 = pd.Series(["m", "m", "p", "m"], name="s2") + sem = dict(color=s1), dict(marker=s2) + variables = {"s1": "color", "s2": "marker"} + p = Plot(**xy).add(MockMark(), **sem[0]).add(MockMark(), **sem[1]).plot() + e1, e2 = p._legend_contents + + for e, s in zip([e1, e2], [s1, s2]): + assert e[0] == (s.name, s.name) + + labels = categorical_order(s) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == [variables[s.name]] + + def test_multi_layer_different_artists(self, xy): + + class MockMark1(MockMark): + def _legend_artist(self, variables, value, scales): + return mpl.lines.Line2D([], []) + + class MockMark2(MockMark): + def _legend_artist(self, variables, value, scales): + return mpl.patches.Patch() + + s = pd.Series(["a", "b", "a", "c"], name="s") + p = Plot(**xy, color=s).add(MockMark1()).add(MockMark2()).plot() + + legend, = p._figure.legends + + names = categorical_order(s) + labels = [t.get_text() for t in legend.get_texts()] + assert labels == names + + if Version(mpl.__version__) >= Version("3.2"): + contents = legend.get_children()[0] + assert len(contents.findobj(mpl.lines.Line2D)) == len(names) + assert len(contents.findobj(mpl.patches.Patch)) == len(names) + + def test_identity_scale_ignored(self, xy): + + s = pd.Series(["r", "g", "b", "g"]) + p = Plot(**xy).add(MockMark(), color=s).scale(color=None).plot() + assert not p._legend_contents diff --git a/seaborn/tests/_core/test_properties.py b/seaborn/tests/_core/test_properties.py new file mode 100644 index 0000000000..caca153922 --- /dev/null +++ b/seaborn/tests/_core/test_properties.py @@ -0,0 +1,582 @@ + +import numpy as np +import pandas as pd +import matplotlib as mpl +from matplotlib.colors import same_color, to_rgb, to_rgba + +import pytest +from numpy.testing import assert_array_equal + +from seaborn.external.version import Version +from seaborn._core.rules import categorical_order +from seaborn._core.scales import Nominal, Continuous +from seaborn._core.properties import ( + Alpha, + Color, + Coordinate, + EdgeWidth, + Fill, + LineStyle, + LineWidth, + Marker, + PointSize, +) +from seaborn._compat import MarkerStyle +from seaborn.palettes import color_palette + + +class DataFixtures: + + @pytest.fixture + def num_vector(self, long_df): + return long_df["s"] + + @pytest.fixture + def num_order(self, num_vector): + return categorical_order(num_vector) + + @pytest.fixture + def cat_vector(self, long_df): + return long_df["a"] + + @pytest.fixture + def cat_order(self, cat_vector): + return categorical_order(cat_vector) + + @pytest.fixture + def dt_num_vector(self, long_df): + return long_df["t"] + + @pytest.fixture + def dt_cat_vector(self, long_df): + return long_df["d"] + + @pytest.fixture + def vectors(self, num_vector, cat_vector): + return {"num": num_vector, "cat": cat_vector} + + +class TestCoordinate(DataFixtures): + + def test_bad_scale_arg_str(self, num_vector): + + err = "Unknown magic arg for x scale: 'xxx'." + with pytest.raises(ValueError, match=err): + Coordinate("x").infer_scale("xxx", num_vector) + + def test_bad_scale_arg_type(self, cat_vector): + + err = "Magic arg for x scale must be str, not list." + with pytest.raises(TypeError, match=err): + Coordinate("x").infer_scale([1, 2, 3], cat_vector) + + +class TestColor(DataFixtures): + + def assert_same_rgb(self, a, b): + assert_array_equal(a[:, :3], b[:, :3]) + + def test_nominal_default_palette(self, cat_vector, cat_order): + + m = Color().get_mapping(Nominal(), cat_vector) + n = len(cat_order) + actual = m(np.arange(n)) + expected = color_palette(None, n) + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_default_palette_large(self): + + vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) + m = Color().get_mapping(Nominal(), vector) + actual = m(np.arange(26)) + expected = color_palette("husl", 26) + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_named_palette(self, cat_vector, cat_order): + + palette = "Blues" + m = Color().get_mapping(Nominal(palette), cat_vector) + n = len(cat_order) + actual = m(np.arange(n)) + expected = color_palette(palette, n) + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_list_palette(self, cat_vector, cat_order): + + palette = color_palette("Reds", len(cat_order)) + m = Color().get_mapping(Nominal(palette), cat_vector) + actual = m(np.arange(len(palette))) + expected = palette + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_dict_palette(self, cat_vector, cat_order): + + colors = color_palette("Greens") + palette = dict(zip(cat_order, colors)) + m = Color().get_mapping(Nominal(palette), cat_vector) + n = len(cat_order) + actual = m(np.arange(n)) + expected = colors + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_dict_with_missing_keys(self, cat_vector, cat_order): + + palette = dict(zip(cat_order[1:], color_palette("Purples"))) + with pytest.raises(ValueError, match="No entry in color dict"): + Color("color").get_mapping(Nominal(palette), cat_vector) + + def test_nominal_list_too_short(self, cat_vector, cat_order): + + n = len(cat_order) - 1 + palette = color_palette("Oranges", n) + msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)" + with pytest.warns(UserWarning, match=msg): + Color("edgecolor").get_mapping(Nominal(palette), cat_vector) + + def test_nominal_list_too_long(self, cat_vector, cat_order): + + n = len(cat_order) + 1 + palette = color_palette("Oranges", n) + msg = rf"The edgecolor list has more values \({n}\) than needed \({n - 1}\)" + with pytest.warns(UserWarning, match=msg): + Color("edgecolor").get_mapping(Nominal(palette), cat_vector) + + def test_continuous_default_palette(self, num_vector): + + cmap = color_palette("ch:", as_cmap=True) + m = Color().get_mapping(Continuous(), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_named_palette(self, num_vector): + + pal = "flare" + cmap = color_palette(pal, as_cmap=True) + m = Color().get_mapping(Continuous(pal), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_tuple_palette(self, num_vector): + + vals = ("blue", "red") + cmap = color_palette("blend:" + ",".join(vals), as_cmap=True) + m = Color().get_mapping(Continuous(vals), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_callable_palette(self, num_vector): + + cmap = mpl.cm.get_cmap("viridis") + m = Color().get_mapping(Continuous(cmap), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_missing(self): + + x = pd.Series([1, 2, np.nan, 4]) + m = Color().get_mapping(Continuous(), x) + assert np.isnan(m(x)[2]).all() + + def test_bad_scale_values_continuous(self, num_vector): + + with pytest.raises(TypeError, match="Scale values for color with a Continuous"): + Color().get_mapping(Continuous(["r", "g", "b"]), num_vector) + + def test_bad_scale_values_nominal(self, cat_vector): + + with pytest.raises(TypeError, match="Scale values for color with a Nominal"): + Color().get_mapping(Nominal(mpl.cm.get_cmap("viridis")), cat_vector) + + def test_bad_inference_arg(self, cat_vector): + + with pytest.raises(TypeError, match="A single scale argument for color"): + Color().infer_scale(123, cat_vector) + + @pytest.mark.parametrize( + "data_type,scale_class", + [("cat", Nominal), ("num", Continuous)] + ) + def test_default(self, data_type, scale_class, vectors): + + scale = Color().default_scale(vectors[data_type]) + assert isinstance(scale, scale_class) + + def test_default_numeric_data_category_dtype(self, num_vector): + + scale = Color().default_scale(num_vector.astype("category")) + assert isinstance(scale, Nominal) + + def test_default_binary_data(self): + + x = pd.Series([0, 0, 1, 0, 1], dtype=int) + scale = Color().default_scale(x) + assert isinstance(scale, Continuous) + + # TODO default scales for other types + + @pytest.mark.parametrize( + "values,data_type,scale_class", + [ + ("viridis", "cat", Nominal), # Based on variable type + ("viridis", "num", Continuous), # Based on variable type + ("muted", "num", Nominal), # Based on qualitative palette + (["r", "g", "b"], "num", Nominal), # Based on list palette + ({2: "r", 4: "g", 8: "b"}, "num", Nominal), # Based on dict palette + (("r", "b"), "num", Continuous), # Based on tuple / variable type + (("g", "m"), "cat", Nominal), # Based on tuple / variable type + (mpl.cm.get_cmap("inferno"), "num", Continuous), # Based on callable + ] + ) + def test_inference(self, values, data_type, scale_class, vectors): + + scale = Color().infer_scale(values, vectors[data_type]) + assert isinstance(scale, scale_class) + assert scale.values == values + + def test_inference_binary_data(self): + + x = pd.Series([0, 0, 1, 0, 1], dtype=int) + scale = Color().infer_scale("viridis", x) + assert isinstance(scale, Nominal) + + def test_standardization(self): + + f = Color().standardize + assert f("C3") == to_rgb("C3") + assert f("dodgerblue") == to_rgb("dodgerblue") + + assert f((.1, .2, .3)) == (.1, .2, .3) + assert f((.1, .2, .3, .4)) == (.1, .2, .3, .4) + + assert f("#123456") == to_rgb("#123456") + assert f("#12345678") == to_rgba("#12345678") + + if Version(mpl.__version__) >= Version("3.4.0"): + assert f("#123") == to_rgb("#123") + assert f("#1234") == to_rgba("#1234") + + +class ObjectPropertyBase(DataFixtures): + + def assert_equal(self, a, b): + + assert self.unpack(a) == self.unpack(b) + + def unpack(self, x): + return x + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_default(self, data_type, vectors): + + scale = self.prop().default_scale(vectors[data_type]) + assert isinstance(scale, Nominal) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_inference_list(self, data_type, vectors): + + scale = self.prop().infer_scale(self.values, vectors[data_type]) + assert isinstance(scale, Nominal) + assert scale.values == self.values + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_inference_dict(self, data_type, vectors): + + x = vectors[data_type] + values = dict(zip(categorical_order(x), self.values)) + scale = self.prop().infer_scale(values, x) + assert isinstance(scale, Nominal) + assert scale.values == values + + def test_dict_missing(self, cat_vector): + + levels = categorical_order(cat_vector) + values = dict(zip(levels, self.values[:-1])) + scale = Nominal(values) + name = self.prop.__name__.lower() + msg = f"No entry in {name} dictionary for {repr(levels[-1])}" + with pytest.raises(ValueError, match=msg): + self.prop().get_mapping(scale, cat_vector) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_mapping_default(self, data_type, vectors): + + x = vectors[data_type] + mapping = self.prop().get_mapping(Nominal(), x) + n = x.nunique() + for i, expected in enumerate(self.prop()._default_values(n)): + actual, = mapping([i]) + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_mapping_from_list(self, data_type, vectors): + + x = vectors[data_type] + scale = Nominal(self.values) + mapping = self.prop().get_mapping(scale, x) + for i, expected in enumerate(self.standardized_values): + actual, = mapping([i]) + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_mapping_from_dict(self, data_type, vectors): + + x = vectors[data_type] + levels = categorical_order(x) + values = dict(zip(levels, self.values[::-1])) + standardized_values = dict(zip(levels, self.standardized_values[::-1])) + + scale = Nominal(values) + mapping = self.prop().get_mapping(scale, x) + for i, level in enumerate(levels): + actual, = mapping([i]) + expected = standardized_values[level] + self.assert_equal(actual, expected) + + def test_mapping_with_null_value(self, cat_vector): + + mapping = self.prop().get_mapping(Nominal(self.values), cat_vector) + actual = mapping(np.array([0, np.nan, 2])) + v0, _, v2 = self.standardized_values + expected = [v0, self.prop.null_value, v2] + for a, b in zip(actual, expected): + self.assert_equal(a, b) + + def test_unique_default_large_n(self): + + n = 24 + x = pd.Series(np.arange(n)) + mapping = self.prop().get_mapping(Nominal(), x) + assert len({self.unpack(x_i) for x_i in mapping(x)}) == n + + def test_bad_scale_values(self, cat_vector): + + var_name = self.prop.__name__.lower() + with pytest.raises(TypeError, match=f"Scale values for a {var_name} variable"): + self.prop().get_mapping(Nominal(("o", "s")), cat_vector) + + +class TestMarker(ObjectPropertyBase): + + prop = Marker + values = ["o", (5, 2, 0), MarkerStyle("^")] + standardized_values = [MarkerStyle(x) for x in values] + + def unpack(self, x): + return ( + x.get_path(), + x.get_joinstyle(), + x.get_transform().to_values(), + x.get_fillstyle(), + ) + + +class TestLineStyle(ObjectPropertyBase): + + prop = LineStyle + values = ["solid", "--", (1, .5)] + standardized_values = [LineStyle._get_dash_pattern(x) for x in values] + + def test_bad_type(self): + + p = LineStyle() + with pytest.raises(TypeError, match="^Linestyle must be .+, not list.$"): + p.standardize([1, 2]) + + def test_bad_style(self): + + p = LineStyle() + with pytest.raises(ValueError, match="^Linestyle string must be .+, not 'o'.$"): + p.standardize("o") + + def test_bad_dashes(self): + + p = LineStyle() + with pytest.raises(TypeError, match="^Invalid dash pattern"): + p.standardize((1, 2, "x")) + + +class TestFill(DataFixtures): + + @pytest.fixture + def vectors(self): + + return { + "cat": pd.Series(["a", "a", "b"]), + "num": pd.Series([1, 1, 2]), + "bool": pd.Series([True, True, False]) + } + + @pytest.fixture + def cat_vector(self, vectors): + return vectors["cat"] + + @pytest.fixture + def num_vector(self, vectors): + return vectors["num"] + + @pytest.mark.parametrize("data_type", ["cat", "num", "bool"]) + def test_default(self, data_type, vectors): + + x = vectors[data_type] + scale = Fill().default_scale(x) + assert isinstance(scale, Nominal) + + @pytest.mark.parametrize("data_type", ["cat", "num", "bool"]) + def test_inference_list(self, data_type, vectors): + + x = vectors[data_type] + scale = Fill().infer_scale([True, False], x) + assert isinstance(scale, Nominal) + assert scale.values == [True, False] + + @pytest.mark.parametrize("data_type", ["cat", "num", "bool"]) + def test_inference_dict(self, data_type, vectors): + + x = vectors[data_type] + values = dict(zip(x.unique(), [True, False])) + scale = Fill().infer_scale(values, x) + assert isinstance(scale, Nominal) + assert scale.values == values + + def test_mapping_categorical_data(self, cat_vector): + + mapping = Fill().get_mapping(Nominal(), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [True, False, True]) + + def test_mapping_numeric_data(self, num_vector): + + mapping = Fill().get_mapping(Nominal(), num_vector) + assert_array_equal(mapping([0, 1, 0]), [True, False, True]) + + def test_mapping_list(self, cat_vector): + + mapping = Fill().get_mapping(Nominal([False, True]), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [False, True, False]) + + def test_mapping_truthy_list(self, cat_vector): + + mapping = Fill().get_mapping(Nominal([0, 1]), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [False, True, False]) + + def test_mapping_dict(self, cat_vector): + + values = dict(zip(cat_vector.unique(), [False, True])) + mapping = Fill().get_mapping(Nominal(values), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [False, True, False]) + + def test_cycle_warning(self): + + x = pd.Series(["a", "b", "c"]) + with pytest.warns(UserWarning, match="The variable assigned to fill"): + Fill().get_mapping(Nominal(), x) + + def test_values_error(self): + + x = pd.Series(["a", "b"]) + with pytest.raises(TypeError, match="Scale values for fill must be"): + Fill().get_mapping(Nominal("bad_values"), x) + + +class IntervalBase(DataFixtures): + + def norm(self, x): + return (x - x.min()) / (x.max() - x.min()) + + @pytest.mark.parametrize("data_type,scale_class", [ + ("cat", Nominal), + ("num", Continuous), + ]) + def test_default(self, data_type, scale_class, vectors): + + x = vectors[data_type] + scale = self.prop().default_scale(x) + assert isinstance(scale, scale_class) + + @pytest.mark.parametrize("arg,data_type,scale_class", [ + ((1, 3), "cat", Nominal), + ((1, 3), "num", Continuous), + ([1, 2, 3], "cat", Nominal), + ([1, 2, 3], "num", Nominal), + ({"a": 1, "b": 3, "c": 2}, "cat", Nominal), + ({2: 1, 4: 3, 8: 2}, "num", Nominal), + ]) + def test_inference(self, arg, data_type, scale_class, vectors): + + x = vectors[data_type] + scale = self.prop().infer_scale(arg, x) + assert isinstance(scale, scale_class) + assert scale.values == arg + + def test_mapped_interval_numeric(self, num_vector): + + mapping = self.prop().get_mapping(Continuous(), num_vector) + assert_array_equal(mapping([0, 1]), self.prop().default_range) + + def test_mapped_interval_categorical(self, cat_vector): + + mapping = self.prop().get_mapping(Nominal(), cat_vector) + n = cat_vector.nunique() + assert_array_equal(mapping([n - 1, 0]), self.prop().default_range) + + def test_bad_scale_values_numeric_data(self, num_vector): + + prop_name = self.prop.__name__.lower() + err_stem = ( + f"Values for {prop_name} variables with Continuous scale must be 2-tuple" + ) + + with pytest.raises(TypeError, match=f"{err_stem}; not ."): + self.prop().get_mapping(Continuous("abc"), num_vector) + + with pytest.raises(TypeError, match=f"{err_stem}; not 3-tuple."): + self.prop().get_mapping(Continuous((1, 2, 3)), num_vector) + + def test_bad_scale_values_categorical_data(self, cat_vector): + + prop_name = self.prop.__name__.lower() + err_text = f"Values for {prop_name} variables with Nominal scale" + with pytest.raises(TypeError, match=err_text): + self.prop().get_mapping(Nominal("abc"), cat_vector) + + +class TestAlpha(IntervalBase): + prop = Alpha + + +class TestLineWidth(IntervalBase): + prop = LineWidth + + def test_rcparam_default(self): + + with mpl.rc_context({"lines.linewidth": 2}): + assert self.prop().default_range == (1, 4) + + +class TestEdgeWidth(IntervalBase): + prop = EdgeWidth + + def test_rcparam_default(self): + + with mpl.rc_context({"patch.linewidth": 2}): + assert self.prop().default_range == (1, 4) + + +class TestPointSize(IntervalBase): + prop = PointSize + + def test_areal_scaling_numeric(self, num_vector): + + limits = 5, 10 + scale = Continuous(limits) + mapping = self.prop().get_mapping(scale, num_vector) + x = np.linspace(0, 1, 6) + expected = np.sqrt(np.linspace(*np.square(limits), num=len(x))) + assert_array_equal(mapping(x), expected) + + def test_areal_scaling_categorical(self, cat_vector): + + limits = (2, 4) + scale = Nominal(limits) + mapping = self.prop().get_mapping(scale, cat_vector) + assert_array_equal(mapping(np.arange(3)), [4, np.sqrt(10), 2]) diff --git a/seaborn/tests/_core/test_rules.py b/seaborn/tests/_core/test_rules.py new file mode 100644 index 0000000000..655840a8d1 --- /dev/null +++ b/seaborn/tests/_core/test_rules.py @@ -0,0 +1,94 @@ + +import numpy as np +import pandas as pd + +import pytest + +from seaborn._core.rules import ( + VarType, + variable_type, + categorical_order, +) + + +def test_vartype_object(): + + v = VarType("numeric") + assert v == "numeric" + assert v != "categorical" + with pytest.raises(AssertionError): + v == "number" + with pytest.raises(AssertionError): + VarType("date") + + +def test_variable_type(): + + s = pd.Series([1., 2., 3.]) + assert variable_type(s) == "numeric" + assert variable_type(s.astype(int)) == "numeric" + assert variable_type(s.astype(object)) == "numeric" + assert variable_type(s.to_numpy()) == "numeric" + assert variable_type(s.to_list()) == "numeric" + + s = pd.Series([1, 2, 3, np.nan], dtype=object) + assert variable_type(s) == "numeric" + + s = pd.Series([np.nan, np.nan]) + # s = pd.Series([pd.NA, pd.NA]) + assert variable_type(s) == "numeric" + + s = pd.Series(["1", "2", "3"]) + assert variable_type(s) == "categorical" + assert variable_type(s.to_numpy()) == "categorical" + assert variable_type(s.to_list()) == "categorical" + + s = pd.Series([True, False, False]) + assert variable_type(s) == "numeric" + assert variable_type(s, boolean_type="categorical") == "categorical" + s_cat = s.astype("category") + assert variable_type(s_cat, boolean_type="categorical") == "categorical" + assert variable_type(s_cat, boolean_type="numeric") == "categorical" + + s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)]) + assert variable_type(s) == "datetime" + assert variable_type(s.astype(object)) == "datetime" + assert variable_type(s.to_numpy()) == "datetime" + assert variable_type(s.to_list()) == "datetime" + + +def test_categorical_order(): + + x = pd.Series(["a", "c", "c", "b", "a", "d"]) + y = pd.Series([3, 2, 5, 1, 4]) + order = ["a", "b", "c", "d"] + + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(x, order) + assert out == order + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + out = categorical_order(y) + assert out == [1, 2, 3, 4, 5] + + out = categorical_order(pd.Series(y)) + assert out == [1, 2, 3, 4, 5] + + y_cat = pd.Series(pd.Categorical(y, y)) + out = categorical_order(y_cat) + assert out == list(y) + + x = pd.Series(x).astype("category") + out = categorical_order(x) + assert out == list(x.cat.categories) + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + x = pd.Series(["a", np.nan, "c", "c", "b", "a", "d"]) + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py new file mode 100644 index 0000000000..5883d9f2d4 --- /dev/null +++ b/seaborn/tests/_core/test_scales.py @@ -0,0 +1,540 @@ + +import numpy as np +import pandas as pd +import matplotlib as mpl + +import pytest +from numpy.testing import assert_array_equal +from pandas.testing import assert_series_equal + +from seaborn._core.scales import ( + Nominal, + Continuous, + Temporal, + PseudoAxis, +) +from seaborn._core.properties import ( + IntervalProperty, + ObjectProperty, + Coordinate, + Alpha, + Color, + Fill, +) +from seaborn.palettes import color_palette + + +class TestContinuous: + + @pytest.fixture + def x(self): + return pd.Series([1, 3, 9], name="x", dtype=float) + + def test_coordinate_defaults(self, x): + + s = Continuous().setup(x, Coordinate()) + assert_series_equal(s(x), x) + assert_series_equal(s.invert_axis_transform(s(x)), x) + + def test_coordinate_transform(self, x): + + s = Continuous(transform="log").setup(x, Coordinate()) + assert_series_equal(s(x), np.log10(x)) + assert_series_equal(s.invert_axis_transform(s(x)), x) + + def test_coordinate_transform_with_parameter(self, x): + + s = Continuous(transform="pow3").setup(x, Coordinate()) + assert_series_equal(s(x), np.power(x, 3)) + assert_series_equal(s.invert_axis_transform(s(x)), x) + + def test_interval_defaults(self, x): + + s = Continuous().setup(x, IntervalProperty()) + assert_array_equal(s(x), [0, .25, 1]) + + def test_interval_with_range(self, x): + + s = Continuous((1, 3)).setup(x, IntervalProperty()) + assert_array_equal(s(x), [1, 1.5, 3]) + + def test_interval_with_norm(self, x): + + s = Continuous(norm=(3, 7)).setup(x, IntervalProperty()) + assert_array_equal(s(x), [-.5, 0, 1.5]) + + def test_interval_with_range_norm_and_transform(self, x): + + x = pd.Series([1, 10, 100]) + # TODO param order? + s = Continuous((2, 3), (10, 100), "log").setup(x, IntervalProperty()) + assert_array_equal(s(x), [1, 2, 3]) + + def test_color_defaults(self, x): + + cmap = color_palette("ch:", as_cmap=True) + s = Continuous().setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA + + def test_color_named_values(self, x): + + cmap = color_palette("viridis", as_cmap=True) + s = Continuous("viridis").setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA + + def test_color_tuple_values(self, x): + + cmap = color_palette("blend:b,g", as_cmap=True) + s = Continuous(("b", "g")).setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA + + def test_color_callable_values(self, x): + + cmap = color_palette("light:r", as_cmap=True) + s = Continuous(cmap).setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA + + def test_color_with_norm(self, x): + + cmap = color_palette("ch:", as_cmap=True) + s = Continuous(norm=(3, 7)).setup(x, Color()) + assert_array_equal(s(x), cmap([-.5, 0, 1.5])[:, :3]) # FIXME RGBA + + def test_color_with_transform(self, x): + + x = pd.Series([1, 10, 100], name="x", dtype=float) + cmap = color_palette("ch:", as_cmap=True) + s = Continuous(transform="log").setup(x, Color()) + assert_array_equal(s(x), cmap([0, .5, 1])[:, :3]) # FIXME RGBA + + def test_tick_locator(self, x): + + locs = [.2, .5, .8] + locator = mpl.ticker.FixedLocator(locs) + s = Continuous().tick(locator).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), locs) + + def test_tick_locator_input_check(self, x): + + err = "Tick locator must be an instance of .*?, not ." + with pytest.raises(TypeError, match=err): + Continuous().tick((1, 2)) + + def test_tick_upto(self, x): + + for n in [2, 5, 10]: + s = Continuous().tick(upto=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert len(a.major.locator()) <= (n + 1) + + def test_tick_every(self, x): + + for d in [.05, .2, .5]: + s = Continuous().tick(every=d).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert np.allclose(np.diff(a.major.locator()), d) + + def test_tick_every_between(self, x): + + lo, hi = .2, .8 + for d in [.05, .2, .5]: + s = Continuous().tick(every=d, between=(lo, hi)).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + expected = np.arange(lo, hi + d, d) + assert_array_equal(a.major.locator(), expected) + + def test_tick_at(self, x): + + locs = [.2, .5, .9] + s = Continuous().tick(at=locs).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), locs) + + def test_tick_count(self, x): + + n = 8 + s = Continuous().tick(count=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), np.linspace(0, 1, n)) + + def test_tick_count_between(self, x): + + n = 5 + lo, hi = .2, .7 + s = Continuous().tick(count=n, between=(lo, hi)).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), np.linspace(lo, hi, n)) + + def test_tick_minor(self, x): + + n = 3 + s = Continuous().tick(count=2, minor=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + # I am not sure why matplotlib's minor ticks include the + # largest major location but exclude the smalllest one ... + expected = np.linspace(0, 1, n + 2)[1:] + assert_array_equal(a.minor.locator(), expected) + + def test_log_tick_default(self, x): + + s = Continuous(transform="log").setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(.5, 1050) + ticks = a.major.locator() + assert np.allclose(np.diff(np.log10(ticks)), 1) + + def test_log_tick_upto(self, x): + + n = 3 + s = Continuous(transform="log").tick(upto=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + assert a.major.locator.numticks == n + + def test_log_tick_count(self, x): + + with pytest.raises(RuntimeError, match="`count` requires"): + Continuous(transform="log").tick(count=4) + + s = Continuous(transform="log").tick(count=4, between=(1, 1000)) + a = PseudoAxis(s.setup(x, Coordinate()).matplotlib_scale) + a.set_view_interval(.5, 1050) + assert_array_equal(a.major.locator(), [1, 10, 100, 1000]) + + def test_log_tick_every(self, x): + + with pytest.raises(RuntimeError, match="`every` not supported"): + Continuous(transform="log").tick(every=2) + + +class TestNominal: + + @pytest.fixture + def x(self): + return pd.Series(["a", "c", "b", "c"], name="x") + + @pytest.fixture + def y(self): + return pd.Series([1, -1.5, 3, -1.5], name="y") + + def test_coordinate_defaults(self, x): + + s = Nominal().setup(x, Coordinate()) + assert_array_equal(s(x), np.array([0, 1, 2, 1], float)) + assert_array_equal(s.invert_axis_transform(s(x)), s(x)) + + def test_coordinate_with_order(self, x): + + s = Nominal(order=["a", "b", "c"]).setup(x, Coordinate()) + assert_array_equal(s(x), np.array([0, 2, 1, 2], float)) + assert_array_equal(s.invert_axis_transform(s(x)), s(x)) + + def test_coordinate_with_subset_order(self, x): + + s = Nominal(order=["c", "a"]).setup(x, Coordinate()) + assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float)) + assert_array_equal(s.invert_axis_transform(s(x)), s(x)) + + def test_coordinate_axis(self, x): + + ax = mpl.figure.Figure().subplots() + s = Nominal().setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([0, 1, 2, 1], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == ["a", "c", "b"] + + def test_coordinate_axis_with_order(self, x): + + order = ["a", "b", "c"] + ax = mpl.figure.Figure().subplots() + s = Nominal(order=order).setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([0, 2, 1, 2], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == order + + def test_coordinate_axis_with_subset_order(self, x): + + order = ["c", "a"] + ax = mpl.figure.Figure().subplots() + s = Nominal(order=order).setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == [*order, ""] + + def test_coordinate_axis_with_category_dtype(self, x): + + order = ["b", "a", "d", "c"] + x = x.astype(pd.CategoricalDtype(order)) + ax = mpl.figure.Figure().subplots() + s = Nominal().setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([1, 3, 0, 3], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2, 3]) == order + + def test_coordinate_numeric_data(self, y): + + ax = mpl.figure.Figure().subplots() + s = Nominal().setup(y, Coordinate(), ax.yaxis) + assert_array_equal(s(y), np.array([1, 0, 2, 0], float)) + f = ax.yaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == ["-1.5", "1.0", "3.0"] + + def test_coordinate_numeric_data_with_order(self, y): + + order = [1, 4, -1.5] + ax = mpl.figure.Figure().subplots() + s = Nominal(order=order).setup(y, Coordinate(), ax.yaxis) + assert_array_equal(s(y), np.array([0, 2, np.nan, 2], float)) + f = ax.yaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == ["1.0", "4.0", "-1.5"] + + def test_color_defaults(self, x): + + s = Nominal().setup(x, Color()) + cs = color_palette() + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) + + def test_color_named_palette(self, x): + + pal = "flare" + s = Nominal(pal).setup(x, Color()) + cs = color_palette(pal, 3) + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) + + def test_color_list_palette(self, x): + + cs = color_palette("crest", 3) + s = Nominal(cs).setup(x, Color()) + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) + + def test_color_dict_palette(self, x): + + cs = color_palette("crest", 3) + pal = dict(zip("bac", cs)) + s = Nominal(pal).setup(x, Color()) + assert_array_equal(s(x), [cs[1], cs[2], cs[0], cs[2]]) + + def test_color_numeric_data(self, y): + + s = Nominal().setup(y, Color()) + cs = color_palette() + assert_array_equal(s(y), [cs[1], cs[0], cs[2], cs[0]]) + + def test_color_numeric_with_order_subset(self, y): + + s = Nominal(order=[-1.5, 1]).setup(y, Color()) + c1, c2 = color_palette(n_colors=2) + null = (np.nan, np.nan, np.nan) + assert_array_equal(s(y), [c2, c1, null, c1]) + + @pytest.mark.xfail(reason="Need to sort out float/int order") + def test_color_numeric_int_float_mix(self): + + z = pd.Series([1, 2], name="z") + s = Nominal(order=[1.0, 2]).setup(z, Color()) + c1, c2 = color_palette(n_colors=2) + null = (np.nan, np.nan, np.nan) + assert_array_equal(s(z), [c1, null, c2]) + + def test_color_alpha_in_palette(self, x): + + cs = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)] + s = Nominal(cs).setup(x, Color()) + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) + + def test_color_unknown_palette(self, x): + + pal = "not_a_palette" + err = f"{pal} is not a valid palette name" + with pytest.raises(ValueError, match=err): + Nominal(pal).setup(x, Color()) + + def test_object_defaults(self, x): + + class MockProperty(ObjectProperty): + def _default_values(self, n): + return list("xyz"[:n]) + + s = Nominal().setup(x, MockProperty()) + assert s(x) == ["x", "y", "z", "y"] + + def test_object_list(self, x): + + vs = ["x", "y", "z"] + s = Nominal(vs).setup(x, ObjectProperty()) + assert s(x) == ["x", "y", "z", "y"] + + def test_object_dict(self, x): + + vs = {"a": "x", "b": "y", "c": "z"} + s = Nominal(vs).setup(x, ObjectProperty()) + assert s(x) == ["x", "z", "y", "z"] + + def test_object_order(self, x): + + vs = ["x", "y", "z"] + s = Nominal(vs, order=["c", "a", "b"]).setup(x, ObjectProperty()) + assert s(x) == ["y", "x", "z", "x"] + + def test_object_order_subset(self, x): + + vs = ["x", "y"] + s = Nominal(vs, order=["a", "c"]).setup(x, ObjectProperty()) + assert s(x) == ["x", "y", None, "y"] + + def test_objects_that_are_weird(self, x): + + vs = [("x", 1), (None, None, 0), {}] + s = Nominal(vs).setup(x, ObjectProperty()) + assert s(x) == [vs[0], vs[1], vs[2], vs[1]] + + def test_alpha_default(self, x): + + s = Nominal().setup(x, Alpha()) + assert_array_equal(s(x), [.95, .625, .3, .625]) + + def test_fill(self): + + x = pd.Series(["a", "a", "b", "a"], name="x") + s = Nominal().setup(x, Fill()) + assert_array_equal(s(x), [True, True, False, True]) + + def test_fill_dict(self): + + x = pd.Series(["a", "a", "b", "a"], name="x") + vs = {"a": False, "b": True} + s = Nominal(vs).setup(x, Fill()) + assert_array_equal(s(x), [False, False, True, False]) + + def test_fill_nunique_warning(self): + + x = pd.Series(["a", "b", "c", "a", "b"], name="x") + with pytest.warns(UserWarning, match="The variable assigned to fill"): + s = Nominal().setup(x, Fill()) + assert_array_equal(s(x), [True, False, True, True, False]) + + def test_interval_defaults(self, x): + + class MockProperty(IntervalProperty): + _default_range = (1, 2) + + s = Nominal().setup(x, MockProperty()) + assert_array_equal(s(x), [2, 1.5, 1, 1.5]) + + def test_interval_tuple(self, x): + + s = Nominal((1, 2)).setup(x, IntervalProperty()) + assert_array_equal(s(x), [2, 1.5, 1, 1.5]) + + def test_interval_tuple_numeric(self, y): + + s = Nominal((1, 2)).setup(y, IntervalProperty()) + assert_array_equal(s(y), [1.5, 2, 1, 2]) + + def test_interval_list(self, x): + + vs = [2, 5, 4] + s = Nominal(vs).setup(x, IntervalProperty()) + assert_array_equal(s(x), [2, 5, 4, 5]) + + def test_interval_dict(self, x): + + vs = {"a": 3, "b": 4, "c": 6} + s = Nominal(vs).setup(x, IntervalProperty()) + assert_array_equal(s(x), [3, 6, 4, 6]) + + def test_interval_with_transform(self, x): + + class MockProperty(IntervalProperty): + _forward = np.square + _inverse = np.sqrt + + s = Nominal((2, 4)).setup(x, MockProperty()) + assert_array_equal(s(x), [4, np.sqrt(10), 2, np.sqrt(10)]) + + +class TestTemporal: + + @pytest.fixture + def t(self): + dates = pd.to_datetime(["1972-09-27", "1975-06-24", "1980-12-14"]) + return pd.Series(dates, name="x") + + @pytest.fixture + def x(self, t): + return pd.Series(mpl.dates.date2num(t), name=t.name) + + def test_coordinate_defaults(self, t, x): + + s = Temporal().setup(t, Coordinate()) + assert_array_equal(s(t), x) + + def test_interval_defaults(self, t, x): + + s = Temporal().setup(t, IntervalProperty()) + normed = (x - x.min()) / (x.max() - x.min()) + assert_array_equal(s(t), normed) + + def test_interval_with_range(self, t, x): + + values = (1, 3) + s = Temporal((1, 3)).setup(t, IntervalProperty()) + normed = (x - x.min()) / (x.max() - x.min()) + expected = normed * (values[1] - values[0]) + values[0] + assert_array_equal(s(t), expected) + + def test_interval_with_norm(self, t, x): + + norm = t[1], t[2] + s = Temporal(norm=norm).setup(t, IntervalProperty()) + n = mpl.dates.date2num(norm) + normed = (x - n[0]) / (n[1] - n[0]) + assert_array_equal(s(t), normed) + + def test_color_defaults(self, t, x): + + cmap = color_palette("ch:", as_cmap=True) + s = Temporal().setup(t, Color()) + normed = (x - x.min()) / (x.max() - x.min()) + assert_array_equal(s(t), cmap(normed)[:, :3]) # FIXME RGBA + + def test_color_named_values(self, t, x): + + name = "viridis" + cmap = color_palette(name, as_cmap=True) + s = Temporal(name).setup(t, Color()) + normed = (x - x.min()) / (x.max() - x.min()) + assert_array_equal(s(t), cmap(normed)[:, :3]) # FIXME RGBA + + def test_coordinate_axis(self, t, x): + + ax = mpl.figure.Figure().subplots() + s = Temporal().setup(t, Coordinate(), ax.xaxis) + assert_array_equal(s(t), x) + locator = ax.xaxis.get_major_locator() + formatter = ax.xaxis.get_major_formatter() + assert isinstance(locator, mpl.dates.AutoDateLocator) + assert isinstance(formatter, mpl.dates.AutoDateFormatter) + + def test_concise_format(self, t, x): + + ax = mpl.figure.Figure().subplots() + Temporal().format(concise=True).setup(t, Coordinate(), ax.xaxis) + formatter = ax.xaxis.get_major_formatter() + assert isinstance(formatter, mpl.dates.ConciseDateFormatter) + + def test_tick_upto(self, t, x): + + n = 8 + ax = mpl.figure.Figure().subplots() + Temporal().tick(upto=n).setup(t, Coordinate(), ax.xaxis) + locator = ax.xaxis.get_major_locator() + assert set(locator.maxticks.values()) == {n} diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py new file mode 100644 index 0000000000..b7705dfb7b --- /dev/null +++ b/seaborn/tests/_core/test_subplots.py @@ -0,0 +1,513 @@ +import itertools + +import numpy as np +import pytest + +from seaborn._core.subplots import Subplots + + +class TestSpecificationChecks: + + def test_both_facets_and_wrap(self): + + err = "Cannot wrap facets when specifying both `col` and `row`." + facet_spec = {"wrap": 3, "variables": {"col": "a", "row": "b"}} + with pytest.raises(RuntimeError, match=err): + Subplots({}, facet_spec, {}) + + def test_cross_xy_pairing_and_wrap(self): + + err = "Cannot wrap subplots when pairing on both `x` and `y`." + pair_spec = {"wrap": 3, "structure": {"x": ["a", "b"], "y": ["y", "z"]}} + with pytest.raises(RuntimeError, match=err): + Subplots({}, {}, pair_spec) + + def test_col_facets_and_x_pairing(self): + + err = "Cannot facet the columns while pairing on `x`." + facet_spec = {"variables": {"col": "a"}} + pair_spec = {"structure": {"x": ["x", "y"]}} + with pytest.raises(RuntimeError, match=err): + Subplots({}, facet_spec, pair_spec) + + def test_wrapped_columns_and_y_pairing(self): + + err = "Cannot wrap the columns while pairing on `y`." + facet_spec = {"variables": {"col": "a"}, "wrap": 2} + pair_spec = {"structure": {"y": ["x", "y"]}} + with pytest.raises(RuntimeError, match=err): + Subplots({}, facet_spec, pair_spec) + + def test_wrapped_x_pairing_and_facetd_rows(self): + + err = "Cannot wrap the columns while faceting the rows." + facet_spec = {"variables": {"row": "a"}} + pair_spec = {"structure": {"x": ["x", "y"]}, "wrap": 2} + with pytest.raises(RuntimeError, match=err): + Subplots({}, facet_spec, pair_spec) + + +class TestSubplotSpec: + + def test_single_subplot(self): + + s = Subplots({}, {}, {}) + + assert s.n_subplots == 1 + assert s.subplot_spec["ncols"] == 1 + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_single_facet(self): + + key = "a" + order = list("abc") + spec = {"variables": {"col": key}, "structure": {"col": order}} + s = Subplots({}, spec, {}) + + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == len(order) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_two_facets(self): + + col_key = "a" + row_key = "b" + col_order = list("xy") + row_order = list("xyz") + spec = { + "variables": {"col": col_key, "row": row_key}, + "structure": {"col": col_order, "row": row_order}, + + } + s = Subplots({}, spec, {}) + + assert s.n_subplots == len(col_order) * len(row_order) + assert s.subplot_spec["ncols"] == len(col_order) + assert s.subplot_spec["nrows"] == len(row_order) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_col_facet_wrapped(self): + + key = "b" + wrap = 3 + order = list("abcde") + spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap} + s = Subplots({}, spec, {}) + + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == len(order) // wrap + 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_row_facet_wrapped(self): + + key = "b" + wrap = 3 + order = list("abcde") + spec = {"variables": {"row": key}, "structure": {"row": order}, "wrap": wrap} + s = Subplots({}, spec, {}) + + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == len(order) // wrap + 1 + assert s.subplot_spec["nrows"] == wrap + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_col_facet_wrapped_single_row(self): + + key = "b" + order = list("abc") + wrap = len(order) + 2 + spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap} + s = Subplots({}, spec, {}) + + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == len(order) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_x_and_y_paired(self): + + x = ["x", "y", "z"] + y = ["a", "b"] + s = Subplots({}, {}, {"structure": {"x": x, "y": y}}) + + assert s.n_subplots == len(x) * len(y) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] == "row" + + def test_x_paired(self): + + x = ["x", "y", "z"] + s = Subplots({}, {}, {"structure": {"x": x}}) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] is True + + def test_y_paired(self): + + y = ["x", "y", "z"] + s = Subplots({}, {}, {"structure": {"y": y}}) + + assert s.n_subplots == len(y) + assert s.subplot_spec["ncols"] == 1 + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] == "row" + + def test_x_paired_and_wrapped(self): + + x = ["a", "b", "x", "y", "z"] + wrap = 3 + s = Subplots({}, {}, {"structure": {"x": x}, "wrap": wrap}) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == len(x) // wrap + 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is True + + def test_y_paired_and_wrapped(self): + + y = ["a", "b", "x", "y", "z"] + wrap = 2 + s = Subplots({}, {}, {"structure": {"y": y}, "wrap": wrap}) + + assert s.n_subplots == len(y) + assert s.subplot_spec["ncols"] == len(y) // wrap + 1 + assert s.subplot_spec["nrows"] == wrap + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is False + + def test_col_faceted_y_paired(self): + + y = ["x", "y", "z"] + key = "a" + order = list("abc") + facet_spec = {"variables": {"col": key}, "structure": {"col": order}} + pair_spec = {"structure": {"y": y}} + s = Subplots({}, facet_spec, pair_spec) + + assert s.n_subplots == len(order) * len(y) + assert s.subplot_spec["ncols"] == len(order) + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] == "row" + + def test_row_faceted_x_paired(self): + + x = ["f", "s"] + key = "a" + order = list("abc") + facet_spec = {"variables": {"row": key}, "structure": {"row": order}} + pair_spec = {"structure": {"x": x}} + s = Subplots({}, facet_spec, pair_spec) + + assert s.n_subplots == len(order) * len(x) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == len(order) + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] is True + + def test_x_any_y_paired_non_cross(self): + + x = ["a", "b", "c"] + y = ["x", "y", "z"] + spec = {"structure": {"x": x, "y": y}, "cross": False} + s = Subplots({}, {}, spec) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == len(y) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is False + + def test_x_any_y_paired_non_cross_wrapped(self): + + x = ["a", "b", "c"] + y = ["x", "y", "z"] + wrap = 2 + spec = {"structure": {"x": x, "y": y}, "cross": False, "wrap": wrap} + s = Subplots({}, {}, spec) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == len(x) // wrap + 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is False + + def test_forced_unshared_facets(self): + + s = Subplots({"sharex": False, "sharey": "row"}, {}, {}) + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] == "row" + + +class TestSubplotElements: + + def test_single_subplot(self): + + s = Subplots({}, {}, {}) + f = s.init_figure({}, {}) + + assert len(s) == 1 + for i, e in enumerate(s): + for side in ["left", "right", "bottom", "top"]: + assert e[side] + for dim in ["col", "row"]: + assert e[dim] is None + for axis in "xy": + assert e[axis] == axis + assert e["ax"] == f.axes[i] + + @pytest.mark.parametrize("dim", ["col", "row"]) + def test_single_facet_dim(self, dim): + + key = "a" + order = list("abc") + spec = {"variables": {dim: key}, "structure": {dim: order}} + s = Subplots({}, spec, {}) + s.init_figure(spec, {}) + + assert len(s) == len(order) + + for i, e in enumerate(s): + assert e[dim] == order[i] + for axis in "xy": + assert e[axis] == axis + assert e["top"] == (dim == "col" or i == 0) + assert e["bottom"] == (dim == "col" or i == len(order) - 1) + assert e["left"] == (dim == "row" or i == 0) + assert e["right"] == (dim == "row" or i == len(order) - 1) + + @pytest.mark.parametrize("dim", ["col", "row"]) + def test_single_facet_dim_wrapped(self, dim): + + key = "b" + order = list("abc") + wrap = len(order) - 1 + spec = {"variables": {dim: key}, "structure": {dim: order}, "wrap": wrap} + s = Subplots({}, spec, {}) + s.init_figure(spec, {}) + + assert len(s) == len(order) + + for i, e in enumerate(s): + assert e[dim] == order[i] + for axis in "xy": + assert e[axis] == axis + + sides = { + "col": ["top", "bottom", "left", "right"], + "row": ["left", "right", "top", "bottom"], + } + tests = ( + i < wrap, + i >= wrap or i >= len(s) % wrap, + i % wrap == 0, + i % wrap == wrap - 1 or i + 1 == len(s), + ) + + for side, expected in zip(sides[dim], tests): + assert e[side] == expected + + def test_both_facet_dims(self): + + col = "a" + row = "b" + col_order = list("ab") + row_order = list("xyz") + facet_spec = { + "variables": {"col": col, "row": row}, + "structure": {"col": col_order, "row": row_order}, + } + s = Subplots({}, facet_spec, {}) + s.init_figure(facet_spec, {}) + + n_cols = len(col_order) + n_rows = len(row_order) + assert len(s) == n_cols * n_rows + es = list(s) + + for e in es[:n_cols]: + assert e["top"] + for e in es[::n_cols]: + assert e["left"] + for e in es[n_cols - 1::n_cols]: + assert e["right"] + for e in es[-n_cols:]: + assert e["bottom"] + + for e, (row_, col_) in zip(es, itertools.product(row_order, col_order)): + assert e["col"] == col_ + assert e["row"] == row_ + + for e in es: + assert e["x"] == "x" + assert e["y"] == "y" + + @pytest.mark.parametrize("var", ["x", "y"]) + def test_single_paired_var(self, var): + + other_var = {"x": "y", "y": "x"}[var] + pairings = ["x", "y", "z"] + pair_spec = { + "variables": {f"{var}{i}": v for i, v in enumerate(pairings)}, + "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]}, + } + + s = Subplots({}, {}, pair_spec) + s.init_figure(pair_spec) + + assert len(s) == len(pair_spec["structure"][var]) + + for i, e in enumerate(s): + assert e[var] == f"{var}{i}" + assert e[other_var] == other_var + assert e["col"] is e["row"] is None + + tests = i == 0, True, True, i == len(s) - 1 + sides = { + "x": ["left", "right", "top", "bottom"], + "y": ["top", "bottom", "left", "right"], + } + + for side, expected in zip(sides[var], tests): + assert e[side] == expected + + @pytest.mark.parametrize("var", ["x", "y"]) + def test_single_paired_var_wrapped(self, var): + + other_var = {"x": "y", "y": "x"}[var] + pairings = ["x", "y", "z", "a", "b"] + wrap = len(pairings) - 2 + pair_spec = { + "variables": {f"{var}{i}": val for i, val in enumerate(pairings)}, + "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]}, + "wrap": wrap + } + s = Subplots({}, {}, pair_spec) + s.init_figure(pair_spec) + + assert len(s) == len(pairings) + + for i, e in enumerate(s): + assert e[var] == f"{var}{i}" + assert e[other_var] == other_var + assert e["col"] is e["row"] is None + + tests = ( + i < wrap, + i >= wrap or i >= len(s) % wrap, + i % wrap == 0, + i % wrap == wrap - 1 or i + 1 == len(s), + ) + sides = { + "x": ["top", "bottom", "left", "right"], + "y": ["left", "right", "top", "bottom"], + } + for side, expected in zip(sides[var], tests): + assert e[side] == expected + + def test_both_paired_variables(self): + + x = ["x0", "x1"] + y = ["y0", "y1", "y2"] + pair_spec = {"structure": {"x": x, "y": y}} + s = Subplots({}, {}, pair_spec) + s.init_figure(pair_spec) + + n_cols = len(x) + n_rows = len(y) + assert len(s) == n_cols * n_rows + es = list(s) + + for e in es[:n_cols]: + assert e["top"] + for e in es[::n_cols]: + assert e["left"] + for e in es[n_cols - 1::n_cols]: + assert e["right"] + for e in es[-n_cols:]: + assert e["bottom"] + + for e in es: + assert e["col"] is e["row"] is None + + for i in range(len(y)): + for j in range(len(x)): + e = es[i * len(x) + j] + assert e["x"] == f"x{j}" + assert e["y"] == f"y{i}" + + def test_both_paired_non_cross(self): + + pair_spec = { + "structure": {"x": ["x0", "x1", "x2"], "y": ["y0", "y1", "y2"]}, + "cross": False + } + s = Subplots({}, {}, pair_spec) + s.init_figure(pair_spec) + + for i, e in enumerate(s): + assert e["x"] == f"x{i}" + assert e["y"] == f"y{i}" + assert e["col"] is e["row"] is None + assert e["left"] == (i == 0) + assert e["right"] == (i == (len(s) - 1)) + assert e["top"] + assert e["bottom"] + + @pytest.mark.parametrize("dim,var", [("col", "y"), ("row", "x")]) + def test_one_facet_one_paired(self, dim, var): + + other_var = {"x": "y", "y": "x"}[var] + other_dim = {"col": "row", "row": "col"}[dim] + order = list("abc") + facet_spec = {"variables": {dim: "s"}, "structure": {dim: order}} + + pairings = ["x", "y", "t"] + pair_spec = { + "variables": {f"{var}{i}": val for i, val in enumerate(pairings)}, + "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]}, + } + + s = Subplots({}, facet_spec, pair_spec) + s.init_figure(pair_spec) + + n_cols = len(order) if dim == "col" else len(pairings) + n_rows = len(order) if dim == "row" else len(pairings) + + assert len(s) == len(order) * len(pairings) + + es = list(s) + + for e in es[:n_cols]: + assert e["top"] + for e in es[::n_cols]: + assert e["left"] + for e in es[n_cols - 1::n_cols]: + assert e["right"] + for e in es[-n_cols:]: + assert e["bottom"] + + if dim == "row": + es = np.reshape(es, (n_rows, n_cols)).T.ravel() + + for i, e in enumerate(es): + assert e[dim] == order[i % len(pairings)] + assert e[other_dim] is None + assert e[var] == f"{var}{i // len(order)}" + assert e[other_var] == other_var diff --git a/seaborn/tests/_marks/__init__.py b/seaborn/tests/_marks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/tests/_marks/test_area.py b/seaborn/tests/_marks/test_area.py new file mode 100644 index 0000000000..70f31541aa --- /dev/null +++ b/seaborn/tests/_marks/test_area.py @@ -0,0 +1,103 @@ + +import matplotlib as mpl +from matplotlib.colors import to_rgba_array + +from numpy.testing import assert_array_equal + +from seaborn._core.plot import Plot +from seaborn._marks.area import Area, Ribbon + + +class TestAreaMarks: + + def test_single_defaults(self): + + x, y = [1, 2, 3], [1, 2, 1] + p = Plot(x=x, y=y).add(Area()).plot() + ax = p._figure.axes[0] + poly, = ax.collections + verts = poly.get_paths()[0].vertices.T + + expected_x = [1, 2, 3, 3, 2, 1, 1] + assert_array_equal(verts[0], expected_x) + + expected_y = [0, 0, 0, 1, 2, 1, 0] + assert_array_equal(verts[1], expected_y) + + fc = poly.get_facecolor() + assert_array_equal(fc, to_rgba_array("C0", .2)) + + ec = poly.get_edgecolor() + assert_array_equal(ec, to_rgba_array("C0", 1)) + + lw = poly.get_linewidth() + assert_array_equal(lw, mpl.rcParams["patch.linewidth"]) + + def test_direct_parameters(self): + + x, y = [1, 2, 3], [1, 2, 1] + mark = Area( + color="C2", + alpha=.3, + edgecolor="k", + edgealpha=.8, + edgewidth=2, + edgestyle=(0, (2, 1)), + ) + p = Plot(x=x, y=y).add(mark).plot() + ax = p._figure.axes[0] + poly, = ax.collections + + fc = poly.get_facecolor() + assert_array_equal(fc, to_rgba_array(mark.color, mark.alpha)) + + ec = poly.get_edgecolor() + assert_array_equal(ec, to_rgba_array(mark.edgecolor, mark.edgealpha)) + + lw = poly.get_linewidth() + assert_array_equal(lw, mark.edgewidth) + + ls = poly.get_linestyle() + dash_on, dash_off = mark.edgestyle[1] + expected = [(0, [mark.edgewidth * dash_on, mark.edgewidth * dash_off])] + assert ls == expected + + def test_mapped(self): + + x, y = [1, 2, 3, 2, 3, 4], [1, 2, 1, 1, 3, 2] + g = ["a", "a", "a", "b", "b", "b"] + p = Plot(x=x, y=y, color=g, edgewidth=g).add(Area()).plot() + ax = p._figure.axes[0] + polys, = ax.collections + + paths = polys.get_paths() + expected_x = [1, 2, 3, 3, 2, 1, 1], [2, 3, 4, 4, 3, 2, 2] + expected_y = [0, 0, 0, 1, 2, 1, 0], [0, 0, 0, 2, 3, 1, 0] + + for i, path in enumerate(paths): + verts = path.vertices.T + assert_array_equal(verts[0], expected_x[i]) + assert_array_equal(verts[1], expected_y[i]) + + fc = polys.get_facecolor() + assert_array_equal(fc, to_rgba_array(["C0", "C1"], .2)) + + ec = polys.get_edgecolor() + assert_array_equal(ec, to_rgba_array(["C0", "C1"], 1)) + + lw = polys.get_linewidths() + assert lw[0] > lw[1] + + def test_ribbon(self): + + x, ymin, ymax = [1, 2, 4], [2, 1, 4], [3, 3, 5] + p = Plot(x=x, ymin=ymin, ymax=ymax).add(Ribbon()).plot() + ax = p._figure.axes[0] + poly, = ax.collections + verts = poly.get_paths()[0].vertices.T + + expected_x = [1, 2, 4, 4, 2, 1, 1] + assert_array_equal(verts[0], expected_x) + + expected_y = [2, 1, 4, 5, 3, 3, 2] + assert_array_equal(verts[1], expected_y) diff --git a/seaborn/tests/_marks/test_bars.py b/seaborn/tests/_marks/test_bars.py new file mode 100644 index 0000000000..ae4849a6b1 --- /dev/null +++ b/seaborn/tests/_marks/test_bars.py @@ -0,0 +1,122 @@ +import pytest + +from matplotlib.colors import to_rgba + +from seaborn._core.plot import Plot +from seaborn._marks.bars import Bar + + +class TestBar: + + def plot_bars(self, variables, mark_kws, layer_kws): + + p = Plot(**variables).add(Bar(**mark_kws), **layer_kws).plot() + ax = p._figure.axes[0] + return [bar for barlist in ax.containers for bar in barlist] + + def check_bar(self, bar, x, y, width, height): + + assert bar.get_x() == pytest.approx(x) + assert bar.get_y() == pytest.approx(y) + assert bar.get_width() == pytest.approx(width) + assert bar.get_height() == pytest.approx(height) + + def test_categorical_positions_vertical(self): + + x = ["a", "b"] + y = [1, 2] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {}) + for i, bar in enumerate(bars): + self.check_bar(bar, i - w / 2, 0, w, y[i]) + + def test_categorical_positions_horizontal(self): + + x = [1, 2] + y = ["a", "b"] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {}) + for i, bar in enumerate(bars): + self.check_bar(bar, 0, i - w / 2, x[i], w) + + def test_numeric_positions_vertical(self): + + x = [1, 2] + y = [3, 4] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {}) + for i, bar in enumerate(bars): + self.check_bar(bar, x[i] - w / 2, 0, w, y[i]) + + def test_numeric_positions_horizontal(self): + + x = [1, 2] + y = [3, 4] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {"orient": "h"}) + for i, bar in enumerate(bars): + self.check_bar(bar, 0, y[i] - w / 2, x[i], w) + + @pytest.mark.xfail(reason="new dodge api") + def test_categorical_dodge_vertical(self): + + x = ["a", "a", "b", "b"] + y = [1, 2, 3, 4] + group = ["x", "y", "x", "y"] + w = .8 + bars = self.plot_bars( + {"x": x, "y": y, "group": group}, {"multiple": "dodge"}, {} + ) + for i, bar in enumerate(bars[:2]): + self.check_bar(bar, i - w / 2, 0, w / 2, y[i * 2]) + for i, bar in enumerate(bars[2:]): + self.check_bar(bar, i, 0, w / 2, y[i * 2 + 1]) + + @pytest.mark.xfail(reason="new dodge api") + def test_categorical_dodge_horizontal(self): + + x = [1, 2, 3, 4] + y = ["a", "a", "b", "b"] + group = ["x", "y", "x", "y"] + w = .8 + bars = self.plot_bars( + {"x": x, "y": y, "group": group}, {"multiple": "dodge"}, {} + ) + for i, bar in enumerate(bars[:2]): + self.check_bar(bar, 0, i - w / 2, x[i * 2], w / 2) + for i, bar in enumerate(bars[2:]): + self.check_bar(bar, 0, i, x[i * 2 + 1], w / 2) + + def test_direct_properties(self): + + x = ["a", "b", "c"] + y = [1, 3, 2] + + mark = Bar( + color="C2", + alpha=.5, + edgecolor="k", + edgealpha=.9, + edgestyle=(2, 1), + edgewidth=1.5, + ) + + p = Plot(x, y).add(mark).plot() + ax = p._figure.axes[0] + for bar in ax.patches: + assert bar.get_facecolor() == to_rgba(mark.color, mark.alpha) + assert bar.get_edgecolor() == to_rgba(mark.edgecolor, mark.edgealpha) + assert bar.get_linewidth() == mark.edgewidth + assert bar.get_linestyle() == (0, mark.edgestyle) + + def test_mapped_properties(self): + + x = ["a", "b"] + y = [1, 2] + mark = Bar(alpha=.2) + p = Plot(x, y, color=x, edgewidth=y).add(mark).plot() + ax = p._figure.axes[0] + for i, bar in enumerate(ax.patches): + assert bar.get_facecolor() == to_rgba(f"C{i}", mark.alpha) + assert bar.get_edgecolor() == to_rgba(f"C{i}", 1) + assert ax.patches[0].get_linewidth() < ax.patches[1].get_linewidth() diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py new file mode 100644 index 0000000000..ba31585811 --- /dev/null +++ b/seaborn/tests/_marks/test_base.py @@ -0,0 +1,157 @@ +from dataclasses import dataclass + +import numpy as np +import pandas as pd +import matplotlib as mpl + +import pytest +from numpy.testing import assert_array_equal + +from seaborn._marks.base import Mark, Mappable, resolve_color + + +class TestMappable: + + def mark(self, **features): + + @dataclass + class MockMark(Mark): + linewidth: float = Mappable(rc="lines.linewidth") + pointsize: float = Mappable(4) + color: str = Mappable("C0") + fillcolor: str = Mappable(depend="color") + alpha: float = Mappable(1) + fillalpha: float = Mappable(depend="alpha") + + m = MockMark(**features) + return m + + def test_repr(self): + + assert str(Mappable(.5)) == "<0.5>" + assert str(Mappable("CO")) == "<'CO'>" + assert str(Mappable(rc="lines.linewidth")) == "" + assert str(Mappable(depend="color")) == "" + + def test_input_checks(self): + + with pytest.raises(AssertionError): + Mappable(rc="bogus.parameter") + with pytest.raises(AssertionError): + Mappable(depend="nonexistent_feature") + + def test_value(self): + + val = 3 + m = self.mark(linewidth=val) + assert m._resolve({}, "linewidth") == val + + df = pd.DataFrame(index=pd.RangeIndex(10)) + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + def test_default(self): + + val = 3 + m = self.mark(linewidth=Mappable(val)) + assert m._resolve({}, "linewidth") == val + + df = pd.DataFrame(index=pd.RangeIndex(10)) + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + def test_rcparam(self): + + param = "lines.linewidth" + val = mpl.rcParams[param] + + m = self.mark(linewidth=Mappable(rc=param)) + assert m._resolve({}, "linewidth") == val + + df = pd.DataFrame(index=pd.RangeIndex(10)) + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + def test_depends(self): + + val = 2 + df = pd.DataFrame(index=pd.RangeIndex(10)) + + m = self.mark(pointsize=Mappable(val), linewidth=Mappable(depend="pointsize")) + assert m._resolve({}, "linewidth") == val + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + m = self.mark(pointsize=val * 2, linewidth=Mappable(depend="pointsize")) + assert m._resolve({}, "linewidth") == val * 2 + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val * 2)) + + def test_mapped(self): + + values = {"a": 1, "b": 2, "c": 3} + + def f(x): + return np.array([values[x_i] for x_i in x]) + + m = self.mark(linewidth=Mappable(2)) + scales = {"linewidth": f} + + assert m._resolve({"linewidth": "c"}, "linewidth", scales) == 3 + + df = pd.DataFrame({"linewidth": ["a", "b", "c"]}) + expected = np.array([1, 2, 3], float) + assert_array_equal(m._resolve(df, "linewidth", scales), expected) + + def test_color(self): + + c, a = "C1", .5 + m = self.mark(color=c, alpha=a) + + assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a) + + df = pd.DataFrame(index=pd.RangeIndex(10)) + cs = [c] * len(df) + assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a)) + + def test_color_mapped_alpha(self): + + c = "r" + values = {"a": .2, "b": .5, "c": .8} + + m = self.mark(color=c, alpha=Mappable(1)) + scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])} + + assert resolve_color(m, {"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5) + + df = pd.DataFrame({"alpha": list(values.keys())}) + + # Do this in two steps for mpl 3.2 compat + expected = mpl.colors.to_rgba_array([c] * len(df)) + expected[:, 3] = list(values.values()) + + assert_array_equal(resolve_color(m, df, "", scales), expected) + + def test_color_scaled_as_strings(self): + + colors = ["C1", "dodgerblue", "#445566"] + m = self.mark() + scales = {"color": lambda s: colors} + + actual = resolve_color(m, {"color": pd.Series(["a", "b", "c"])}, "", scales) + expected = mpl.colors.to_rgba_array(colors) + assert_array_equal(actual, expected) + + def test_fillcolor(self): + + c, a = "green", .8 + fa = .2 + m = self.mark( + color=c, alpha=a, + fillcolor=Mappable(depend="color"), fillalpha=Mappable(fa), + ) + + assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a) + assert resolve_color(m, {}, "fill") == mpl.colors.to_rgba(c, fa) + + df = pd.DataFrame(index=pd.RangeIndex(10)) + cs = [c] * len(df) + assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a)) + assert_array_equal( + resolve_color(m, df, "fill"), mpl.colors.to_rgba_array(cs, fa) + ) diff --git a/seaborn/tests/_marks/test_scatter.py b/seaborn/tests/_marks/test_scatter.py new file mode 100644 index 0000000000..e5f0a6b8ca --- /dev/null +++ b/seaborn/tests/_marks/test_scatter.py @@ -0,0 +1,141 @@ +from matplotlib.colors import to_rgba, to_rgba_array + +from numpy.testing import assert_array_equal + +from seaborn._core.plot import Plot +from seaborn._marks.scatter import Dot, Scatter + + +class ScatterBase: + + def check_offsets(self, points, x, y): + + offsets = points.get_offsets().T + assert_array_equal(offsets[0], x) + assert_array_equal(offsets[1], y) + + def check_colors(self, part, points, colors, alpha=None): + + rgba = to_rgba_array(colors, alpha) + + getter = getattr(points, f"get_{part}colors") + assert_array_equal(getter(), rgba) + + +class TestScatter(ScatterBase): + + def test_simple(self): + + x = [1, 2, 3] + y = [4, 5, 2] + p = Plot(x=x, y=y).add(Scatter()).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0"] * 3, .2) + self.check_colors("edge", points, ["C0"] * 3, 1) + + def test_color_direct(self): + + x = [1, 2, 3] + y = [4, 5, 2] + p = Plot(x=x, y=y).add(Scatter(color="g")).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["g"] * 3, .2) + self.check_colors("edge", points, ["g"] * 3, 1) + + def test_color_mapped(self): + + x = [1, 2, 3] + y = [4, 5, 2] + c = ["a", "b", "a"] + p = Plot(x=x, y=y, color=c).add(Scatter()).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0", "C1", "C0"], .2) + self.check_colors("edge", points, ["C0", "C1", "C0"], 1) + + def test_fill(self): + + x = [1, 2, 3] + y = [4, 5, 2] + c = ["a", "b", "a"] + p = Plot(x=x, y=y, color=c).add(Scatter(fill=False)).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0", "C1", "C0"], 0) + self.check_colors("edge", points, ["C0", "C1", "C0"], 1) + + def test_pointsize(self): + + x = [1, 2, 3] + y = [4, 5, 2] + s = 3 + p = Plot(x=x, y=y).add(Scatter(pointsize=s)).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + assert_array_equal(points.get_sizes(), [s ** 2] * 3) + + def test_stroke(self): + + x = [1, 2, 3] + y = [4, 5, 2] + s = 3 + p = Plot(x=x, y=y).add(Scatter(stroke=s)).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + assert_array_equal(points.get_linewidths(), [s] * 3) + + def test_filled_unfilled_mix(self): + + x = [1, 2] + y = [4, 5] + marker = ["a", "b"] + shapes = ["o", "x"] + + mark = Scatter(stroke=2) + p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, [to_rgba("C0", .2), to_rgba("C0", 0)], None) + self.check_colors("edge", points, ["C0", "C0"], 1) + assert_array_equal(points.get_linewidths(), [mark.stroke] * 2) + + +class TestDot(ScatterBase): + + def test_simple(self): + + x = [1, 2, 3] + y = [4, 5, 2] + p = Plot(x=x, y=y).add(Dot()).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0"] * 3, 1) + self.check_colors("edge", points, ["C0"] * 3, 1) + + def test_filled_unfilled_mix(self): + + x = [1, 2] + y = [4, 5] + marker = ["a", "b"] + shapes = ["o", "x"] + + mark = Dot(edgecolor="k", stroke=2, edgewidth=1) + p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0", to_rgba("C0", 0)], None) + self.check_colors("edge", points, ["k", "C0"], 1) + + expected = [mark.edgewidth, mark.stroke] + assert_array_equal(points.get_linewidths(), expected) diff --git a/seaborn/tests/_stats/__init__.py b/seaborn/tests/_stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/tests/_stats/test_aggregation.py b/seaborn/tests/_stats/test_aggregation.py new file mode 100644 index 0000000000..ed5b7e4d03 --- /dev/null +++ b/seaborn/tests/_stats/test_aggregation.py @@ -0,0 +1,71 @@ + +import pandas as pd + +import pytest +from pandas.testing import assert_frame_equal + +from seaborn._core.groupby import GroupBy +from seaborn._stats.aggregation import Agg + + +class TestAgg: + + @pytest.fixture + def df(self, rng): + + n = 30 + return pd.DataFrame(dict( + x=rng.uniform(0, 7, n).round(), + y=rng.normal(size=n), + color=rng.choice(["a", "b", "c"], n), + group=rng.choice(["x", "y"], n), + )) + + def get_groupby(self, df, orient): + + other = {"x": "y", "y": "x"}[orient] + cols = [c for c in df if c != other] + return GroupBy(cols) + + def test_default(self, df): + + ori = "x" + df = df[["x", "y"]] + gb = self.get_groupby(df, ori) + res = Agg()(df, gb, ori, {}) + + expected = df.groupby("x", as_index=False)["y"].mean() + assert_frame_equal(res, expected) + + def test_default_multi(self, df): + + ori = "x" + gb = self.get_groupby(df, ori) + res = Agg()(df, gb, ori, {}) + + grp = ["x", "color", "group"] + index = pd.MultiIndex.from_product( + [sorted(df["x"].unique()), df["color"].unique(), df["group"].unique()], + names=["x", "color", "group"] + ) + expected = ( + df + .groupby(grp) + .agg("mean") + .reindex(index=index) + .dropna() + .reset_index() + .reindex(columns=df.columns) + ) + assert_frame_equal(res, expected) + + @pytest.mark.parametrize("func", ["max", lambda x: float(len(x) % 2)]) + def test_func(self, df, func): + + ori = "x" + df = df[["x", "y"]] + gb = self.get_groupby(df, ori) + res = Agg(func)(df, gb, ori, {}) + + expected = df.groupby("x", as_index=False)["y"].agg(func) + assert_frame_equal(res, expected) diff --git a/seaborn/tests/_stats/test_histograms.py b/seaborn/tests/_stats/test_histograms.py new file mode 100644 index 0000000000..f67ae64b52 --- /dev/null +++ b/seaborn/tests/_stats/test_histograms.py @@ -0,0 +1,207 @@ + +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal + +from seaborn._core.groupby import GroupBy +from seaborn._stats.histograms import Hist + + +class TestHist: + + @pytest.fixture + def single_args(self): + + groupby = GroupBy(["group"]) + + class Scale: + scale_type = "continuous" + + return groupby, "x", {"x": Scale()} + + @pytest.fixture + def triple_args(self): + + groupby = GroupBy(["group", "a", "s"]) + + class Scale: + scale_type = "continuous" + + return groupby, "x", {"x": Scale()} + + def test_string_bins(self, long_df): + + h = Hist(bins="sqrt") + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max()) + assert bin_kws["bins"] == int(np.sqrt(len(long_df))) + + def test_int_bins(self, long_df): + + n = 24 + h = Hist(bins=n) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max()) + assert bin_kws["bins"] == n + + def test_array_bins(self, long_df): + + bins = [-3, -2, 1, 2, 3] + h = Hist(bins=bins) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert_array_equal(bin_kws["bins"], bins) + + def test_binwidth(self, long_df): + + binwidth = .5 + h = Hist(binwidth=binwidth) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + n_bins = bin_kws["bins"] + left, right = bin_kws["range"] + assert (right - left) / n_bins == pytest.approx(binwidth) + + def test_binrange(self, long_df): + + binrange = (-4, 4) + h = Hist(binrange=binrange) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert bin_kws["range"] == binrange + + def test_discrete_bins(self, long_df): + + h = Hist(discrete=True) + x = long_df["x"].astype(int) + bin_kws = h._define_bin_params(long_df.assign(x=x), "x", "continuous") + assert bin_kws["range"] == (x.min() - .5, x.max() + .5) + assert bin_kws["bins"] == (x.max() - x.min() + 1) + + def test_discrete_bins_from_nominal_scale(self, rng): + + h = Hist() + x = rng.randint(0, 5, 10) + df = pd.DataFrame({"x": x}) + bin_kws = h._define_bin_params(df, "x", "nominal") + assert bin_kws["range"] == (x.min() - .5, x.max() + .5) + assert bin_kws["bins"] == (x.max() - x.min() + 1) + + def test_count_stat(self, long_df, single_args): + + h = Hist(stat="count") + out = h(long_df, *single_args) + assert out["y"].sum() == len(long_df) + + def test_probability_stat(self, long_df, single_args): + + h = Hist(stat="probability") + out = h(long_df, *single_args) + assert out["y"].sum() == 1 + + def test_proportion_stat(self, long_df, single_args): + + h = Hist(stat="proportion") + out = h(long_df, *single_args) + assert out["y"].sum() == 1 + + def test_percent_stat(self, long_df, single_args): + + h = Hist(stat="percent") + out = h(long_df, *single_args) + assert out["y"].sum() == 100 + + def test_density_stat(self, long_df, single_args): + + h = Hist(stat="density") + out = h(long_df, *single_args) + assert (out["y"] * out["space"]).sum() == 1 + + def test_frequency_stat(self, long_df, single_args): + + h = Hist(stat="frequency") + out = h(long_df, *single_args) + assert (out["y"] * out["space"]).sum() == len(long_df) + + def test_cumulative_count(self, long_df, single_args): + + h = Hist(stat="count", cumulative=True) + out = h(long_df, *single_args) + assert out["y"].max() == len(long_df) + + def test_cumulative_proportion(self, long_df, single_args): + + h = Hist(stat="proportion", cumulative=True) + out = h(long_df, *single_args) + assert out["y"].max() == 1 + + def test_cumulative_density(self, long_df, single_args): + + h = Hist(stat="density", cumulative=True) + out = h(long_df, *single_args) + assert out["y"].max() == 1 + + def test_common_norm_default(self, long_df, triple_args): + + h = Hist(stat="percent") + out = h(long_df, *triple_args) + assert out["y"].sum() == pytest.approx(100) + + def test_common_norm_false(self, long_df, triple_args): + + h = Hist(stat="percent", common_norm=False) + out = h(long_df, *triple_args) + for _, out_part in out.groupby(["a", "s"]): + assert out_part["y"].sum() == pytest.approx(100) + + def test_common_norm_subset(self, long_df, triple_args): + + h = Hist(stat="percent", common_norm=["a"]) + out = h(long_df, *triple_args) + for _, out_part in out.groupby(["a"]): + assert out_part["y"].sum() == pytest.approx(100) + + def test_common_bins_default(self, long_df, triple_args): + + h = Hist() + out = h(long_df, *triple_args) + bins = [] + for _, out_part in out.groupby(["a", "s"]): + bins.append(tuple(out_part["x"])) + assert len(set(bins)) == 1 + + def test_common_bins_false(self, long_df, triple_args): + + h = Hist(common_bins=False) + out = h(long_df, *triple_args) + bins = [] + for _, out_part in out.groupby(["a", "s"]): + bins.append(tuple(out_part["x"])) + assert len(set(bins)) == len(out.groupby(["a", "s"])) + + def test_common_bins_subset(self, long_df, triple_args): + + h = Hist(common_bins=False) + out = h(long_df, *triple_args) + bins = [] + for _, out_part in out.groupby(["a"]): + bins.append(tuple(out_part["x"])) + assert len(set(bins)) == out["a"].nunique() + + def test_histogram_single(self, long_df, single_args): + + h = Hist() + out = h(long_df, *single_args) + hist, edges = np.histogram(long_df["x"], bins="auto") + assert_array_equal(out["y"], hist) + assert_array_equal(out["space"], np.diff(edges)) + + def test_histogram_multiple(self, long_df, triple_args): + + h = Hist() + out = h(long_df, *triple_args) + bins = np.histogram_bin_edges(long_df["x"], "auto") + for (a, s), out_part in out.groupby(["a", "s"]): + x = long_df.loc[(long_df["a"] == a) & (long_df["s"] == s), "x"] + hist, edges = np.histogram(x, bins=bins) + assert_array_equal(out_part["y"], hist) + assert_array_equal(out_part["space"], np.diff(edges)) diff --git a/seaborn/tests/_stats/test_regression.py b/seaborn/tests/_stats/test_regression.py new file mode 100644 index 0000000000..7facf75d32 --- /dev/null +++ b/seaborn/tests/_stats/test_regression.py @@ -0,0 +1,52 @@ + +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal, assert_array_almost_equal + +from seaborn._core.groupby import GroupBy +from seaborn._stats.regression import PolyFit + + +class TestPolyFit: + + @pytest.fixture + def df(self, rng): + + n = 100 + return pd.DataFrame(dict( + x=rng.normal(0, 1, n), + y=rng.normal(0, 1, n), + color=rng.choice(["a", "b", "c"], n), + group=rng.choice(["x", "y"], n), + )) + + def test_no_grouper(self, df): + + groupby = GroupBy(["group"]) + res = PolyFit(order=1, gridsize=100)(df[["x", "y"]], groupby, "x", {}) + + assert_array_equal(res.columns, ["x", "y"]) + + grid = np.linspace(df["x"].min(), df["x"].max(), 100) + assert_array_equal(res["x"], grid) + assert_array_almost_equal( + res["y"].diff().diff().dropna(), np.zeros(grid.size - 2) + ) + + def test_one_grouper(self, df): + + groupby = GroupBy(["group"]) + gridsize = 50 + res = PolyFit(gridsize=gridsize)(df, groupby, "x", {}) + + assert res.columns.to_list() == ["x", "y", "group"] + + ngroups = df["group"].nunique() + assert_array_equal(res.index, np.arange(ngroups * gridsize)) + + for _, part in res.groupby("group"): + grid = np.linspace(part["x"].min(), part["x"].max(), gridsize) + assert_array_equal(part["x"], grid) + assert part["y"].diff().diff().dropna().abs().gt(0).all() diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index 96c8797e27..d7858464e3 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -11,7 +11,7 @@ except ImportError: import pandas.util.testing as tm -from .._core import categorical_order +from .._oldcore import categorical_order from .. import rcmod from ..palettes import color_palette from ..relational import scatterplot diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index 67a225f56f..fd37b809b6 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -18,8 +18,8 @@ from .. import categorical as cat from .. import palettes -from .._core import categorical_order from ..external.version import Version +from .._oldcore import categorical_order from ..categorical import ( _CategoricalPlotterNew, Beeswarm, diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py index 423c814c67..31d177e5c7 100644 --- a/seaborn/tests/test_core.py +++ b/seaborn/tests/test_core.py @@ -9,7 +9,7 @@ from pandas.testing import assert_frame_equal from ..axisgrid import FacetGrid -from .._core import ( +from .._oldcore import ( SemanticMapping, HueMapping, SizeMapping, @@ -144,11 +144,6 @@ def test_hue_map_categorical(self, wide_df, long_df): assert m.palette == palette assert m.lookup_table == palette - # Test dict with missing keys - palette = dict(zip(wide_df.columns[:-1], colors)) - with pytest.raises(ValueError): - HueMapping(p, palette=palette) - # Test dict with missing keys palette = dict(zip(wide_df.columns[:-1], colors)) with pytest.raises(ValueError): diff --git a/seaborn/tests/test_distributions.py b/seaborn/tests/test_distributions.py index fd2d7d2af5..6691b7ff03 100644 --- a/seaborn/tests/test_distributions.py +++ b/seaborn/tests/test_distributions.py @@ -13,7 +13,7 @@ color_palette, light_palette, ) -from .._core import ( +from .._oldcore import ( categorical_order, ) from .._statistics import ( diff --git a/seaborn/utils.py b/seaborn/utils.py index 69a4634448..e14e80121c 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -624,6 +624,10 @@ def load_dataset(name, cache=True, data_home=None, **kws): df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"], ) + elif name == "taxis": + df["pickup"] = pd.to_datetime(df["pickup"]) + df["dropoff"] = pd.to_datetime(df["dropoff"]) + return df diff --git a/setup.cfg b/setup.cfg index 5fe3a51f96..81e966b885 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,3 +5,10 @@ license_file = LICENSE max-line-length = 88 exclude = seaborn/cm.py,seaborn/external ignore = E741,F522,W503 + +[mypy] +# Currently this ignores pandas and matplotlib +# We may want to make custom stub files for the parts we use +# I have found the available third party stubs to be less +# complete than they would need to be useful +ignore_missing_imports = True \ No newline at end of file diff --git a/setup.py b/setup.py index 76cd05730d..530705d2dc 100644 --- a/setup.py +++ b/setup.py @@ -30,24 +30,31 @@ PYTHON_REQUIRES = ">=3.7" INSTALL_REQUIRES = [ - 'numpy>=1.16', - 'pandas>=0.24', - 'matplotlib>=3.0', + 'numpy>=1.17', + 'pandas>=0.25', + 'matplotlib>=3.1', + 'typing_extensions; python_version < "3.8"', ] EXTRAS_REQUIRE = { 'all': [ - 'scipy>=1.2', - 'statsmodels>=0.9', + 'scipy>=1.3', + 'statsmodels>=0.10', ] } PACKAGES = [ 'seaborn', + 'seaborn._core', + 'seaborn._marks', + 'seaborn._stats', 'seaborn.colors', 'seaborn.external', 'seaborn.tests', + 'seaborn._core', + 'seaborn._marks', + 'seaborn._stats', ] CLASSIFIERS = [