Skip to content

Commit

Permalink
🩹 Fix broken 'plot_data_overview' (#42)
Browse files Browse the repository at this point in the history
* 🩹 Moved imports needed at runtime up from TYPE_CHECKING if block

* 👌 Changed defaul figsize of plot_data_overview to (15, 10))

* 👌 Changed dataset in plot_data_overview to also be allowed to be a path

That way there is consistent behavior across high level plotting functions.
  • Loading branch information
s-weigand authored Oct 24, 2021
1 parent fe6e00f commit 1e1ce26
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,37 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import cast

import matplotlib.pyplot as plt
from matplotlib.axis import Axis

from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.plotting.plot_svd import plot_lsv_data
from pyglotaran_extras.plotting.plot_svd import plot_rsv_data
from pyglotaran_extras.plotting.plot_svd import plot_sv_data

__all__ = ["plot_data_overview"]

if TYPE_CHECKING:
from typing import cast

import xarray as xr
from matplotlib.axis import Axis
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import DatasetConvertible


def plot_data_overview(
dataset: xr.Dataset,
dataset: DatasetConvertible,
title: str = "Data overview",
linlog: bool = False,
linthresh: float = 1,
figsize: tuple[int, int] = (30, 15),
figsize: tuple[int, int] = (15, 10),
) -> tuple[Figure, Axes]:
"""Plot data as filled contour plot and SVD components.
Parameters
----------
dataset : Dataset
dataset : DatasetConvertible
Dataset containing data and SVD of the data.
title : str, optional
Title to add to the figure., by default "Data overview"
Expand All @@ -45,6 +46,8 @@ def plot_data_overview(
tuple[Figure, Axes]
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset)

fig = plt.figure(figsize=figsize)
data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig))
lsv_ax = cast(Axis, plt.subplot2grid((4, 3), (3, 0), fig=fig))
Expand Down

0 comments on commit 1e1ce26

Please sign in to comment.