diff --git a/pyopenms_viz/BasePlotter.py b/pyopenms_viz/BasePlotter.py
deleted file mode 100644
index d32aa8fe..00000000
--- a/pyopenms_viz/BasePlotter.py
+++ /dev/null
@@ -1,124 +0,0 @@
-from abc import ABC, abstractmethod
-from dataclasses import dataclass, field
-from enum import Enum
-from typing import Literal, List, Tuple
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-
-class Engine(Enum):
- PLOTLY = 1
- BOKEH = 2
- MATPLOTLIB = 3
-
-# A colorset suitable for color blindness
-class Colors(str, Enum):
- BLUE = "#4575B4"
- RED = "#D73027"
- LIGHTBLUE = "#91BFDB"
- ORANGE = "#FC8D59"
- PURPLE = "#7B2C65"
- YELLOW = "#FCCF53"
- DARKGRAY = "#555555"
- LIGHTGRAY = "#BBBBBB"
-
-@dataclass(kw_only=True)
-class LegendConfig:
- loc: str = 'right'
- title: str = 'Legend'
- fontsize: int = 10
- show: bool = True
- onClick: Literal["hide", "mute"] = 'mute' # legend click policy, only valid for bokeh
- bbox_to_anchor: Tuple[float, float] = (1.2, 0.5) # for fine control legend positioning in matplotlib
-
- @staticmethod
- def _matplotlibLegendLocationMapper(loc):
- '''
- Maps the legend location to the matplotlib equivalent
- '''
- loc_mapper = {'right':'center right', 'left':'center left', 'above':'upper center', 'below':'lower center'}
- return loc_mapper[loc]
-
-@dataclass(kw_only=True)
-class _BasePlotterConfig(ABC):
- title: str = "1D Plot"
- xlabel: str = "X-axis"
- ylabel: str = "Y-axis"
- engine: Literal["PLOTLY", "BOKEH", "MATPLOTLIB"] = "PLOTLY"
- height: int = 500
- width: int = 500
- relative_intensity: bool = False
- show_legend: bool = True
- show: bool = True
- colormap: str = 'viridis'
- legend: LegendConfig = field(default_factory=LegendConfig)
- grid: bool = True
- lineStyle: str = 'solid'
- lineWidth: float = 1
-
- @property
- def engine_enum(self):
- return Engine[self.engine]
-
-
-# Abstract Class for Plotting
-class _BasePlotter(ABC):
- def __init__(self, config: _BasePlotterConfig) -> None:
- self.config = config
- self.fig = None # holds the figure object
- self.main_palette = None # holds the main color palette
- self.feature_palette = None # holds the feature color palette
-
- def updateConfig(self, **kwargs):
- for key, value in kwargs.items():
- if hasattr(self.config, key):
- setattr(self.config, key, value)
- else:
- raise ValueError(f"Invalid config setting: {key}")
-
- @staticmethod
- def generate_colors(colormap, n):
- # Use Matplotlib's built-in color palettes
- cmap = plt.get_cmap(colormap, n)
- colors = cmap(np.linspace(0, 1, n))
-
- # Convert colors to hex format
- hex_colors = ['#{:02X}{:02X}{:02X}'.format(int(r*255), int(g*255), int(b*255)) for r, g, b, _ in colors]
-
- return hex_colors
-
- def _get_n_grayscale_colors(self, n: int) -> List[str]:
- """Returns n evenly spaced grayscale colors in hex format."""
- hex_list = []
- for v in np.linspace(50, 200, n):
- hex = "#"
- for _ in range(3):
- hex += f"{int(round(v)):02x}"
- hex_list.append(hex)
- return hex_list
-
- def plot(self, data : pd.DataFrame , featureData : pd.DataFrame = None, **kwargs):
-
- # TODO: Assert throws errors at startup if using a test streamlit app
- # ### Assert color palettes are set
- # assert(self.main_palette is not None)
- # assert(self.feature_palette is not None if featureData is not None else True)
-
- if self.config.engine_enum == Engine.PLOTLY:
- return self._plotPlotly(data, **kwargs)
- elif self.config.engine_enum == Engine.BOKEH:
- return self._plotBokeh(data, **kwargs)
- else: # self.config.engine_enum == Engine.MATPLOTLIB:
- return self._plotMatplotlib(data, **kwargs)
-
- @abstractmethod
- def _plotBokeh(self, data, **kwargs):
- pass
-
- @abstractmethod
- def _plotPlotly(self, data, **kwargs):
- pass
-
- @abstractmethod
- def _plotMatplotlib(self, data, **kwargs):
- pass
diff --git a/pyopenms_viz/ChromatogramPlotter.py b/pyopenms_viz/ChromatogramPlotter.py
deleted file mode 100644
index 21a51652..00000000
--- a/pyopenms_viz/ChromatogramPlotter.py
+++ /dev/null
@@ -1,819 +0,0 @@
-from typing import List, Tuple
-from pandas import DataFrame
-import matplotlib.pyplot as plt
-from matplotlib import cm
-import numpy as np
-import pandas as pd
-from dataclasses import dataclass, field, fields
-from typing import Literal
-
-from .BasePlotter import _BasePlotter, _BasePlotterConfig, Engine, LegendConfig
-from .util._decorators import filter_unexpected_fields
-
-@dataclass(kw_only=True)
-class ChromatogramFeatureConfig:
- def default_legend_factory():
- return LegendConfig(title="Features", loc='right', bbox_to_anchor=(1.5, 0.5))
-
- colormap: str = "viridis"
- lineWidth: float = 1
- lineStyle: str = 'solid'
- legend: LegendConfig = field(default_factory=default_legend_factory)
-
-@filter_unexpected_fields
-@dataclass(kw_only=True)
-class ChromatogramPlotterConfig(_BasePlotterConfig):
- def default_legend_factory():
- return LegendConfig(title="Transitions")
-
- # Plot Aesthetics
- title: str = "Chromatogram Plot"
- xlabel: str = "Retention Time"
- ylabel: str = "Intensity"
- x_axis_col: str = "rt"
- y_axis_col: str = "int"
- x_axis_location: str = "below"
- y_axis_location: str = "left"
- min_border: int = 0
- show: bool = True
- lineWidth: float = 1
- lineStyle: str = 'solid'
- plot_type: str = "lineplot"
- add_marginals: bool = False
- featureConfig: ChromatogramFeatureConfig = field(default_factory=ChromatogramFeatureConfig)
- legend: LegendConfig = field(default_factory=default_legend_factory)
-
- # Data Specific Attributes
- ion_mobility: bool = False # if True, plot ion mobility as well in a heatmap
-
-
-class ChromatogramPlotter(_BasePlotter):
-
- def __init__(self, config: _BasePlotterConfig, **kwargs) -> None:
- super().__init__(config, **kwargs)
-
- @staticmethod
- def rgb_to_hex(rgb):
- """
- Converts an RGB color value to its corresponding hexadecimal representation.
-
- Args:
- rgb (tuple): A tuple containing the RGB values as floats between 0 and 1.
-
- Returns:
- str: The hexadecimal representation of the RGB color.
-
- Example:
- >>> rgb_to_hex((0.5, 0.75, 1.0))
- '#7fbfff'
- """
- return "#{:02x}{:02x}{:02x}".format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255))
-
-
- @staticmethod
- def _get_data_ranges(arr: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, float, float, float, float, float, float]:
- """
- Get parameters for plotting.
-
- Args:
- arr (pd.DataFrame): The input DataFrame.
-
- Returns:
- Tuple[np.ndarray, np.ndarray, float, float, float, float, float, float]: The parameters for plotting.
- """
- im_arr = arr.index.to_numpy()
- rt_arr = arr.columns.to_numpy()
-
- dw_main = rt_arr.max() - rt_arr.min()
- dh_main = im_arr.max() - im_arr.min()
-
- rt_min, rt_max, im_min, im_max = rt_arr.min(), rt_arr.max(), im_arr.min(), im_arr.max()
-
- return im_arr, rt_arr, dw_main, dh_main, rt_min, rt_max, im_min, im_max
-
- @staticmethod
- def _prepare_array(arr: pd.DataFrame) -> np.ndarray:
- """
- Prepare the array for plotting. Also performs equalization and/or smoothing if specified in the configuration.
-
- Args:
- arr (pd.DataFrame): The input DataFrame.
-
- Returns:
- np.ndarray: The prepared array.
- """
- arr = arr.to_numpy()
- arr[np.isnan(arr)] = 0
-
- return arr
-
- @staticmethod
- def _get_data_as_two_dimenstional_array(data: pd.DataFrame) -> np.ndarray:
- feat_arrs = {ion_trace:grp_df.pivot_table(index='im', columns='rt', values='int', aggfunc="sum") for ion_trace, grp_df in data.groupby('Annotation')}
- return feat_arrs
-
- @staticmethod
- def _integrate_data_along_dim(data: pd.DataFrame, col_dim: str) -> pd.DataFrame:
- # TODO: Double check which columns are required
- grouped = data.fillna({'native_id': 'NA'}).groupby(['native_id', 'ms_level', 'precursor_mz', 'Annotation', 'product_mz', col_dim])['int'].sum().reset_index()
- return grouped
-
- ### assume that the chromatogram have the following columns: intensity, time
- ### optional column to be used: annotation
- ### assume that the chromatogramFeatures have the following columns: left_width, right_width (optional columns: area, q_value)
- def plot(self, chromatogram, chromatogramFeatures = None, **kwargs):
- #### General Data Processing before plotting ####
- # sort by q_value if available
- if chromatogramFeatures is not None:
- if "q_value" in chromatogramFeatures.columns:
- chromatogramFeatures = chromatogramFeatures.sort_values(by="q_value")
-
- #### compute apex intensity for features if not already computed
- if chromatogramFeatures is not None:
- if "apexIntensity" not in chromatogramFeatures.columns:
- all_apexIntensity = []
- for _, feature in chromatogramFeatures.iterrows():
- apexIntensity = 0
- for _, row in chromatogram.iterrows():
- if row["rt"] >= feature["leftWidth"] and row["rt"] <= feature["rightWidth"] and row["int"] > apexIntensity:
- apexIntensity = row["int"]
- all_apexIntensity.append(apexIntensity)
-
- chromatogramFeatures["apexIntensity"] = all_apexIntensity
-
- # compute colormaps based on the number of transitions and features
- self.main_palette = self.generate_colors(self.config.colormap, len(chromatogram["Annotation"].unique()) if 'Annotation' in chromatogram.columns else 1)
- self.feature_palette = self.generate_colors(self.config.featureConfig.colormap, len(chromatogramFeatures)) if chromatogramFeatures is not None else None
-
- return super().plot(chromatogram, chromatogramFeatures, **kwargs)
-
- def _plotBokeh(self, data: DataFrame, chromatogramFeatures: DataFrame = None):
-
- def _plotLines(self, data: pd.DataFrame, chromatogramFeatures: DataFrame = None):
- from bokeh.plotting import figure, show
- from bokeh.models import ColumnDataSource, Legend
-
- # Tooltips for interactive information
- TOOLTIPS = [
- ("index", "$index"),
- ("Retention Time", "@rt{0.2f}"),
- ("Intensity", "@int{0.2f}"),
- ("m/z", "@mz{0.4f}")
- ]
-
- if "Annotation" in data.columns:
- TOOLTIPS.append(("Annotation", "@Annotation"))
- if "product_mz" in data.columns:
- TOOLTIPS.append(("Target m/z", "@product_mz{0.4f}"))
-
- # Create the Bokeh plot
- p = figure(title=self.config.title,
- x_axis_label=self.config.xlabel,
- y_axis_label=self.config.ylabel,
- x_axis_location=self.config.x_axis_location,
- y_axis_location=self.config.y_axis_location,
- width=self.config.width,
- height=self.config.height,
- tooltips=TOOLTIPS)
-
- # Create a legend
- legend = Legend()
-
- # Create a list to store legend items
- if 'Annotation' in data.columns:
- legend_items = []
- i = 0
- for annotation, group_df in data.groupby('Annotation'):
- source = ColumnDataSource(group_df)
- line = p.line(x=self.config.x_axis_col, y=self.config.y_axis_col, source=source, line_width=self.config.lineWidth, line_color=self.main_palette[i], line_dash=self.config.lineStyle)
- legend_items.append((annotation, [line]))
- i+=1
-
- # Add legend items to the legend
- legend.items = legend_items
-
- # Add the legend to the plot
- p.add_layout(legend, self.config.legend.loc)
-
- p.legend.click_policy=self.config.legend.onClick
- p.legend.title = self.config.legend.title
- p.legend.label_text_font_size = str(self.config.legend.fontsize) + 'pt'
-
- else:
- source = ColumnDataSource(data)
- line = p.line(x=self.config.x_axis_col, y=self.config.y_axis_col, source=source, line_width=self.config.lineWidth, line_color=self.main_palette[0], line_alpha=0.5, line_dash=self.config.lineStyle)
- # Customize the plot
- p.grid.visible = self.config.grid
- p.toolbar_location = "above" #NOTE: This is hardcoded
-
- ##### Plotting chromatogram features #####
- if chromatogramFeatures is not None:
-
- for idx, (_, feature) in enumerate(chromatogramFeatures.iterrows()):
-
- leftWidth_line = p.line(x=[feature['leftWidth']] * 2, y=[0, feature['apexIntensity']], width=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], line_dash=self.config.featureConfig.lineStyle )
- rightWidth_line = p.line(x=[feature['rightWidth']] * 2, y=[0, feature['apexIntensity']], width=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], line_dash = self.config.featureConfig.lineStyle)
-
- if self.config.featureConfig.legend.show:
- feature_legend_items = []
- if "q_value" in chromatogramFeatures.columns:
- legend_msg = f'Feature {idx} (q={feature["q_value"]:.2f})'
- else:
- legend_msg = f'Feature {idx}'
- feature_legend_items.append((legend_msg, [leftWidth_line]))
-
- legend = Legend(items=feature_legend_items, title=self.config.legend.title )
- p.add_layout(legend, self.config.legend.loc)
-
- if self.config.show:
- show(p)
-
- return p
-
- def _plotHeatmap(self, data: pd.DataFrame):
- from bokeh.plotting import figure, show
- from bokeh.models import ColumnDataSource, Legend, HoverTool, CrosshairTool
-
- AFMHOT_CMAP = [self.rgb_to_hex(cm.afmhot_r(i)[:3]) for i in range(256)]
-
- # Get the data as a two-dimensional array
- feat_arrs = self._get_data_as_two_dimenstional_array(data)
-
- # Get the data ranges
- im_range = data.groupby('Annotation')['im'].agg(lambda x: x.max() - x.min())
- rt_range = data.groupby('Annotation')['rt'].agg(lambda x: x.max() - x.min())
-
- p_hm_legends = []
- # Create a legend
- legend_hm = Legend()
- p = figure(x_range=(data.rt.min(), data.rt.max()),
- y_range=(data.im.min(), data.im.max()),
- x_axis_label="Retention Time [sec]",
- y_axis_label="Ion Mobility",
- # y_axis_label=None,
- # y_axis_location = None,
- width=700,
- height=700,
- min_border=0
- )
-
- for annotation, df_wide in feat_arrs.items():
- arr = self._prepare_array(df_wide)
- heatmap_img = p.image(image=[arr], x=data.rt.min(), y=data.im.min(), dw=rt_range.mean(), dh=im_range.min(), palette=AFMHOT_CMAP)
- p_hm_legends.append((annotation, [heatmap_img]))
-
- # Add legend items to the legend
- legend_hm.items = p_hm_legends
-
- # Add the legend to the plot
- p.add_layout(legend_hm, 'right')
-
- hover = HoverTool(renderers=[heatmap_img], tooltips=[("Value", "@image")])
- linked_crosshair = CrosshairTool(dimensions="both")
- p.add_tools(hover)
- p.add_tools(linked_crosshair)
-
- p.grid.visible = False
-
- if self.config.add_marginals:
- # Store original config values that need to be changed for marginals
- show_org = self.config.show
- self.config.show = False
-
- # Integrate the data along the retention time dimension
- rt_integrated = self._integrate_data_along_dim(data, 'rt')
- # Generate a lineplot for XIC
- self.config.y_axis_location = "right"
- p_xic = _plotLines(self, rt_integrated, chromatogramFeatures)
-
- # Link range of XIC plot with the main plot
- p_xic.x_range = p.x_range
- p_xic.width = p.width
-
- # Modify labels
- p_xic.title = "Integrated Ion Chromatogram"
- # Hide x-axis for grouped plot
- p_xic.xaxis.visible = False
-
- # Make border 0
- p_xic.min_border = 0
-
- # Integrate the data along the ion mobility dimension
- im_integrated = self._integrate_data_along_dim(data, 'im')
-
- # Generate a lineplot for XIM
- self.config.x_axis_col = 'int'
- self.config.y_axis_col = 'im'
- self.config.y_axis_location = "left"
- self.config.legend.loc = 'below'
- p_xim = _plotLines(self, im_integrated)
-
- # Link y-axis with heatmap
- p_xim.y_range = p.y_range
- p_xim.height = p.height
-
- p_xim.legend.orientation = "horizontal"
-
- # Flip x-axis range
- p_xim.x_range.flipped = True
- # p_mobi.y_range
-
- # Modify labels
- p_xim.title = "Integrated Ion Mobilogram"
- p_xim.title_location = "below"
- p_xim.xaxis.axis_label = "Intensity"
- p_xim.yaxis.axis_label = "Ion Mobility"
-
- # Make border 0
- p_xim.min_border = 0
-
- # Heatmap mod
- p.yaxis.visible = False
-
- # Construct Marginal Plot
- from bokeh.layouts import gridplot
-
- # Combine the plots into a grid layout
- layout = gridplot([[None, p_xic], [p_xim, p]], sizing_mode="stretch_both")
-
- # Reset the config values to org
- self.config.show = show_org
-
- if self.config.show:
- show(layout)
-
- return layout
-
- if self.config.show:
- show(p)
-
- return p
-
-
-
- if self.config.plot_type == "lineplot":
- return _plotLines(self, data, chromatogramFeatures)
- elif self.config.plot_type == "heatmap":
- return _plotHeatmap(self, data)
- else:
- raise ValueError(f"Invalid plot type: {type}")
-
- def _plotPlotly(self, data: DataFrame, chromatogramFeatures: DataFrame = None):
-
- def _plotLines(self, data: pd.DataFrame, chromatogramFeatures: DataFrame = None):
-
- import plotly.graph_objects as go
-
- # Create a trace for each unique annotation
- traces = []
- if "Annotation" in data.columns:
- for i, (annotation, group_df) in enumerate(data.groupby('Annotation')):
- trace = go.Scatter(
- x=group_df[self.config.x_axis_col],
- y=group_df[self.config.y_axis_col],
- mode='lines',
- name=annotation,
- line=dict(
- color=self.main_palette[i],
- width=self.config.lineWidth,
- dash=self.config.lineStyle
- )
- )
- traces.append(trace)
- else:
- trace = go.Scatter(
- x=data[self.config.x_axis_col],
- y=data[self.config.y_axis_col],
- mode='lines',
- name="Transition",
- line=dict(
- color=self.main_palette[0],
- width=self.config.lineWidth,
- dash=self.config.lineStyle
- ))
- traces.append(trace)
-
-
- # Create the Plotly figure
- fig = go.Figure(data=traces)
- fig.update_layout(
- title=self.config.title,
- xaxis_title=self.config.xlabel,
- yaxis_title=self.config.ylabel,
- width=self.config.width,
- height=self.config.height,
- legend_title="Transition",
- legend_font_size=self.config.legend.fontsize
- )
-
- available_columns = data.columns.tolist()
- available_columns = data.columns.tolist()
- custom_hover_data = [data[col] for col in ["index", "mz"] if col in available_columns]
-
- hover_template_parts = [
- "Index: %{customdata[0]}",
- "Retention Time: %{x:.2f}",
- "Intensity: %{y:.2f}",
- ]
-
- if "mz" in available_columns:
- hover_template_parts.append("m/z: %{customdata[1]:.4f}")
- custom_hover_data_index = 2
- else:
- custom_hover_data_index = 1
-
- if "Annotation" in available_columns:
- hover_template_parts.append("Annotation: %{customdata[" + str(custom_hover_data_index) + "]}")
- custom_hover_data.append(data["Annotation"])
-
- hovertemplate = "
".join(hover_template_parts)
-
- fig.update_traces(
- hovertemplate=hovertemplate,
- customdata=np.column_stack(custom_hover_data)
- )
-
- # Customize the plot
- fig.update_layout(
- plot_bgcolor='white',
- paper_bgcolor='white',
- xaxis_showgrid=True,
- yaxis_showgrid=True,
- xaxis_zeroline=False,
- yaxis_zeroline=False
- )
-
- ##### Plotting chromatogram features #####
- if chromatogramFeatures is not None:
- for idx, (_, feature) in enumerate(chromatogramFeatures.iterrows()):
- feature_group = f"Feature {idx}"
-
- feature_boundary_box = fig.add_shape(type='rect',
- x0=feature['leftWidth'],
- y0=0,
- x1=feature['rightWidth'],
- y1=feature['apexIntensity'],
- legendgroup=feature_group,
- legendgrouptitle_text="Features",
- showlegend=self.config.featureConfig.legend.show,
- name=f'Feature {idx}' if "q_value" not in chromatogramFeatures.columns else f'Feature {idx} (q={feature["q_value"]:.2f})',
- line=dict(
- color=self.feature_palette[idx],
- width=self.config.featureConfig.lineWidth,
- dash=self.config.featureConfig.lineStyle)
- )
-
- if self.config.show:
- fig.show()
- return fig
-
- def _plotHeatmap(self, data: pd.DataFrame):
- import plotly.graph_objects as go
-
- AFMHOT_CMAP = [self.rgb_to_hex(cm.afmhot_r(i)[:3]) for i in range(256)]
-
- # Get the data as a two-dimensional array
- feat_arrs = self._get_data_as_two_dimenstional_array(data)
-
- # Create the Plotly figure
- fig = go.Figure()
-
- # Create a trace for each unique annotation
- for annotation, df_wide in feat_arrs.items():
- arr = self._prepare_array(df_wide)
- fig.add_trace(go.Heatmap(z=arr, x=df_wide.columns, y=df_wide.index, colorscale=AFMHOT_CMAP, coloraxis="coloraxis", name=annotation, showlegend=True))
-
- fig.update_layout(coloraxis = {'colorscale':AFMHOT_CMAP})
-
- # fig.update_traces(name="Transitions", showlegend=True)
-
- # Customize the plot
- fig.update_layout(
- title=self.config.title,
- xaxis_title=self.config.xlabel,
- yaxis_title=self.config.ylabel,
- width=self.config.width,
- height=self.config.height,
- legend=dict(
- orientation="h", # Set the legend orientation to horizontal
- y=-0.2, # Adjust the vertical position of the legend
- x=0.5, # Adjust the horizontal position of the legend
- xanchor="center" # Center the legend horizontally
- )
- )
-
- if self.config.add_marginals:
- from plotly.subplots import make_subplots
-
- # Store original config values that need to be changed for marginals
- show_org = self.config.show
- self.config.show = False
-
- # Integrate the data along the retention time dimension
- rt_integrated = self._integrate_data_along_dim(data, 'rt')
- # Generate a lineplot for XIC
- self.config.y_axis_location = "right"
- fig_xic = _plotLines(self, rt_integrated, chromatogramFeatures)
- fig_xic.update_layout(title="Integrated Ion Chromatogram")
- fig_xic.update_xaxes(visible=False)
-
- # Integrate the data along the ion mobility dimension
- im_integrated = self._integrate_data_along_dim(data, 'im')
- # Generate a lineplot for XIM
- self.config.x_axis_col = 'int'
- self.config.y_axis_col = 'im'
- self.config.y_axis_location = "left"
- fig_xim = _plotLines(self, im_integrated)
- fig_xim.update_layout(title="Integrated Ion Mobilogram")
- fig_xim.update_xaxes(range=[0, im_integrated['int'].max()])
- fig_xim.update_yaxes(range=[im_integrated['im'].min(), im_integrated['im'].max()])
- fig_xim.update_layout(xaxis_title="Intensity", yaxis_title="Ion Mobility")
-
- # Create a figure with subplots
- fig_m = make_subplots(
- rows=2, cols=2,
- shared_xaxes=True, shared_yaxes=True,
- vertical_spacing=0, horizontal_spacing=0,
- subplot_titles=(None, "Integrated Ion Chromatogram", "Integrated Ion Mobilogram", None),
- specs=[[{}, {"type": "xy", "rowspan": 1, "secondary_y":True}],
- [{"type": "xy", "rowspan": 1, "secondary_y":False}, {"type": "xy", "rowspan": 1, "secondary_y":False}]]
- )
-
- # Add the heatmap to the first row
- for trace in fig.data:
- trace.showlegend = False
- trace.legendgroup = trace.name
- fig_m.add_trace(trace, row=2, col=2, secondary_y=False)
-
- # Update the heatmap layout
- fig_m.update_layout(fig.layout)
- fig_m.update_yaxes(row=2, col=2, secondary_y=False)
-
- # Add the XIC plot to the second row
- for trace in fig_xic.data:
- trace.legendgroup = trace.name
- fig_m.add_trace(trace, row=1, col=2, secondary_y=True)
-
- # Update the XIC layout
- fig_m.update_layout(fig_xic.layout)
-
- # Make the y-axis of fig_xic independent
- fig_m.update_yaxes(overwrite=True, row=1, col=2, secondary_y=True)
-
- # Manually adjust the domain of secondary y-axis to only span the first row of the subplot
- fig_m['layout']['yaxis3']['domain'] = [0.5, 1.0]
-
- # Add the XIM plot to the second row
- for trace in fig_xim.data:
- trace.showlegend = False
- trace.legendgroup = trace.name
- fig_m.add_trace(trace, row=2, col=1)
-
- # Update the XIM layout
- fig_m.update_layout(fig_xim.layout)
-
- # Make the x-axis of fig_xim independent
- fig_m.update_xaxes(overwrite=True, row=2, col=1)
-
- # Reverse the x-axis range for the XIM subplot
- fig_m.update_xaxes(autorange="reversed", row=2, col=1)
-
- # Update xaxis properties
- fig_m.update_xaxes(title_text="Retention Time [sec]", row=2, col=2)
- fig_m.update_xaxes(title_text="Intensity", row=2, col=1)
-
- # Update yaxis properties
- fig_m.update_yaxes(title_text="Intensity", row=1, col=2)
- fig_m.update_yaxes(title_text="Ion Mobility", row=2, col=1)
-
- # Update the layout
- fig_m.update_layout(
- height=800,
- width=1200,
- title=self.config.title
- )
-
- # Reset the config values to org
- self.config.show = show_org
-
- if self.config.show:
- fig_m.show()
- return fig_m
-
- if self.config.show:
- fig.show()
- return fig
-
- if self.config.plot_type == "lineplot":
- return _plotLines(self, data, chromatogramFeatures)
- elif self.config.plot_type == "heatmap":
- return _plotHeatmap(self, data)
- else:
- raise ValueError(f"Invalid plot type: {type}")
-
- def _plotMatplotlib(self, data: DataFrame, chromatogramFeatures: DataFrame = None):
-
- def _plotLines(self, data: pd.DataFrame, chromatogramFeatures: DataFrame = None, ax=None):
- import matplotlib.pyplot as plt
- from matplotlib.lines import Line2D
-
- # Create a figure and axis
- if ax is None:
- fig, ax = plt.subplots(figsize=(self.config.width/100, self.config.height/100), dpi=100)
- else:
- fig = ax.get_figure()
-
-
- # Set plot title and axis labels
- ax.set_title(self.config.title)
- ax.set_xlabel(self.config.xlabel)
- ax.set_ylabel(self.config.ylabel)
-
- # Create a legend
- legend_lines = []
- legend_labels = []
-
- # Plot each unique annotation
- if "Annotation" in data.columns:
- for i, (annotation, group_df) in enumerate(data.groupby('Annotation')):
- line, = ax.plot(group_df[self.config.x_axis_col], group_df[self.config.y_axis_col], color=self.main_palette[i], linewidth=self.config.lineWidth, ls=self.config.lineStyle)
- legend_lines.append(line)
- legend_labels.append(annotation)
-
- # Add legend
- matplotlibLegendLoc= LegendConfig._matplotlibLegendLocationMapper(self.config.legend.loc)
- legend = ax.legend(legend_lines, legend_labels, loc=matplotlibLegendLoc, bbox_to_anchor=self.config.legend.bbox_to_anchor, title=self.config.legend.title, prop={'size': self.config.legend.fontsize})
- legend.get_title().set_fontsize(str(self.config.legend.fontsize))
-
- else: # only one transition
- line, = ax.plot(data[self.config.x_axis_col], data[self.config.y_axis_col], color=self.main_palette[0], linewidth=self.config.lineWidth, ls=self.config.lineStyle)
-
- # Customize the plot
- ax.grid(self.config.grid)
-
- ## add 10% padding to the plot
- padding = (data[self.config.y_axis_col].max() - data[self.config.y_axis_col].min() ) * 0.1
- ax.set_xlim(data[self.config.x_axis_col].min(), data[self.config.x_axis_col].max())
- ax.set_ylim(data[self.config.y_axis_col].min(), data[self.config.y_axis_col].max() + padding)
-
- ##### Plotting chromatogram features #####
- if chromatogramFeatures is not None:
- ax.add_artist(legend)
-
- for idx, (_, feature) in enumerate(chromatogramFeatures.iterrows()):
-
- ax.vlines(x=feature['leftWidth'], ymin=0, ymax=feature['apexIntensity'], lw=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], ls=self.config.featureConfig.lineStyle)
- ax.vlines(x=feature['rightWidth'], ymin=0, ymax=feature['apexIntensity'], lw=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], ls=self.config.featureConfig.lineStyle)
-
- if self.config.featureConfig.legend.show:
- custom_lines = [Line2D([0], [0], color=self.feature_palette[i], lw=2) for i in range(len(chromatogramFeatures))]
- if "q_value" in chromatogramFeatures.columns:
- legend_labels = [f'Feature {i} (q={feature["q_value"]:.2f})' for i, (_,feature) in enumerate(chromatogramFeatures.iterrows())]
- else:
- legend_labels = [f'Feature {i}' for i in range(len(chromatogramFeatures))]
-
- if self.config.featureConfig.legend.show:
-
- matplotlibLegendLoc= LegendConfig._matplotlibLegendLocationMapper(self.config.featureConfig.legend.loc)
- ax.legend(custom_lines, legend_labels, loc=matplotlibLegendLoc, bbox_to_anchor=self.config.featureConfig.legend.bbox_to_anchor, title=self.config.featureConfig.legend.title)
-
- if self.config.show:
- plt.show()
- return fig
-
- def _plotHeatmap(self, data: pd.DataFrame):
- import matplotlib.pyplot as plt
-
- if not self.config.add_marginals:
- # Create a figure and axis
- fig, ax = plt.subplots(figsize=(self.config.width/100, self.config.height/100), dpi=200, constrained_layout=True)
-
- # Plot each unique annotation
- for annotation, df_group in data.groupby("Annotation"):
- x = df_group.rt
- y = df_group.im
- values = df_group.int
-
- scatter = ax.scatter(x, y, c=values, cmap='afmhot_r', marker='s', s=20, edgecolors='none')
-
- # Customize the plot
- ax.set_title(self.config.title)
- ax.set_xlabel("Retention Time [sec]")
- ax.set_ylabel("Ion Mobility")
-
- else:
- # Store original config values that need to be changed for marginals
- show_org = self.config.show
- self.config.show = False
-
- # Create a figure and axis
- fig, ax = plt.subplots(2, 2, figsize=(self.config.width/100, self.config.height/100), dpi=200)
-
- # Plot each unique annotation
- for annotation, df_group in data.groupby("Annotation"):
- x = df_group.rt
- y = df_group.im
- values = df_group.int
-
- scatter = ax[1, 1].scatter(x, y, c=values, cmap='afmhot_r', marker='s', s=20, edgecolors='none')
-
- # Customize the plot
- ax[1, 1].set_title(None)
- ax[1, 1].set_xlabel("Retention Time [sec]")
- ax[1, 1].set_ylabel(None)
- ax[1, 1].set_yticklabels([])
- ax[1, 1].set_yticks([])
-
- # Integrate the data along the retention time dimension
- rt_integrated = self._integrate_data_along_dim(data, 'rt')
- _plotLines(self, rt_integrated, chromatogramFeatures, ax=ax[0, 1])
- # Generate a lineplot for XIC
- # ax[0, 1].plot(rt_integrated['rt'], rt_integrated['int'])
- # ax[0, 1].set_title("Integrated Ion Chromatogram")
- ax[0, 1].set_title(None)
- ax[0, 1].set_xlabel(None)
- ax[0, 1].set_xticklabels([])
- ax[0, 1].set_xticks([])
- ax[0, 1].set_ylabel("Intensity")
- ax[0, 1].yaxis.set_ticks_position('right')
- ax[0, 1].yaxis.set_label_position('right')
- ax[0, 1].yaxis.tick_right()
- ax[0, 1].legend_ = None
-
- # Integrate the data along the ion mobility dimension
- im_integrated = self._integrate_data_along_dim(data, 'im')
- self.config.x_axis_col = 'int'
- self.config.y_axis_col = 'im'
- self.config.legend.loc = 'below'
- # self.config.y_axis_location = "left"
- # self.config.legend.loc = 'below'
- _plotLines(self, im_integrated, ax=ax[1, 0])
- # Generate a lineplot for XIM
- # ax[1, 0].plot(im_integrated['int'], im_integrated['im'])
- ax[1, 0].invert_xaxis()
- ax[1, 0].set_title(None)
- ax[1, 0].set_xlabel("Intensity")
- ax[1, 0].set_ylabel("Ion Mobility")
- ax[1, 0].legend_ = None
-
-
- # Hide the first subplot
- ax[0, 0].axis('off')
-
- # Adjust the layout
- plt.subplots_adjust(wspace=0, hspace=0)
- # plt.tight_layout()
-
- # Reset the config values to org
- self.config.show = show_org
-
- if self.config.show:
- plt.show()
- return fig
-
- if self.config.plot_type == "lineplot":
- return _plotLines(self, data, chromatogramFeatures)
- elif self.config.plot_type == "heatmap":
- return _plotHeatmap(self, data)
- else:
- raise ValueError(f"Invalid plot type: {type}")
-
-# ============================================================================= #
-## FUNCTIONAL API ##
-# ============================================================================= #
-def plotChromatogram(chromatogram: pd.DataFrame,
- chromatogram_features: pd.DataFrame = None,
- plot_type: str = "lineplot",
- add_marginals: bool = False,
- title: str = "Chromatogram Plot",
- show_plot: bool = True,
- ion_mobility: bool = False,
- width: int = 500,
- height: int = 500,
- engine: Literal['PLOTLY', 'BOKEH', 'MATPLOTLIB'] = 'PLOTLY',
- **kwargs):
- """
- Plot a Chromatogram from a MSChromatogram Object
-
- Args:
- chromatogram (DataFrame): DataFrame containing chromatogram data
- chromatogram_features (DataFrame, optional): DataFrame containing chromatogram features. Defaults to None.
- plot_type (str, optional): Type of plot to generate. Defaults to "lineplot". Can be either "lineplot" or "heatmap".
- add_marginals (bool, optional): If True, adds marginal plots for integrated ion chromatogram and ion mobilogram. Defaults to False.
- title (str, optional): title of plot. Defaults to "Chromatogram Plot".
- show_plot (bool, optional): If True, shows the plot. Defaults to True.
- ion_mobility (bool, optional): If True, plots a heatmap of Retention Time vs ion mobility with intensity as the color. Defaults to False.
- width (int, optional): width of the figure. Defaults to 500.
- height (int, optional): height of the figure. Defaults to 500.
- engine (Literal['PLOTLY', 'BOKEH'], optional): Plotting engine to use. Defaults to 'PLOTLY'. Can be either 'PLOTLY' or 'BOKEH'
-
- Returns:
- PLOTLY figure or BOKEH figure depending on engine
- """
-
-
- config = ChromatogramPlotterConfig(title=title, show=show_plot, ion_mobility=ion_mobility, width=width, height=height, plot_type=plot_type, add_marginals=add_marginals, engine=engine, **kwargs)
-
- plotter = ChromatogramPlotter(config)
- return plotter.plot(chromatogram=chromatogram, chromatogramFeatures=chromatogram_features)
-
diff --git a/pyopenms_viz/MSExperimentPlotter.py b/pyopenms_viz/MSExperimentPlotter.py
deleted file mode 100644
index ed6ef1a5..00000000
--- a/pyopenms_viz/MSExperimentPlotter.py
+++ /dev/null
@@ -1,332 +0,0 @@
-from dataclasses import dataclass
-from typing import Literal, Union
-
-import matplotlib.pyplot as plt
-import pandas as pd
-import plotly.graph_objects as go
-from bokeh.models import ColorBar, ColumnDataSource, HoverTool, PrintfTickFormatter
-from bokeh.palettes import Plasma256
-from bokeh.plotting import figure
-from bokeh.transform import linear_cmap
-
-from .BasePlotter import Colors, _BasePlotter, _BasePlotterConfig
-
-
-@dataclass(kw_only=True)
-class MSExperimentPlotterConfig(_BasePlotterConfig):
- bin_peaks: Union[Literal["auto"], bool] = "auto"
- num_RT_bins: int = 50
- num_mz_bins: int = 50
- plot3D: bool = False
-
-
-class MSExperimentPlotter(_BasePlotter):
- def __init__(self, config: MSExperimentPlotterConfig, **kwargs) -> None:
- """
- Initialize the MSExperimentPlotter with a given configuration and optional parameters.
-
- Args:
- config (MSExperimentPlotterConfig): Configuration settings for the spectrum plotter.
- **kwargs: Additional keyword arguments for customization.
- """
- super().__init__(config=config, **kwargs)
-
- def _prepare_data(self, exp: pd.DataFrame) -> pd.DataFrame:
- """Prepares data for plotting based on configuration (binning, relative intensity, hover text)."""
- if self.config.bin_peaks == True or (
- exp.shape[0] > self.config.num_mz_bins * self.config.num_RT_bins
- and self.config.bin_peaks == "auto"
- ):
- exp["mz"] = pd.cut(exp["mz"], bins=self.config.num_mz_bins)
- exp["RT"] = pd.cut(exp["RT"], bins=self.config.num_RT_bins)
-
- # Group by x and y bins and calculate the mean intensity within each bin
- exp = (
- exp.groupby(["mz", "RT"], observed=True)
- .agg({"inty": "mean"})
- .reset_index()
- )
- exp["mz"] = exp["mz"].apply(lambda interval: interval.mid).astype(float)
- exp["RT"] = exp["RT"].apply(lambda interval: interval.mid).astype(float)
- exp = exp.fillna(0)
- else:
- self.config.bin_peaks = False
-
- if self.config.relative_intensity:
- exp["inty"] = exp["inty"] / max(exp["inty"]) * 100
-
- exp["hover_text"] = exp.apply(
- lambda x: f"m/z: {round(x['mz'], 6)}
RT: {round(x['RT'], 2)}
intensity: {int(x['inty'])}",
- axis=1,
- )
-
- return exp.sort_values("inty")
-
- def _plotMatplotlib3D(
- self,
- exp: pd.DataFrame,
- ) -> plt.Figure:
- """Plot 3D peak map with mz, RT and intensity dimensions. Colored peaks based on intensity."""
- fig = plt.figure(
- figsize=(self.config.width / 100, self.config.height / 100),
- layout="constrained",
- )
- ax = fig.add_subplot(111, projection="3d")
-
- if self.config.title:
- ax.set_title(self.config.title, fontsize=12, loc="left")
- ax.set_xlabel(
- self.config.ylabel,
- fontsize=9,
- labelpad=-2,
- color=Colors["DARKGRAY"],
- style="italic",
- )
- ax.set_ylabel(
- self.config.xlabel,
- fontsize=9,
- labelpad=-2,
- color=Colors["DARKGRAY"],
- )
- ax.set_zlabel("intensity", fontsize=10, color=Colors["DARKGRAY"], labelpad=-2)
- for axis in ("x", "y", "z"):
- ax.tick_params(axis=axis, labelsize=8, pad=-2, colors=Colors["DARKGRAY"])
-
- ax.set_box_aspect(aspect=None, zoom=0.88)
- ax.ticklabel_format(axis="z", style="sci", useMathText=True)
- ax.grid(color="#FF0000", linewidth=0.8)
- ax.xaxis.pane.fill = False
- ax.yaxis.pane.fill = False
- ax.zaxis.pane.fill = False
- ax.view_init(elev=25, azim=-45, roll=0)
-
- # Plot lines to the bottom with colored based on inty
- for i in range(len(exp)):
- ax.plot(
- [exp["RT"].iloc[i], exp["RT"].iloc[i]],
- [exp["inty"].iloc[i], 0],
- [exp["mz"].iloc[i], exp["mz"].iloc[i]],
- zdir="x",
- color=plt.cm.magma_r(exp["inty"].iloc[i] / exp["inty"].max()),
- )
- return fig
-
- def _plotMatplotlib2D(
- self,
- exp: pd.DataFrame,
- ) -> plt.Figure:
- """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity."""
- if self.config.plot3D:
- return self._p
- fig, ax = plt.subplots(
- figsize=(self.config.width / 100, self.config.height / 100)
- )
- if self.config.title:
- ax.set_title(self.config.title, fontsize=12, loc="left", pad=20)
- ax.set_xlabel(self.config.xlabel, fontsize=10, color=Colors["DARKGRAY"])
- ax.set_ylabel(
- self.config.ylabel, fontsize=10, style="italic", color=Colors["DARKGRAY"]
- )
- ax.xaxis.label.set_color(Colors["DARKGRAY"])
- ax.tick_params(axis="x", colors=Colors["DARKGRAY"])
- ax.yaxis.label.set_color(Colors["DARKGRAY"])
- ax.tick_params(axis="y", colors=Colors["DARKGRAY"])
- ax.spines[["right", "top"]].set_visible(False)
-
- scatter = ax.scatter(
- exp["RT"],
- exp["mz"],
- c=exp["inty"],
- cmap="magma_r",
- s=20,
- marker="s",
- )
- if self.config.show_legend:
- cb = fig.colorbar(scatter, aspect=40)
- cb.outline.set_visible(False)
- if self.config.relative_intensity:
- cb.ax.yaxis.set_major_formatter(
- plt.FuncFormatter(lambda x, _: f"{int(x)}%")
- )
- else:
- cb.formatter.set_powerlimits((0, 0))
- cb.formatter.set_useMathText(True)
- return fig
-
- def _plotMatplotlib(
- self,
- exp: pd.DataFrame,
- ) -> plt.Figure:
- """Prepares data and returns Matplotlib 2D or 3D plot."""
- exp = self._prepare_data(exp)
- if self.config.plot3D:
- return self._plotMatplotlib3D(exp)
- return self._plotMatplotlib2D(exp)
-
- def _plotBokeh2D(
- self,
- exp: pd.DataFrame,
- ) -> figure:
- """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity."""
- # Initialize figure
- p = figure(
- title=self.config.title,
- x_axis_label=self.config.xlabel,
- y_axis_label=self.config.ylabel,
- width=self.config.width,
- height=self.config.height,
- )
-
- p.grid.grid_line_color = None
- p.border_fill_color = None
- p.outline_line_color = None
-
- mapper = linear_cmap(
- field_name="inty",
- palette=Plasma256[::-1],
- low=exp["inty"].min(),
- high=exp["inty"].max(),
- )
- source = ColumnDataSource(exp)
- p.scatter(
- x="RT",
- y="mz",
- size=6,
- source=source,
- color=mapper,
- marker="square",
- )
- # if not self.config.bin_peaks:
- hover = HoverTool(
- tooltips="""
-
- @hover_text{safe}
-
- """
- )
- p.add_tools(hover)
- if self.config.show_legend:
- # Add color bar
- color_bar = ColorBar(
- color_mapper=mapper["transform"],
- width=8,
- location=(0, 0),
- formatter=PrintfTickFormatter(format="%e"),
- )
- p.add_layout(color_bar, "right")
-
- return p
-
- def _plotBokeh(
- self,
- exp: pd.DataFrame,
- ) -> figure:
- """Prepares data and returns Bokeh 2D plot."""
- exp = self._prepare_data(exp)
- return self._plotBokeh2D(exp)
-
- def _plotPlotly2D(
- self,
- exp: pd.DataFrame,
- ) -> go.Figure:
- """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity."""
- layout = go.Layout(
- title=dict(text=self.config.title),
- xaxis=dict(title=self.config.xlabel),
- yaxis=dict(title=self.config.ylabel),
- showlegend=self.config.show_legend,
- template="simple_white",
- dragmode="select",
- height=self.config.height,
- width=self.config.width,
- )
- fig = go.Figure(layout=layout)
- fig.add_trace(
- go.Scattergl(
- name="peaks",
- x=exp["RT"],
- y=exp["mz"],
- mode="markers",
- marker=dict(
- color=exp["inty"],
- colorscale="sunset",
- size=8,
- symbol="square",
- colorbar=(
- dict(thickness=8, outlinewidth=0, tickformat=".0e")
- if self.config.show_legend
- else None
- ),
- ),
- hovertext=exp["hover_text"],# if not self.config.bin_peaks else None,
- hoverinfo="text",
- showlegend=False,
- )
- )
- return fig
-
- def _plotPlotly(
- self,
- exp: pd.DataFrame,
- ) -> go.Figure:
- """Prepares data and returns Plotly 2D plot."""
- exp = self._prepare_data(exp)
- return self._plotPlotly2D(exp)
-
-# ============================================================================= #
-## FUNCTIONAL API ##
-# ============================================================================= #
-
-
-def plotMSExperiment(
- exp: pd.DataFrame,
- plot3D: bool = False,
- relative_intensity: bool = False,
- bin_peaks: Union[Literal["auto"], bool] = "auto",
- num_RT_bins: int = 50,
- num_mz_bins: int = 50,
- width: int = 750,
- height: int = 500,
- title: str = "Peak Map",
- xlabel: str = "RT (s)",
- ylabel: str = "m/z",
- show_legend: bool = False,
- engine: Literal["PLOTLY", "BOKEH", "MATPLOTLIB"] = "PLOTLY",
-):
- """
- Plots a Spectrum from an MSSpectrum object
-
- Args:
- spectrum (pd.DataFrame): OpenMS MSSpectrum Object
- plot3D: (bool = False, optional): Plot peak map 3D with peaks colored based on intensity. Disables colorbar legend. Works with "MATPLOTLIB" engine only. Defaults to False.
- relative_intensity (bool, optional): If true, plot relative intensity values. Defaults to False.
- bin_peaks: (Union[Literal["auto"], bool], optional): Bin peaks to reduce complexity and improve plotting speed. Hovertext disabled if activated. If set to "auto" any MSExperiment with more then num_RT_bins x num_mz_bins peaks will be binned. Defaults to "auto".
- num_RT_bins: (int, optional): Number of bins in RT dimension. Defaults to 50.
- num_mz_bins: (int, optional): Number of bins in m/z dimension. Defaults to 50.
- width (int, optional): Width of plot. Defaults to 500px.
- height (int, optional): Height of plot. Defaults to 500px.
- title (str, optional): Plot title. Defaults to "Spectrum Plot".
- xlabel (str, optional): X-axis label. Defaults to "m/z".
- ylabel (str, optional): Y-axis label. Defaults to "intensity" or "ion mobility".
- show_legend (int, optional): Show legend. Defaults to False.
- engine (Literal['PLOTLY', 'BOKEH', 'MATPLOTLIB'], optional): Plotting engine to use. Defaults to 'PLOTLY' can be either 'PLOTLY', 'BOKEH' or 'MATPLOTLIB'.
-
- Returns:
- Plot: The generated plot using the specified engine.
- """
- config = MSExperimentPlotterConfig(
- plot3D=plot3D,
- relative_intensity=relative_intensity,
- bin_peaks=bin_peaks,
- num_RT_bins=num_RT_bins,
- num_mz_bins=num_mz_bins,
- width=width,
- height=height,
- title=title,
- xlabel=xlabel,
- ylabel=ylabel,
- show_legend=show_legend,
- engine=engine,
- )
- plotter = MSExperimentPlotter(config)
- return plotter.plot(exp.copy())
diff --git a/pyopenms_viz/SpectrumPlotter.py b/pyopenms_viz/SpectrumPlotter.py
deleted file mode 100644
index 5d7d5808..00000000
--- a/pyopenms_viz/SpectrumPlotter.py
+++ /dev/null
@@ -1,622 +0,0 @@
-import re
-from dataclasses import dataclass
-from itertools import cycle
-from typing import Literal, Optional, Union
-
-import matplotlib.pyplot as plt
-import numpy as np
-import pandas as pd
-import plotly.graph_objects as go
-from bokeh.models import ColorBar, ColumnDataSource, HoverTool, Label, Span
-from bokeh.palettes import Plasma256
-from bokeh.plotting import figure
-from bokeh.transform import linear_cmap
-
-from .BasePlotter import Colors, _BasePlotter, _BasePlotterConfig
-
-
-@dataclass(kw_only=True)
-class SpectrumPlotterConfig(_BasePlotterConfig):
- ion_mobility: bool = False
- annotate_mz: bool = False
- annotate_ions: bool = False
- annotate_sequence: bool = False
- mirror_spectrum: bool = (False,)
- custom_peak_color: bool = (False,)
- custom_annotation_color: bool = (False,)
- custom_annotation_text: bool = False
-
-
-class SpectrumPlotter(_BasePlotter):
- def __init__(self, config: SpectrumPlotterConfig, **kwargs) -> None:
- """
- Initialize the SpectrumPlotter with a given configuration and optional parameters.
-
- Args:
- config (SpectrumPlotterConfig): Configuration settings for the spectrum plotter.
- **kwargs: Additional keyword arguments for customization.
- """
- super().__init__(config=config, **kwargs)
- # If y-axis label is default ("intensity") and ion_mobility is True, update label
- if self.config.ylabel == "intensity" and self.config.ion_mobility:
- self.config.ylabel = "ion mobility"
-
- def _prepare_data(
- self,
- spectrum: Union[pd.DataFrame, list[pd.DataFrame]],
- reference_spectrum: Union[pd.DataFrame, list[pd.DataFrame], None],
- ) -> tuple[list, list]:
- """Prepares data for plotting based on configuration (ensures list format for input spectra, relative intensity, hover text)."""
-
- # Ensure input spectra dataframes are in lists
- if not isinstance(spectrum, list):
- spectrum = [spectrum]
-
- if reference_spectrum is None:
- reference_spectrum = []
- elif not isinstance(reference_spectrum, list):
- reference_spectrum = [reference_spectrum]
- # Convert to relative intensity if required
- if self.config.relative_intensity or self.config.mirror_spectrum:
- combined_spectra = spectrum + (
- reference_spectrum if reference_spectrum else []
- )
- for df in combined_spectra:
- df["intensity"] = df["intensity"] / df["intensity"].max() * 100
- # Add hover text
- for spec in spectrum+reference_spectrum:
- spec["hover_text_ion_mobility"] = spec.apply(
- lambda x: f"{x['native_id']}
m/z: {x['mz']}
ion mobility: {x['ion_mobility']}
intensity: {x['intensity']}",
- axis=1,
- )
- return spectrum, reference_spectrum
-
- def _get_ion_color_annotation(self, annotation: str) -> str:
- """Retrieve the color associated with a specific ion annotation from a predefined colormap."""
- colormap = {
- "a": Colors["PURPLE"],
- "b": Colors["BLUE"],
- "c": Colors["LIGHTBLUE"],
- "x": Colors["YELLOW"],
- "y": Colors["RED"],
- "z": Colors["ORANGE"],
- }
- for key in colormap.keys():
- # Exact matches
- if annotation == key:
- return colormap[key]
- # Fragment ions via regex
- x = re.search(r"^[abcxyz]{1}[0-9]*[+-]$", annotation)
- if x:
- return colormap[annotation[0]]
- return Colors["DARKGRAY"]
-
- def _get_peak_color(self, default_color: str, peak: pd.Series) -> str:
- """Determine the color of a peak based on custom settings or annotation."""
- if self.config.custom_peak_color and "color_peak" in peak:
- return peak["color_peak"]
-
- if self.config.annotate_ions:
- return self._get_annotation_color(peak, default_color)
-
- return default_color
-
- def _get_annotation_text(self, peak: pd.Series) -> str:
- """Generate the annotation text for a given peak based on the configuration."""
- if "custom_annotation" in peak and self.config.custom_annotation_text:
- return peak["custom_annotation"]
-
- texts = []
-
- if self.config.annotate_ions and peak["ion_annotation"] != "none":
- texts.append(peak["ion_annotation"])
-
- if self.config.annotate_sequence and peak["sequence"]:
- texts.append(peak["sequence"])
-
- if self.config.annotate_mz:
- texts.append(str(peak["mz"]))
-
- return "
".join(texts)
-
- def _get_annotation_color(
- self, peak: pd.Series, fallback_color: str = "black"
- ) -> str:
- """Determine the color for annotations based on custom settings or ion type."""
- if "color_annotation" in peak and self.config.custom_annotation_color:
- return peak["color_annotation"]
-
- if self.config.annotate_ions:
- return self._get_ion_color_annotation(peak["ion_annotation"])
-
- if self.config.custom_peak_color and "color_peak" in peak:
- return peak["color_peak"]
-
- return fallback_color
-
- def _get_relative_intensity_ticks(self) -> tuple[list[int], list[str]]:
- """Generate the ticks and labels for relative intensity on the y-axis."""
- ticks = [0, 25, 50, 75, 100]
- labels = ["0%", "25%", "50%", "75%", "100%"]
-
- if self.config.mirror_spectrum:
- mirror_ticks = [-100, -75, -50, -25]
- mirror_labels = ["-100%", "-75%", "-50%", "-25%"]
- ticks = mirror_ticks + ticks
- labels = mirror_labels + labels
-
- return ticks, labels
-
-
- def _combine_sort_spectra_by_intensity(
- self, spectra: list[pd.DataFrame]
- ) -> pd.DataFrame:
- """Combine and sort spectra by intensity."""
- combined_df = pd.concat(spectra).reset_index(drop=True)
- sorted_df = combined_df.sort_values(by="intensity").reset_index(drop=True)
- return sorted_df
-
- def _plotMatplotlib(
- self,
- spectrum: Union[pd.DataFrame, list[pd.DataFrame]],
- reference_spectrum: Optional[Union[pd.DataFrame, list[pd.DataFrame]]] = None,
- ) -> plt.Figure:
- """Plot the spectrum using Matplotlib."""
-
- def plot_spectrum(ax, df, color, mirror=False):
- for i, peak in df.iterrows():
- intensity = -peak["intensity"] if mirror else peak["intensity"]
- peak_color = self._get_peak_color(color, peak)
- ax.plot(
- [peak["mz"], peak["mz"]],
- [0, intensity],
- color=peak_color,
- linewidth=1.5,
- label=peak["native_id"] if i == 0 else None,
- )
- if any(
- [
- self.config.annotate_mz,
- self.config.annotate_ions,
- self.config.annotate_sequence,
- self.config.custom_annotation_text,
- ]
- ):
- text = self._get_annotation_text(peak).replace("
", "\n")
- annotation_color = self._get_annotation_color(peak, peak_color)
- ax.annotate(
- text,
- xy=(peak["mz"], intensity),
- xytext=(1, 0),
- textcoords="offset points",
- fontsize=8,
- color=annotation_color,
- )
-
- spectrum, reference_spectrum = self._prepare_data(
- spectrum, reference_spectrum
- )
-
- fig, ax = plt.subplots(
- figsize=(self.config.width / 100, self.config.height / 100)
- )
- ax.set_title(self.config.title, fontsize=12, loc="left", pad=20)
- ax.set_xlabel(
- self.config.xlabel, fontsize=10, style="italic", color=Colors["DARKGRAY"]
- )
- ax.set_ylabel(self.config.ylabel, fontsize=10, color=Colors["DARKGRAY"])
- ax.xaxis.label.set_color(Colors["DARKGRAY"])
- ax.tick_params(axis="x", colors=Colors["DARKGRAY"])
- ax.yaxis.label.set_color(Colors["DARKGRAY"])
- ax.tick_params(axis="y", colors=Colors["DARKGRAY"])
- ax.spines[["right", "top"]].set_visible(False)
-
- if self.config.ion_mobility:
- df = self._combine_sort_spectra_by_intensity(spectrum)
- scatter = ax.scatter(
- df["mz"],
- df["ion_mobility"],
- c=df["intensity"],
- cmap="plasma_r",
- s=20,
- marker="s",
- )
- if self.config.show_legend:
- cb = fig.colorbar(scatter, aspect=40)
- cb.outline.set_visible(False)
- return fig
-
- gs_colors = self._get_n_grayscale_colors(
- max(len(spectrum), len(reference_spectrum or []))
- )
- colors = cycle(gs_colors)
-
- for spec in spectrum:
- plot_spectrum(ax, spec, next(colors))
-
- if self.config.mirror_spectrum:
- colors = cycle(gs_colors)
- for ref_spec in reference_spectrum:
- plot_spectrum(ax, ref_spec, next(colors), mirror=True)
- ax.plot(ax.get_xlim(), [0, 0], color="#EEEEEE", linewidth=1.5)
-
- ax.set_ylim([0 if not self.config.mirror_spectrum else None, None])
- if self.config.relative_intensity or self.config.mirror_spectrum:
- ticks, labels = self._get_relative_intensity_ticks()
- ax.set_yticks(ticks)
- ax.set_yticklabels(labels)
- else:
- ax.ticklabel_format(axis="both", style="sci", useMathText=True)
-
- if self.config.show_legend:
- ax.legend(loc="best")
-
- return fig
-
- def _plotBokeh(
- self,
- spectrum: Union[pd.DataFrame, list[pd.DataFrame]],
- reference_spectrum: Optional[Union[pd.DataFrame, list[pd.DataFrame]]] = None,
- ) -> figure:
- """Plot the spectrum using Bokeh."""
-
- def plot_spectrum(p, df, color, mirror=False):
- """Add spectrum plot to figure including annotations based on configuration, default peak color and mirror spectrum."""
- for i, peak in df.iterrows():
- intensity = -peak["intensity"] if mirror else peak["intensity"]
- peak_color = self._get_peak_color(color, peak)
- if i == 0 and self.config.show_legend:
- p.line(
- [peak["mz"], peak["mz"]],
- [0, intensity],
- line_color=peak_color,
- line_width=2,
- legend_label=peak["native_id"],
- )
- else:
- p.line(
- [peak["mz"], peak["mz"]],
- [0, intensity],
- line_color=peak_color,
- line_width=2,
- )
- if any(
- [
- self.config.annotate_mz,
- self.config.annotate_ions,
- self.config.annotate_sequence,
- self.config.custom_annotation_text,
- ]
- ):
- text = self._get_annotation_text(peak).replace("
", "\n")
- annotation_color = self._get_annotation_color(peak, peak_color)
- label = Label(
- x=peak["mz"],
- y=intensity,
- text=text,
- text_font_size="8pt",
- text_color=annotation_color,
- x_offset=1,
- y_offset=0,
- )
- p.add_layout(label)
-
- def _get_ion_mobility_plot(p: figure) -> figure:
- """Adds ion mobility traces to figure."""
- df = self._combine_sort_spectra_by_intensity(spectrum)
- mapper = linear_cmap(
- field_name="intensity",
- palette=Plasma256[::-1],
- low=df["intensity"].min(),
- high=df["intensity"].max(),
- )
- source = ColumnDataSource(df)
- p.scatter(
- x="mz",
- y="ion_mobility",
- size=6,
- source=source,
- color=mapper,
- marker="square",
- )
- hover = HoverTool(
- tooltips="""
-
- @hover_text_ion_mobility{safe}
-
- """
- )
- p.add_tools(hover)
- if self.config.show_legend:
- color_bar = ColorBar(
- color_mapper=mapper["transform"], width=8, location=(0, 0)
- )
- p.add_layout(color_bar, "right")
- return p
-
- spectrum, reference_spectrum = self._prepare_data(
- spectrum, reference_spectrum
- )
-
- # Initialize figure
- p = figure(
- title=self.config.title,
- x_axis_label=self.config.xlabel,
- y_axis_label=self.config.ylabel,
- width=self.config.width,
- height=self.config.height,
- )
-
- p.grid.grid_line_color = None
- p.border_fill_color = None
- p.outline_line_color = None
-
- if self.config.ion_mobility:
- return _get_ion_mobility_plot(p)
-
- gs_colors = self._get_n_grayscale_colors(
- max(len(spectrum), len(reference_spectrum or []))
- )
- colors = cycle(gs_colors)
-
- for spec in spectrum:
- plot_spectrum(p, spec, next(colors))
-
- if self.config.mirror_spectrum:
- colors = cycle(gs_colors)
- for ref_spec in reference_spectrum:
- plot_spectrum(p, ref_spec, next(colors), mirror=True)
- zero_line = Span(
- location=0, dimension="width", line_color="#EEEEEE", line_width=1.5
- )
- p.add_layout(zero_line)
-
- if self.config.relative_intensity or self.config.mirror_spectrum:
- ticks, labels = self._get_relative_intensity_ticks()
- p.yaxis.ticker = ticks
- p.yaxis.major_label_overrides = {
- tick: label for tick, label in zip(ticks, labels)
- }
- else:
- p.yaxis.formatter.use_scientific = True
-
- p.y_range.start = -110 if self.config.mirror_spectrum else 0
- # adjust x-axis limits to not cut peaks and annotations
- x_values = [
- glyph.data_source.data["x"][0]
- for glyph in p.renderers
- if hasattr(glyph, "data_source")
- ]
- xmin = min(x_values)
- xmax = max(x_values)
- padding = 0.15 * (xmax - xmin)
- p.x_range.end = xmax + padding
- return p
-
- def _plotPlotly(
- self,
- spectrum: Union[pd.DataFrame, list[pd.DataFrame]],
- reference_spectrum: Optional[Union[pd.DataFrame, list[pd.DataFrame]]] = None,
- ) -> go.Figure:
- """Plot the spectrum using Plotly."""
-
- def _create_peak_traces(
- spectrum: pd.DataFrame,
- line_color: str,
- intensity_direction: Literal[1, -1] = 1,
- ) -> list[go.Scattergl]:
- """Create peak traces based on given line color and orientation (-1 for mirror spectra)."""
- return [
- go.Scattergl(
- x=[peak["mz"]] * 2,
- y=[0, intensity_direction * peak["intensity"]],
- mode="lines",
- line=dict(color=self._get_peak_color(line_color, peak)),
- name=peak["native_id"],
- text=f"{peak['native_id']}
m/z: {peak['mz']}
intensity: {peak['intensity']}",
- hoverinfo="text",
- showlegend=(i == 0),
- )
- for i, peak in spectrum.iterrows()
- ]
-
- def _create_annotations(
- spectra: list[pd.DataFrame],
- intensity_sign: Literal[1, -1] = 1,
- ) -> list[dict]:
- """Create peak annotations based on configuration."""
- if not any(
- [
- self.config.annotate_mz,
- self.config.annotate_ions,
- self.config.annotate_sequence,
- self.config.custom_annotation_text,
- ]
- ):
- return []
-
- annotations = []
- colors = cycle(self.gs_colors)
- for spectrum in spectra:
- for _, peak in spectrum.iterrows():
- text = self._get_annotation_text(peak)
- color = self._get_annotation_color(peak, next(colors))
- annotations.append(
- dict(
- x=peak["mz"],
- y=intensity_sign * peak["intensity"],
- text=text,
- showarrow=False,
- xanchor="left",
- font=dict(
- family="Open Sans Mono, monospace",
- size=12,
- color=color,
- ),
- )
- )
- return annotations
-
- def _get_ion_mobility_plot(fig: go.Figure) -> go.Figure:
- """Adds ion mobility traces to figure."""
- df = self._combine_sort_spectra_by_intensity(spectrum)
- fig.add_trace(
- go.Scattergl(
- name="peaks",
- x=df["mz"],
- y=df["ion_mobility"],
- mode="markers",
- marker=dict(
- color=df["intensity"],
- colorscale="sunset",
- size=8,
- symbol="square",
- colorbar=(
- dict(thickness=8, outlinewidth=0)
- if self.config.show_legend
- else None
- ),
- ),
- hovertext=df["hover_text_ion_mobility"],
- hoverinfo="text",
- showlegend=False,
- )
- )
- return fig
-
- spectrum, reference_spectrum = self._prepare_data(
- spectrum, reference_spectrum
- )
-
- layout = go.Layout(
- title=dict(text=self.config.title),
- xaxis=dict(title=self.config.xlabel),
- yaxis=dict(title=self.config.ylabel, rangemode="tozero"),
- showlegend=self.config.show_legend,
- template="simple_white",
- )
- fig = go.Figure(layout=layout)
-
- if self.config.ion_mobility:
- return _get_ion_mobility_plot(fig)
-
- self.gs_colors = self._get_n_grayscale_colors(
- max(len(spectrum), len(reference_spectrum or []))
- )
- colors = cycle(self.gs_colors)
-
- traces = []
- for spec in spectrum:
- traces += _create_peak_traces(spec, next(colors))
-
- if self.config.mirror_spectrum:
- colors = cycle(self.gs_colors)
- for ref_spec in reference_spectrum:
- traces += _create_peak_traces(
- ref_spec, next(colors), intensity_direction=-1
- )
-
- fig.add_traces(traces)
-
- annotations = _create_annotations(spectrum, 1)
- for annotation in annotations:
- fig.add_annotation(annotation)
-
- if self.config.mirror_spectrum:
- annotations = _create_annotations(reference_spectrum, -1)
- for annotation in annotations:
- fig.add_annotation(annotation)
- fig.add_hline(y=0, line_color=Colors["LIGHTGRAY"], line_width=2)
-
- if self.config.relative_intensity or self.config.mirror_spectrum:
- ticks, labels = self._get_relative_intensity_ticks()
- fig.update_layout(
- yaxis=dict(tickmode="array", tickvals=ticks, ticktext=labels),
- yaxis_range=[-110 if self.config.mirror_spectrum else 0, 110],
- )
- # adjust x-axis limits to not cut peaks and annotations
- x_values = [trace.x for trace in fig.data]
- xmin = min([min(values) for values in x_values])
- xmax = max([max(values) for values in x_values])
- padding = 0.15 * (xmax - xmin)
- fig.update_layout(
- xaxis_range=[
- xmin - 1,
- xmax + padding,
- ]
- )
-
- return fig
-
-
-# ============================================================================= #
-## FUNCTIONAL API ##
-# ============================================================================= #
-
-
-def plotSpectrum(
- spectrum: Union[pd.DataFrame, list[pd.DataFrame]],
- reference_spectrum: Union[pd.DataFrame, list[pd.DataFrame]] = None,
- ion_mobility: bool = False,
- annotate_mz: bool = False,
- annotate_ions: bool = False,
- annotate_sequence: bool = False,
- mirror_spectrum: bool = False,
- relative_intensity: bool = False,
- custom_peak_color: bool = False,
- custom_annotation_text: bool = False,
- custom_annotation_color: bool = False,
- width: int = 750,
- height: int = 500,
- title: str = "Spectrum Plot",
- xlabel: str = "m/z",
- ylabel: str = "intensity",
- show_legend: bool = False,
- engine: Literal["PLOTLY", "BOKEH", "MATPLOTLIB"] = "PLOTLY",
-):
- """
- Plots a Spectrum from an MSSpectrum object
-
- Args:
- spectrum (Union[pd.DataFrame, list[pd.DataFrame]]): OpenMS MSSpectrum Object
- reference_spectrum (Union[pd.DataFrame, list[pd.DataFrame]], optional): Optional OpenMS Spectrum object to plot in mirror or used in annotation. Defaults to None.
- ion_mobility (bool, optional): If true, plots spectra (not including reference spectra) as heatmap of m/z vs ion mobility with intensity as color. Defaults to False.
- annotate_mz (bool, optional): If true, annotate peaks with m/z values. Defaults to False.
- annotate_ions (bool, optional): If true, annotate fragment ions. Defaults to False.
- annotate_sequence (bool, optional): Annotate peaks based on sequence provided. Defaults to False
- mirror_spectrum (bool, optional): If true, plot mirror spectrum. Defaults to True, if no mirror reference_spectrum is provided, this is ignored.
- relative_intensity (bool, optional): If true, plot relative intensity values. Defaults to False.
- custom_peak_color (bool, optional): If true, plot peaks with colors from "color_peak" column.
- custom_annotation_text (bool, optional): If true, annotate peaks with custom text from "custom_annotation" column. Overwrites all other annotations.Use
for line breaks.
- custom_annotation_color (bool, optional): If true, plot annotations with colors from "color_annotation" column.
- width (int, optional): Width of plot. Defaults to 500px.
- height (int, optional): Height of plot. Defaults to 500px.
- title (str, optional): Plot title. Defaults to "Spectrum Plot".
- xlabel (str, optional): X-axis label. Defaults to "m/z".
- ylabel (str, optional): Y-axis label. Defaults to "intensity" or "ion mobility".
- show_legend (int, optional): Show legend. Defaults to False.
- engine (Literal['PLOTLY', 'BOKEH', 'MATPLOTLIB'], optional): Plotting engine to use. Defaults to 'PLOTLY' can be either 'PLOTLY', 'BOKEH' OR 'MATPLOTLIB'.
-
- Returns:
- Plot: The generated plot using the specified engine.
- """
- config = SpectrumPlotterConfig(
- ion_mobility=ion_mobility,
- annotate_mz=annotate_mz,
- mirror_spectrum=mirror_spectrum,
- annotate_sequence=annotate_sequence,
- annotate_ions=annotate_ions,
- relative_intensity=relative_intensity,
- custom_peak_color=custom_peak_color,
- custom_annotation_text=custom_annotation_text,
- custom_annotation_color=custom_annotation_color,
- width=width,
- height=height,
- title=title,
- xlabel=xlabel,
- ylabel=ylabel,
- show_legend=show_legend,
- engine=engine,
- )
- plotter = SpectrumPlotter(config)
- return plotter.plot(spectrum.copy(), reference_spectrum=reference_spectrum.copy())
diff --git a/pyopenms_viz/__init__.py b/pyopenms_viz/__init__.py
index 2df4ba87..ad3a6504 100644
--- a/pyopenms_viz/__init__.py
+++ b/pyopenms_viz/__init__.py
@@ -1,8 +1,169 @@
-"""
-pyopenms_viz is a package for visualizing OpenMS data using Bokeh and Plotly.
-"""
+from pandas.plotting._core import PlotAccessor
+from pandas.core.frame import DataFrame
+from typing import Any
+from pandas.core.dtypes.generic import ABCDataFrame
+import types
-from .ChromatogramPlotter import plotChromatogram, ChromatogramPlotterConfig, ChromatogramPlotter
-from .SpectrumPlotter import plotSpectrum, SpectrumPlotterConfig, SpectrumPlotter
-__all__ = [ 'plotChromatogram', 'ChromatogramPlotterConfig', 'ChromatogramPlotter', 'plotSpectrum', 'SpectrumPlotterConfig', 'SpectrumPlotter']
\ No newline at end of file
+class PlotAccessor:
+ """
+ Make plots of MassSpec data using dataframes
+
+ """
+
+ _common_kinds = ("line", "vline", "scatter")
+ _msdata_kinds = ("chromatogram", "mobilogram", "spectrum", "feature_heatmap")
+ _all_kinds = _common_kinds + _msdata_kinds
+
+ def __init__(self, data: DataFrame) -> None:
+ self._parent = data
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ backend_name = kwargs.get("backend", None)
+ if backend_name is None:
+ backend_name = "matplotlib"
+
+ plot_backend = _get_plot_backend(backend_name)
+
+ x, y, kind, kwargs = self._get_call_args(
+ plot_backend.__name__, self._parent, args, kwargs
+ )
+
+ if kind not in self._all_kinds:
+ raise ValueError(
+ f"{kind} is not a valid plot kind "
+ f"Valid plot kinds: {self._all_kinds}"
+ )
+
+ # Call the plot method of the selected backend
+ if "backend" in kwargs:
+ kwargs.pop("backend")
+ return plot_backend.plot(self._parent, x=x, y=y, kind=kind, **kwargs)
+
+ @staticmethod
+ def _get_call_args(backend_name: str, data: DataFrame, args, kwargs):
+ """
+ Get the arguments to pass to the plotting backend.
+
+ Parameters
+ ----------
+ backend_name : str
+ The name of the backend.
+ data : DataFrame
+ The data to plot.
+ args : tuple
+ The positional arguments passed to the plotting function.
+ kwargs : dict
+ The keyword arguments passed to the plotting function.
+
+ Returns
+ -------
+ dict
+ The arguments to pass to the plotting backend.
+ """
+ if isinstance(data, ABCDataFrame):
+ arg_def = [
+ ("x", None),
+ ("y", None),
+ ("kind", "line"),
+ ("by", None),
+ ("subplots", None),
+ ("sharex", None),
+ ("sharey", None),
+ ("height", None),
+ ("width", None),
+ ("grid", None),
+ ("toolbar_location", None),
+ ("fig", None),
+ ("title", None),
+ ("xlabel", None),
+ ("ylabel", None),
+ ("x_axis_location", None),
+ ("y_axis_location", None),
+ ("line_type", None),
+ ("line_width", None),
+ ("min_border", None),
+ ("show_plot", None),
+ ("legend", None),
+ ("feature_config", None),
+ ("_config", None),
+ ("backend", backend_name),
+ ]
+ else:
+ raise ValueError(
+ f"Unsupported data type: {type(data).__name__}, expected DataFrame."
+ )
+
+ pos_args = {name: value for (name, _), value in zip(arg_def, args)}
+
+ kwargs = dict(arg_def, **pos_args, **kwargs)
+
+ x = kwargs.pop("x", None)
+ y = kwargs.pop("y", None)
+ kind = kwargs.pop("kind", "line")
+ return x, y, kind, kwargs
+
+
+_backends: dict[str, types.ModuleType] = {}
+
+
+def _load_backend(backend: str) -> types.ModuleType:
+ """
+ Load a plotting backend.
+
+ Parameters
+ ----------
+ backend : str
+ The identifier for the backend. Either "bokeh", "matplotlib", "plotly",
+ or a module name.
+
+ Returns
+ -------
+ types.ModuleType
+ The imported backend.
+ """
+ if backend == "bokeh":
+ try:
+ module = importlib.import_module("pyopenms_viz.plotting._bokeh")
+ except ImportError:
+ raise ImportError(
+ "Bokeh is required for plotting when the 'bokeh' backend is selected."
+ ) from None
+ return module
+
+ elif backend == "matplotlib":
+ try:
+ module = importlib.import_module("pyopenms_viz.plotting._matplotlib")
+ except ImportError:
+ raise ImportError(
+ "Matplotlib is required for plotting when the 'matplotlib' backend is selected."
+ ) from None
+ return module
+
+ elif backend == "plotly":
+ try:
+ module = importlib.import_module("pyopenms_viz.plotting._plotly")
+ except ImportError:
+ raise ImportError(
+ "Plotly is required for plotting when the 'plotly' backend is selected."
+ ) from None
+ return module
+
+ raise ValueError(
+ f"Could not find plotting backend '{backend}'. Needs to be one of 'bokeh', 'matplotlib', or 'plotly'."
+ )
+
+
+def _get_plot_backend(backend: str | None = None):
+
+ backend_str: str = backend or "matplotlib"
+
+ if backend_str in _backends:
+ return _backends[backend_str]
+
+ module = _load_backend(backend_str)
+ _backends[backend_str] = module
+ return module
+
+
+__all__ = ["PlotAccessor"]
diff --git a/pyopenms_viz/plotting/_bokeh/__init__.py b/pyopenms_viz/_bokeh/__init__.py
similarity index 80%
rename from pyopenms_viz/plotting/_bokeh/__init__.py
rename to pyopenms_viz/_bokeh/__init__.py
index f9030883..38547b99 100644
--- a/pyopenms_viz/plotting/_bokeh/__init__.py
+++ b/pyopenms_viz/_bokeh/__init__.py
@@ -2,8 +2,9 @@
from typing import TYPE_CHECKING
+from ..constants import IS_SPHINX_BUILD
-from pyopenms_viz.plotting._bokeh.core import (
+from .core import (
BOKEHLinePlot,
BOKEHVLinePlot,
BOKEHScatterPlot,
@@ -14,7 +15,11 @@
)
if TYPE_CHECKING:
- from pyopenms_viz.plotting._bokeh.core import BOKEHPlot
+ from .core import BOKEHPlot
+
+if IS_SPHINX_BUILD:
+ from .core import BOKEH_MSPlotter, BOKEHPlot
+
PLOT_CLASSES: dict[str, type[BOKEHPlot]] = {
"line": BOKEHLinePlot,
diff --git a/pyopenms_viz/plotting/_bokeh/core.py b/pyopenms_viz/_bokeh/core.py
similarity index 96%
rename from pyopenms_viz/plotting/_bokeh/core.py
rename to pyopenms_viz/_bokeh/core.py
index 68ed23b4..eacb121b 100644
--- a/pyopenms_viz/plotting/_bokeh/core.py
+++ b/pyopenms_viz/_bokeh/core.py
@@ -25,15 +25,15 @@
LinePlot,
VLinePlot,
ScatterPlot,
- ComplexPlot,
+ BaseMSPlotter,
ChromatogramPlot,
MobilogramPlot,
FeatureHeatmapPlot,
SpectrumPlot,
- APPEND_PLOT_DOC
+ APPEND_PLOT_DOC,
)
from .._misc import ColorGenerator
-from ...constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON
+from ..constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON
class BOKEHPlot(BasePlotter, ABC):
@@ -302,7 +302,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, **kwargs):
return fig, legend
-class BOKEHComplexPlot(ComplexPlot, BOKEHPlot, ABC):
+class BOKEH_MSPlotter(BaseMSPlotter, BOKEHPlot, ABC):
def get_line_renderer(self, data, x, y, **kwargs) -> None:
return BOKEHLinePlot(data, x, y, **kwargs)
@@ -334,7 +334,7 @@ def _create_tooltips(self):
return TOOLTIPS, None
-class BOKEHChromatogramPlot(BOKEHComplexPlot, ChromatogramPlot):
+class BOKEHChromatogramPlot(BOKEH_MSPlotter, ChromatogramPlot):
"""
Class for assembling a Bokeh extracted ion chromatogram plot
"""
@@ -363,8 +363,8 @@ def _add_peak_boundaries(self, annotation_data):
line_dash=self.feature_config.line_type,
line_width=self.feature_config.line_width,
)
- if 'name' in annotation_data.columns:
- use_name = feature['name']
+ if "name" in annotation_data.columns:
+ use_name = feature["name"]
else:
use_name = f"Feature {idx}"
if "q_value" in annotation_data.columns:
@@ -404,7 +404,7 @@ class BOKEHMobilogramPlot(BOKEHChromatogramPlot, MobilogramPlot):
pass
-class BOKEHSpectrumPlot(BOKEHComplexPlot, SpectrumPlot):
+class BOKEHSpectrumPlot(BOKEH_MSPlotter, SpectrumPlot):
"""
Class for assembling a Bokeh spectrum plot
"""
@@ -412,7 +412,7 @@ class BOKEHSpectrumPlot(BOKEHComplexPlot, SpectrumPlot):
pass
-class BOKEHFeatureHeatmapPlot(BOKEHComplexPlot, FeatureHeatmapPlot):
+class BOKEHFeatureHeatmapPlot(BOKEH_MSPlotter, FeatureHeatmapPlot):
"""
Class for assembling a Bokeh feature heatmap plot
"""
@@ -429,7 +429,7 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs):
self.fig = scatterPlot.generate(
marker="square", line_color=mapper, fill_color=mapper, **other_kwargs
)
-
+
if self.annotation_data is not None:
self._add_box_boundaries(self.annotation_data)
@@ -489,7 +489,7 @@ def get_manual_bounding_box_coords(self):
"y1": f"{self.y}_1",
}
)
-
+
def _add_box_boundaries(self, annotation_data):
color_gen = ColorGenerator(
colormap=self.feature_config.colormap, n=annotation_data.shape[0]
@@ -515,10 +515,10 @@ def _add_box_boundaries(self, annotation_data):
color=next(color_gen),
line_dash=self.feature_config.line_type,
line_width=self.feature_config.line_width,
- fill_alpha=0
+ fill_alpha=0,
)
- if 'name' in annotation_data.columns:
- use_name = feature['name']
+ if "name" in annotation_data.columns:
+ use_name = feature["name"]
else:
use_name = f"Feature {idx}"
if "q_value" in annotation_data.columns:
@@ -536,4 +536,3 @@ def _add_box_boundaries(self, annotation_data):
str(self.feature_config.legend.fontsize) + "pt"
)
self.fig.add_layout(legend, self.feature_config.legend.loc)
-
\ No newline at end of file
diff --git a/pyopenms_viz/plotting/_config.py b/pyopenms_viz/_config.py
similarity index 100%
rename from pyopenms_viz/plotting/_config.py
rename to pyopenms_viz/_config.py
diff --git a/pyopenms_viz/plotting/_core.py b/pyopenms_viz/_core.py
similarity index 97%
rename from pyopenms_viz/plotting/_core.py
rename to pyopenms_viz/_core.py
index 929f0c1c..1e96f131 100644
--- a/pyopenms_viz/plotting/_core.py
+++ b/pyopenms_viz/_core.py
@@ -10,11 +10,7 @@
from pandas.core.dtypes.common import is_integer
from pandas.util._decorators import Appender
-from ._config import (
- LegendConfig,
- FeatureConfig,
- _BasePlotterConfig
-)
+from ._config import LegendConfig, FeatureConfig, _BasePlotterConfig
from ._misc import ColorGenerator
@@ -137,7 +133,7 @@ def __init__(
self.kind = kind
self.by = by
self.relative_intensity = relative_intensity
-
+
# Plotting attributes
self.subplots = subplots
self.sharex = sharex
@@ -156,21 +152,21 @@ def __init__(
self.line_width = line_width
self.min_border = min_border
self.show_plot = show_plot
-
+
self.legend = legend
self.feature_config = feature_config
-
+
self._config = _config
-
+
if _config is not None:
self._update_from_config(_config)
-
+
if self.legend is not None and isinstance(self.legend, dict):
self.legend = LegendConfig.from_dict(self.legend)
-
+
if self.feature_config is not None and isinstance(self.feature_config, dict):
self.feature_config = FeatureConfig.from_dict(self.feature_config)
-
+
### get x and y data
if self._kind in {
"line",
@@ -390,7 +386,7 @@ def _kind(self):
return "scatter"
-class ComplexPlot(BasePlotter, ABC):
+class BaseMSPlotter(BasePlotter, ABC):
"""
Abstract class for complex plots, such as chromatograms and mobilograms which are made up of simple plots such as ScatterPlots, VLines and LinePlots.
@@ -423,7 +419,7 @@ def _create_tooltips(self):
pass
-class ChromatogramPlot(BasePlotter, ABC):
+class ChromatogramPlot(BaseMSPlotter, ABC):
@property
def _kind(self):
return "chromatogram"
@@ -431,10 +427,10 @@ def _kind(self):
def __init__(
self, data, x, y, annotation_data: DataFrame | None = None, **kwargs
) -> None:
-
+
# Set default config attributes if not passed as keyword arguments
kwargs["_config"] = _BasePlotterConfig(kind=self._kind)
-
+
super().__init__(data, x, y, **kwargs)
if annotation_data is not None:
@@ -500,13 +496,19 @@ def plot(self, data, x, y, **kwargs):
self._modify_y_range((0, self.data[y].max()), (0, 0.1))
-class SpectrumPlot(ComplexPlot, ABC):
+class SpectrumPlot(BaseMSPlotter, ABC):
@property
def _kind(self):
return "spectrum"
def __init__(
- self, data, x, y, reference_spectrum: DataFrame | None = None, mirror_spectrum: bool = False, **kwargs
+ self,
+ data,
+ x,
+ y,
+ reference_spectrum: DataFrame | None = None,
+ mirror_spectrum: bool = False,
+ **kwargs,
) -> None:
# Set default config attributes if not passed as keyword arguments
@@ -578,16 +580,24 @@ def _prepare_data(
return spectrum, reference_spectrum
-class FeatureHeatmapPlot(ComplexPlot, ABC):
+class FeatureHeatmapPlot(BaseMSPlotter, ABC):
# need to inherit from ChromatogramPlot and SpectrumPlot for get_line_renderer and get_vline_renderer methods respectively
@property
def _kind(self):
return "feature_heatmap"
def __init__(
- self, data, x, y, z, zlabel=None, add_marginals=False, annotation_data: DataFrame | None = None, **kwargs
+ self,
+ data,
+ x,
+ y,
+ z,
+ zlabel=None,
+ add_marginals=False,
+ annotation_data: DataFrame | None = None,
+ **kwargs,
) -> None:
-
+
# Set default config attributes if not passed as keyword arguments
kwargs["_config"] = _BasePlotterConfig(kind=self._kind)
@@ -596,11 +606,11 @@ def __init__(
self.zlabel = zlabel
self.add_marginals = add_marginals
-
+
if annotation_data is not None:
self.annotation_data = annotation_data.copy()
else:
- self.annotation_data = None
+ self.annotation_data = None
super().__init__(data, x, y, z=z, **kwargs)
self.plot(x, y, z, **kwargs)
@@ -666,11 +676,11 @@ def create_x_axis_plot(self, x, z, class_kwargs) -> "figure":
x_config = self._config.copy()
x_config.ylabel = self.zlabel
x_config.y_axis_location = "right"
- x_config.legend.show = True
+ x_config.legend.show = True
x_config.legend.loc = "right"
color_gen = ColorGenerator()
-
+
# remove legend from class_kwargs to update legend args for x axis plot
class_kwargs.pop("legend", None)
class_kwargs.pop("ylabel", None)
@@ -697,7 +707,7 @@ def create_y_axis_plot(self, y, z, class_kwargs) -> "figure":
y_config.y_axis_location = "left"
y_config.legend.show = True
y_config.legend.loc = "below"
-
+
# remove legend from class_kwargs to update legend args for y axis plot
class_kwargs.pop("legend", None)
class_kwargs.pop("xlabel", None)
@@ -715,7 +725,7 @@ def create_y_axis_plot(self, y, z, class_kwargs) -> "figure":
@abstractmethod
def combine_plots(self, x_fig, y_fig):
pass
-
+
@abstractmethod
def _add_box_boundaries(self, annotation_data):
"""
diff --git a/pyopenms_viz/plotting/_matplotlib/__init__.py b/pyopenms_viz/_matplotlib/__init__.py
similarity index 81%
rename from pyopenms_viz/plotting/_matplotlib/__init__.py
rename to pyopenms_viz/_matplotlib/__init__.py
index 2f873236..3b8a4b58 100644
--- a/pyopenms_viz/plotting/_matplotlib/__init__.py
+++ b/pyopenms_viz/_matplotlib/__init__.py
@@ -1,8 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING
+from ..constants import IS_SPHINX_BUILD
-from pyopenms_viz.plotting._matplotlib.core import (
+from .core import (
MATPLOTLIBLinePlot,
MATPLOTLIBVLinePlot,
MATPLOTLIBScatterPlot,
@@ -13,7 +14,10 @@
)
if TYPE_CHECKING:
- from pyopenms_viz.plotting._matplotlib.core import MATPLOTLIBPlot
+ from .core import MATPLOTLIBPlot
+
+if IS_SPHINX_BUILD:
+ from .core import MATPLOTLIB_MSPlotter, MATPLOTLIBPlot
PLOT_CLASSES: dict[str, type[MATPLOTLIBPlot]] = {
"line": MATPLOTLIBLinePlot,
diff --git a/pyopenms_viz/plotting/_matplotlib/core.py b/pyopenms_viz/_matplotlib/core.py
similarity index 95%
rename from pyopenms_viz/plotting/_matplotlib/core.py
rename to pyopenms_viz/_matplotlib/core.py
index 494e47e5..40cbfa77 100644
--- a/pyopenms_viz/plotting/_matplotlib/core.py
+++ b/pyopenms_viz/_matplotlib/core.py
@@ -15,12 +15,12 @@
LinePlot,
VLinePlot,
ScatterPlot,
- ComplexPlot,
+ BaseMSPlotter,
ChromatogramPlot,
MobilogramPlot,
SpectrumPlot,
FeatureHeatmapPlot,
- APPEND_PLOT_DOC
+ APPEND_PLOT_DOC,
)
@@ -259,7 +259,7 @@ def plot(
return ax, (legend_lines, legend_labels)
-class MATPLOTLIBComplexPlot(ComplexPlot, MATPLOTLIBPlot, ABC):
+class MATPLOTLIB_MSPlotter(BaseMSPlotter, MATPLOTLIBPlot, ABC):
def get_line_renderer(self, data, x, y, **kwargs) -> None:
return MATPLOTLIBLinePlot(data, x, y, **kwargs)
@@ -277,8 +277,9 @@ def _create_tooltips(self):
# No tooltips for MATPLOTLIB because it is not interactive
return None, None
+
@APPEND_PLOT_DOC
-class MATPLOTLIBChromatogramPlot(MATPLOTLIBComplexPlot, ChromatogramPlot):
+class MATPLOTLIBChromatogramPlot(MATPLOTLIB_MSPlotter, ChromatogramPlot):
"""
Class for assembling a matplotlib extracted ion chromatogram plot
"""
@@ -302,7 +303,7 @@ def _add_peak_boundaries(self, annotation_data):
)
legend_items = []
- legend_labels = []
+ legend_labels = []
for idx, (_, feature) in enumerate(annotation_data.iterrows()):
use_color = next(color_gen)
left_vlne = self.fig.vlines(
@@ -323,8 +324,8 @@ def _add_peak_boundaries(self, annotation_data):
)
legend_items.append(left_vlne)
- if 'name' in annotation_data.columns:
- use_name = feature['name']
+ if "name" in annotation_data.columns:
+ use_name = feature["name"]
else:
use_name = f"Feature {idx}"
if "q_value" in annotation_data.columns:
@@ -360,7 +361,7 @@ class MATPLOTLIBMobilogramPlot(MATPLOTLIBChromatogramPlot, MobilogramPlot):
@APPEND_PLOT_DOC
-class MATPLOTLIBSpectrumPlot(MATPLOTLIBComplexPlot, SpectrumPlot):
+class MATPLOTLIBSpectrumPlot(MATPLOTLIB_MSPlotter, SpectrumPlot):
"""
Class for assembling a matplotlib spectrum plot
"""
@@ -368,7 +369,7 @@ class MATPLOTLIBSpectrumPlot(MATPLOTLIBComplexPlot, SpectrumPlot):
pass
-class MATPLOTLIBFeatureHeatmapPlot(MATPLOTLIBComplexPlot, FeatureHeatmapPlot):
+class MATPLOTLIBFeatureHeatmapPlot(MATPLOTLIB_MSPlotter, FeatureHeatmapPlot):
"""
Class for assembling a matplotlib feature heatmap plot
"""
@@ -428,7 +429,7 @@ def create_y_axis_plot(self, y, z, class_kwargs) -> "figure":
y_config.legend.loc = "below"
y_config.legend.orientation = "horizontal"
y_config.legend.bbox_to_anchor = (1, -0.4)
-
+
# remove legend from class_kwargs to update legend args for y axis plot
class_kwargs.pop("legend", None)
class_kwargs.pop("xlabel", None)
@@ -460,7 +461,7 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs):
cmap="afmhot_r",
**other_kwargs,
)
-
+
if self.annotation_data is not None:
self._add_box_boundaries(self.annotation_data)
@@ -482,12 +483,12 @@ def create_main_plot_marginals(self, x, y, z, class_kwargs, other_kwargs):
self.ax_grid[1, 1].set_yticklabels([])
self.ax_grid[1, 1].set_yticks([])
self.ax_grid[1, 1].legend_ = None
-
+
def _add_box_boundaries(self, annotation_data):
if self.by is not None:
legend = self.fig.get_legend()
self.fig.add_artist(legend)
-
+
color_gen = ColorGenerator(
colormap=self.feature_config.colormap, n=annotation_data.shape[0]
)
@@ -504,15 +505,19 @@ def _add_box_boundaries(self, annotation_data):
height = abs(y1 - y0)
color = next(color_gen)
- custom_lines = Rectangle((x0, y0), width, height,
- fill=False,
- edgecolor=color,
- linestyle=self.feature_config.line_type,
- linewidth=self.feature_config.line_width)
+ custom_lines = Rectangle(
+ (x0, y0),
+ width,
+ height,
+ fill=False,
+ edgecolor=color,
+ linestyle=self.feature_config.line_type,
+ linewidth=self.feature_config.line_width,
+ )
self.fig.add_patch(custom_lines)
- if 'name' in annotation_data.columns:
- use_name = feature['name']
+ if "name" in annotation_data.columns:
+ use_name = feature["name"]
else:
use_name = f"Feature {idx}"
if "q_value" in annotation_data.columns:
diff --git a/pyopenms_viz/plotting/_misc.py b/pyopenms_viz/_misc.py
similarity index 100%
rename from pyopenms_viz/plotting/_misc.py
rename to pyopenms_viz/_misc.py
diff --git a/pyopenms_viz/plotting/_plotly/__init__.py b/pyopenms_viz/_plotly/__init__.py
similarity index 74%
rename from pyopenms_viz/plotting/_plotly/__init__.py
rename to pyopenms_viz/_plotly/__init__.py
index 5655072b..4c820125 100644
--- a/pyopenms_viz/plotting/_plotly/__init__.py
+++ b/pyopenms_viz/_plotly/__init__.py
@@ -1,8 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING
+from ..constants import IS_SPHINX_BUILD
-from pyopenms_viz.plotting._plotly.core import (
+from .core import (
PLOTLYLinePlot,
PLOTLYVLinePlot,
PLOTLYScatterPlot,
@@ -13,9 +14,12 @@
)
if TYPE_CHECKING:
- from pyopenms_viz.plotting._plotly.core import PLOTLYPlot
+ from .core import PLOTLYPlotter
-PLOT_CLASSES: dict[str, type[PLOTLYPlot]] = {
+if IS_SPHINX_BUILD:
+ from .core import PLOTLY_MSPlotter, PLOTLYPlotter
+
+PLOT_CLASSES: dict[str, type[PLOTLYPlotter]] = {
"line": PLOTLYLinePlot,
"vline": PLOTLYVLinePlot,
"scatter": PLOTLYScatterPlot,
diff --git a/pyopenms_viz/plotting/_plotly/core.py b/pyopenms_viz/_plotly/core.py
similarity index 91%
rename from pyopenms_viz/plotting/_plotly/core.py
rename to pyopenms_viz/_plotly/core.py
index 164ff699..fd9bb5a0 100644
--- a/pyopenms_viz/plotting/_plotly/core.py
+++ b/pyopenms_viz/_plotly/core.py
@@ -2,7 +2,7 @@
from abc import ABC
-from typing import TYPE_CHECKING, Literal, List, Tuple, Union
+from typing import List, Tuple, Union
import plotly.graph_objects as go
from plotly.graph_objs import Figure
@@ -17,20 +17,20 @@
LinePlot,
VLinePlot,
ScatterPlot,
- ComplexPlot,
+ BaseMSPlotter,
ChromatogramPlot,
MobilogramPlot,
SpectrumPlot,
FeatureHeatmapPlot,
- APPEND_PLOT_DOC
+ APPEND_PLOT_DOC,
)
from .._config import bokeh_line_dash_mapper
-from pyopenms_viz.plotting._misc import ColorGenerator
-from pyopenms_viz.constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON
+from .._misc import ColorGenerator
+from ..constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON
-class PLOTLYPlot(BasePlotter, ABC):
+class PLOTLYPlotter(BasePlotter, ABC):
"""
Base class for assembling a Ploty plot
"""
@@ -189,7 +189,7 @@ def show(self, **kwargs):
self.fig.show(**kwargs)
-class PLOTLYLinePlot(PLOTLYPlot, LinePlot):
+class PLOTLYLinePlot(PLOTLYPlotter, LinePlot):
"""
Class for assembling a set of line plots in plotly
"""
@@ -228,7 +228,7 @@ def plot( # type: ignore[override]
return fig, None
-class PLOTLYVLinePlot(PLOTLYPlot, VLinePlot):
+class PLOTLYVLinePlot(PLOTLYPlotter, VLinePlot):
@classmethod
@APPEND_PLOT_DOC
@@ -280,7 +280,7 @@ def plot(cls, fig, data, x, y, by=None, **kwargs) -> Tuple[Figure, "Legend"]:
return fig, None
-class PLOTLYScatterPlot(PLOTLYPlot, ScatterPlot):
+class PLOTLYScatterPlot(PLOTLYPlotter, ScatterPlot):
@classmethod
@APPEND_PLOT_DOC
@@ -317,7 +317,7 @@ def plot(cls, fig, data, x, y, by=None, **kwargs) -> Tuple[Figure, "Legend"]:
return fig, None
-class PLOTLYComplexPlot(ComplexPlot, PLOTLYPlot, ABC):
+class PLOTLY_MSPlotter(BaseMSPlotter, PLOTLYPlotter, ABC):
def get_line_renderer(self, data, x, y, **kwargs) -> None:
return PLOTLYLinePlot(data, x, y, **kwargs)
@@ -367,7 +367,7 @@ def _create_tooltips(self):
return "
".join(TOOLTIPS), column_stack(custom_hover_data)
-class PLOTLYChromatogramPlot(PLOTLYComplexPlot, ChromatogramPlot):
+class PLOTLYChromatogramPlot(PLOTLY_MSPlotter, ChromatogramPlot):
def _add_peak_boundaries(self, annotation_data, **kwargs):
color_gen = ColorGenerator(
@@ -390,10 +390,12 @@ def _add_peak_boundaries(self, annotation_data, **kwargs):
y=[feature["apexIntensity"], 0, 0, feature["apexIntensity"]],
opacity=0.5,
line=dict(
- color = next(color_gen),
- dash=bokeh_line_dash_mapper(self.feature_config.line_type, 'plotly'),
- width=self.feature_config.line_width
+ color=next(color_gen),
+ dash=bokeh_line_dash_mapper(
+ self.feature_config.line_type, "plotly"
),
+ width=self.feature_config.line_width,
+ ),
name=legend_label,
)
)
@@ -407,7 +409,7 @@ class PLOTLYMobilogramPlot(PLOTLYChromatogramPlot, MobilogramPlot):
pass
-class PLOTLYSpectrumPlot(PLOTLYComplexPlot, SpectrumPlot):
+class PLOTLYSpectrumPlot(PLOTLY_MSPlotter, SpectrumPlot):
def _prepare_data(
self, spectrum: DataFrame, y: str, reference_spectrum: DataFrame | None
) -> Tuple[List]:
@@ -423,7 +425,7 @@ def _prepare_data(
return spectrum, reference_spectrum
-class PLOTLYFeatureHeatmapPlot(PLOTLYComplexPlot, FeatureHeatmapPlot):
+class PLOTLYFeatureHeatmapPlot(PLOTLY_MSPlotter, FeatureHeatmapPlot):
def create_main_plot(self, x, y, z, class_kwargs, other_kwargs):
scatterPlot = self.get_scatter_renderer(self.data, x, y, **class_kwargs)
@@ -440,7 +442,7 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs):
),
**other_kwargs,
)
-
+
if self.annotation_data is not None:
self._add_box_boundaries(self.annotation_data)
@@ -532,7 +534,7 @@ def combine_plots(self, x_fig, y_fig):
# Update yaxis properties
fig_m.update_yaxes(title_text=self.zlabel, row=1, col=2)
fig_m.update_yaxes(title_text=self.ylabel, row=2, col=1)
-
+
# Remove axes for first quadrant
fig_m.update_xaxes(visible=False, row=1, col=1)
fig_m.update_yaxes(visible=False, row=1, col=1)
@@ -554,11 +556,11 @@ def _add_box_boundaries(self, annotation_data, **kwargs):
x1 = feature["rightWidth"]
y0 = feature["IM_leftWidth"]
y1 = feature["IM_rightWidth"]
-
+
color = next(color_gen)
-
- if 'name' in annotation_data.columns:
- use_name = feature['name']
+
+ if "name" in annotation_data.columns:
+ use_name = feature["name"]
else:
use_name = f"Feature {idx}"
if "q_value" in annotation_data.columns:
@@ -566,17 +568,25 @@ def _add_box_boundaries(self, annotation_data, **kwargs):
else:
legend_label = f"{use_name}"
self.fig.add_trace(
- go.Scatter(
- x=[x0, x1, x1, x0, x0], # Start and end at the same point to close the shape
- y=[y0, y0, y1, y1, y0],
- mode='lines',
- fill='none',
- opacity=0.5,
- line=dict(
- color = color,
- width=self.feature_config.line_width,
- dash=bokeh_line_dash_mapper(self.feature_config.line_type, 'plotly')
+ go.Scatter(
+ x=[
+ x0,
+ x1,
+ x1,
+ x0,
+ x0,
+ ], # Start and end at the same point to close the shape
+ y=[y0, y0, y1, y1, y0],
+ mode="lines",
+ fill="none",
+ opacity=0.5,
+ line=dict(
+ color=color,
+ width=self.feature_config.line_width,
+ dash=bokeh_line_dash_mapper(
+ self.feature_config.line_type, "plotly"
),
- name=legend_label
- )
- )
\ No newline at end of file
+ ),
+ name=legend_label,
+ )
+ )
diff --git a/pyopenms_viz/constants.py b/pyopenms_viz/constants.py
index 76d6c976..7236c8f3 100644
--- a/pyopenms_viz/constants.py
+++ b/pyopenms_viz/constants.py
@@ -11,5 +11,22 @@
######################
## Icons
-PEAK_BOUNDARY_ICON = Image.open(os.path.normpath(os.path.join(PYOPENMS_VIZ_DIRNAME, 'assets/img/peak_boundary.png')))
-FEATURE_BOUNDARY_ICON = Image.open(os.path.normpath(os.path.join(PYOPENMS_VIZ_DIRNAME, 'assets/img/feature_boundary.png')))
\ No newline at end of file
+PEAK_BOUNDARY_ICON = Image.open(
+ os.path.normpath(os.path.join(PYOPENMS_VIZ_DIRNAME, "assets/img/peak_boundary.png"))
+)
+FEATURE_BOUNDARY_ICON = Image.open(
+ os.path.normpath(
+ os.path.join(PYOPENMS_VIZ_DIRNAME, "assets/img/feature_boundary.png")
+ )
+)
+
+
+######################
+## Determine if running in SPHINX build
+IS_SPHINX_BUILD = False
+try:
+ import sphinx
+
+ IS_SPHINX_BUILD = hasattr(sphinx, "application")
+except ImportError:
+ pass # Not running SPHINX
diff --git a/pyopenms_viz/datastructures/MSChromatogram.py b/pyopenms_viz/datastructures/MSChromatogram.py
deleted file mode 100644
index 417df95d..00000000
--- a/pyopenms_viz/datastructures/MSChromatogram.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""
-Schema definition for mass spectrometry chromatogram. This is used to store any chromatogram object
-"""
-
-REQUIRED_CHROMATOGRAM_DATAFRAME_COLUMNS = {
- "time": "Numeric column representing the retention time (in seconds) of the chromatographic peaks.",
- "intensity": "Numeric column representing the intensity (abundance) of the signal at each time point.",
-}
-
-OPTIONAL_METADATA_CHROMATOGRAM_DATAFRAME_COLUMNS = {
- "native_id" : "Chromatogram id, necessary if multiple chromatograms are in the same dataframe.",
- "chromatogram_type": "Type of chromatogram must be one of: MASS_CHROMATOGRAM, TOTAL_ION_CURRENT_CHROMATOGRAM, SELECTED_ION_CURRENT_CHROMATOGRAM, BASEPEAK_CHROMATOGRAM, SELECTED_ION_MONITORING_CHROMATOGRAM, SELECTED_REACTION_MONITORING_CHROMATOGRAM, ELECTROMAGNETIC_RADIATION_CHROMATOGRAM, ABSORPTION_CHROMATOGRAM, EMISSION_CHROMATOGRAM",
- "ms_level": "Integer column indicating the MS level (1 for MS1, 2 for MS2, etc.).",
- "sequence": "String column representing the peptide sequence.",
- "modified_sequence": "String column representing the modified peptide sequence. Modification can be represented using the UniMod ontology, either with the UniMod Accession (UniMod: 21) or with the UniMod Codename (Phospho).",
- "precursor_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the precursor ion.",
- "precursor_charge": "Integer column representing the charge state of the precursor ion.",
- "product_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the product ion.",
- "product_charge": "Integer column representing the charge state of the product ion.",
- "annotation": "String column representing the annotation of the spectrum, such as the fragment ion series."
-}
-
diff --git a/pyopenms_viz/datastructures/MSSpectrum.py b/pyopenms_viz/datastructures/MSSpectrum.py
deleted file mode 100644
index f69df7ae..00000000
--- a/pyopenms_viz/datastructures/MSSpectrum.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""
-Schema definition for mass spectrometry spectrum.
-"""
-
-REQUIRED_SPECTRUM_DATAFRAME_COLUMNS = {
- "mz": "Numeric column representing the mass-to-charge ratio (m/z) values of the peaks in the mass spectrum.",
- "intensity": "Numeric column representing the intensity (abundance) of the peaks in the mass spectrum."
-}
-
-OPTIONAL_METADATA_SPECTRUM_DATAFRAME_COLUMNS = {
- "native_id": "String column representing the native identifier (i.e. scan number) of the spectrum.",
- "ms_level": "Integer column indicating the MS level (1 for MS1, 2 for MS2, etc.).",
- "time": "Numeric column representing the retention time (in seconds) of the spectrum.",
- "ion_mobility": "Numeric column representing the ion mobility.",
- "sequence": "String column representing the peptide sequence.",
- "modified_sequence": "String column representing the modified peptide sequence.",
- "precursor_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the precursor ion.",
- "precursor_charge": "Integer column representing the charge state of the precursor ion.",
- "product_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the product ion.",
- "product_charge": "Integer column representing the charge state of the product ion.",
- "annotation": "String column representing the annotation of the peak, such as the fragment ion series."
-}
\ No newline at end of file
diff --git a/pyopenms_viz/datastructures/SchemaDefinition.py b/pyopenms_viz/datastructures/SchemaDefinition.py
deleted file mode 100644
index a9a2d53f..00000000
--- a/pyopenms_viz/datastructures/SchemaDefinition.py
+++ /dev/null
@@ -1,180 +0,0 @@
-from enum import Enum
-from typing import List, Union
-import pandas as pd
-
-""" Outlines Column Types. Other columns are permitted in the class however they will be ignored """
-
-class DataType(Enum):
- NUMERIC=0
- STRING=1
-
-class ColumnType(Enum):
- REQUIRED=0 # Plots cannot be generated without these columns
- OPTIONAL=1 # Used to add supplementary information to the plot, but not required for plotting the object
-
-class ColumnSchema:
- def __init__(self,
- name: Union[str, List[str]],
- dataType: DataType = None,
- description: str = None,
- columnType: ColumnType = ColumnType.OPTIONAL):
-
- """Define a column schema
-
- Args:
- name (str): Name of column
- dataType (DataType, Optional): Datatype of column. Defaults to None.
- description (str, Optional): Description of information that column is storing. Defaults to None.
- columnType (ColumnType, optional): Whether column is REQUIRED or OPTIONAL. Defaults to OPTIONAL
- """
-
- self.name = name
- self.dataType = dataType
- self.description = description
- self.columnType = columnType
-
- def __str__(self):
- return f"ColumnSchema: {self.name} ({self.dataType})"
-
-class DataFrameSchema:
- def __init__(self,
- name: str,
- columns: List[ColumnSchema],
- description: str = None):
-
- """Define a table schema
-
- Args:
- columns (List[ColumnSchema]): List of ColumnSchema objects
- description (str, Optional): Description of table. Defaults to None.
- """
- self.name = name
- self.columns = columns
- self.description = description
-
- def __str__(self):
- return f"GenericTableSchema: {self.name}. Required Columns: {[i.name for i in self.getRequiredColumns()]}"
-
- def getOptionalColumns(self):
- return [col for col in self.columns if col.columnType == ColumnType.OPTIONAL]
-
- def getRequiredColumns(self):
- return [col for col in self.columns if col.columnType == ColumnType.REQUIRED]
-
- def validateDataFrame(self, df):
- for col in self.getRequiredColumns():
- # if column name is a list, then exactly one column must be present
- if isinstance(col.name, List):
- intersect = set(df.columns).intersection(set(col.name))
- if len(intersect) == 0:
- raise ValueError(f"Column {col.name} is required for {self.name} schema")
- elif len(intersect) > 1:
- raise ValueError(f"Only one of {col.name} columns is allowed for {self.name} schema")
- nameFound = False
- else: # assume column name is a string
- if col.name not in df.columns:
- raise ValueError(f"Column {col.name} is required for {self.name} schema")
-
- if col.dataType == DataType.NUMERIC:
- assert pd.str.isnumeric(df[col.name])
-
- for col in self.getOptionalColumns():
- if col.dataType == DataType.NUMERIC:
- assert pd.str.isnumeric(df[col.name])
- return True
-
-
-############################################################################################################
-#### DEFINED CHROMATOGRAM SCHEMAS ####
-############################################################################################################
-
-REQUIRED_CHROMATOGRAM_COLUMNS = [
- ColumnSchema("time",
- dataType=DataType.NUMERIC,
- description="Numeric column representing the retention time (in seconds) of the chromatographic peaks.",
- columnType=ColumnType.REQUIRED),
-
- ColumnSchema("intensity",
- dataType=DataType.NUMERIC,
- description="Numeric column representing the intensity (abundance) of the signal at each time point.",
- columnType=ColumnType.REQUIRED)]
-
-MULTIPLE_CHROMATOGRAMS_COLUMNS = REQUIRED_CHROMATOGRAM_COLUMNS + [ ColumnSchema("label",
- dataType=DataType.STRING,
- description="Chromatogram label, necessary if multiple chromatograms are in the same dataframe. This column is used to label the chromatograms", columnType=ColumnType.REQUIRED)]
-
-CHROMATOGRAM = DataFrameSchema(name='chromatogram', columns=REQUIRED_CHROMATOGRAM_COLUMNS, description="Base chromatogram schema")
-MULTIPLE_CHROMATOGRAMS = DataFrameSchema(name='multiple chromatograms', columns=REQUIRED_CHROMATOGRAM_COLUMNS + REQUIRED_CHROMATOGRAM_COLUMNS, description="Multiple chromatograms in a single DataFrame schema")
-
-############################################################################################################
-#### DEFINED SPECTRUM SCHEMAS ####
-############################################################################################################
-SPECTRUM_COLUMNS = [
- ColumnSchema("mz",
- dataType=DataType.NUMERIC,
- description="Numeric column representing the mass-to-charge ratio (m/z) values of the peaks in the mass spectrum.",
- columnType=ColumnType.REQUIRED),
-
- ColumnSchema("intensity",
- dataType=DataType.NUMERIC,
- description="Numeric column representing the intensity (abundance) of the peaks in the mass spectrum.",
- columnType=ColumnType.REQUIRED),
-
- ColumnSchema("annotation",
- dataType=DataType.STRING,
- description="String column representing the annotation of the spectrum, such as the fragment ion series.",
- columnType=ColumnType.OPTIONAL)
-]
-
-MULTIPLE_SPECTRA_COLUMNS = SPECTRUM_COLUMNS + [ ColumnSchema(name=["label", "native_id"],
- dataType=DataType.STRING,
- description="Spectrum label, necessary if multiple spectra are in the same dataframe. This column is used to label the spectra", columnType=ColumnType.REQUIRED)]
-
-SPECTRUM = DataFrameSchema(name='spectrum', columns=SPECTRUM_COLUMNS, description="Base spectrum schema, only a single spectrum is present")
-MULTIPLE_SPECTRA = DataFrameSchema(name='multiple spectra', columns=SPECTRUM_COLUMNS + SPECTRUM_COLUMNS, description="Multiple spectra in a single DataFrame")
-
-############################################################################################################
-#### DEFINED CHROMATOGRAM FEATURE SCHEMAS ####
-############################################################################################################
-REQUIRED_CHROMATOGRAM_FEATURE_COLUMNS = [ ColumnSchema(name=['RT', 'rt_apex'],
- dataType=DataType.NUMERIC,
- description="Numeric column representing the retention time of the peak apex.")]
-
-# Note: only include columns that could be used in plotting (e.g. exclude area because this will not be plotted)
-CHROMATOGRAM_FEATURE = DataFrameSchema(name='chromatogram feature',
- columns=[
- ColumnSchema(name=["rt_apex", 'RT'],
- dataType=DataType.NUMERIC,
- description="Numeric column representing the retention time of the peak apex."),
-
- ColumnSchema(name=["feature_id", 'id', 'label'],
- dataType=DataType.STRING,
- description="Represent the Id of the feature",
- columnType=ColumnType.OPTIONAL),
-
- ColumnSchema(name=["left_width", 'left_boundary'],
- dataType=DataType.NUMERIC,
- description="Numeric column representing the left boundary of the peak.",
- columnType=ColumnType.OPTIONAL),
-
- ColumnSchema(name=["right_width", 'right_boundary'],
- dataType=DataType.NUMERIC,
- description="Numeric column representing the right boundary of the peak.",
- columnType=ColumnType.OPTIONAL),
-
- ColumnSchema(name=["q_value"],
- dataType=DataType.NUMERIC,
- description="Numeric column representing the q-value of the peak.",
- columnType=ColumnType.OPTIONAL),
-
- ColumnSchema(name=["rank"],
- dataType=DataType.NUMERIC,
- description="Numeric column representing the rank of the peak. (1 is best rank)",
- columnType=ColumnType.OPTIONAL)
- ],
- description="Chromatogram feature schema")
-
-############################################################################################################
-#### DEFINED SPECTRUM FEATURE SCHEMAS ####
-############################################################################################################
-##TODO, Is this needed?
\ No newline at end of file
diff --git a/pyopenms_viz/plotting/__init__.py b/pyopenms_viz/plotting/__init__.py
deleted file mode 100644
index a99f028d..00000000
--- a/pyopenms_viz/plotting/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from pandas.plotting._core import PlotAccessor
-
-__all__ = ["PlotAccessor"]
diff --git a/pyopenms_viz/pyopenmsviz.py b/pyopenms_viz/pyopenmsviz.py
deleted file mode 100644
index d85b9095..00000000
--- a/pyopenms_viz/pyopenmsviz.py
+++ /dev/null
@@ -1,14 +0,0 @@
-from pandas.core.frame import DataFrame
-
-from pyopenms_viz.plotting._core import PlotAccessor
-
-
-class pyopenmsviz:
- def __init__(self, data: DataFrame) -> None:
- # Validate the input data frame.
- if not isinstance(data, DataFrame):
- raise TypeError(f"Input data must be a pandas DataFrame, not {type(data)}")
- self._data = data
-
- def plot(self, *args, **kwargs):
- return PlotAccessor(self._data)(*args, **kwargs)
\ No newline at end of file
diff --git a/pyopenms_viz/testing/BokehSnapshotExtension.py b/pyopenms_viz/testing/BokehSnapshotExtension.py
deleted file mode 100644
index 07eba69c..00000000
--- a/pyopenms_viz/testing/BokehSnapshotExtension.py
+++ /dev/null
@@ -1,155 +0,0 @@
-"""
-pyopenms_viz/testing/BokehSnapshotExtension
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-"""
-from typing import Any
-from bokeh.embed import file_html
-import json
-from syrupy.data import SnapshotCollection
-from syrupy.extensions.single_file import SingleFileSnapshotExtension
-from syrupy.types import SerializableData
-from bokeh.resources import CDN
-from html.parser import HTMLParser
-
-class BokehHTMLParser(HTMLParser):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.recording = False # boolean flag to indicate if we are currently recording the data
- self.bokehJson = None # data to extract
-
- def handle_starttag(self, tag, attrs):
- if tag == 'script' and self.bokehJson is None:
- attrs_dict = dict(attrs)
- if attrs_dict.get('type') == 'application/json':
- self.recording = True
-
- def handle_endtag(self, tag):
- if tag == 'script' and self.recording:
- self.recording = False
-
- def handle_data(self, data):
- if self.recording and self.bokehJson is None:
- self.bokehJson = data
-
-class BokehSnapshotExtension(SingleFileSnapshotExtension):
- """
- Handles Bokeh Snapshots. Snapshots are stored as html files and the bokeh .json output from the html files are compared.
- """
- _file_extension = "html"
-
- def matches(self, *, serialized_data, snapshot_data):
- """
- Determine if the serialized data matches the snapshot data.
-
- Args:
- serialized_data: Data produced by the test
- snapshot_data: Saved data from a previous test run
-
- """
- json_snapshot = self.extract_bokeh_json(snapshot_data)
- json_serialized = self.extract_bokeh_json(serialized_data)
-
- # get the keys which store the json
- # NOTE: keys are unique identifiers and are not supposed to be equal
- # but the json objects they contain should be equal
- key_json_snapshot = list(json_snapshot.keys())[0]
- key_json_serialized = list(json_serialized.keys())[0]
-
- return BokehSnapshotExtension.compare_json(json_snapshot[key_json_snapshot], json_serialized[key_json_serialized])
-
- def extract_bokeh_json(self, html: str) -> json:
- """
- Extract the bokeh json from the html file.
-
- Args:
- html (str): string containing the html data
-
- Returns:
- json: bokeh json found in the html
- """
- parser = BokehHTMLParser()
- parser.feed(html)
- return json.loads(parser.bokehJson)
-
- @staticmethod
- def compare_json(json1, json2):
- """
- Compare two bokeh json objects. This function acts recursively
-
- Args:
- json1: first object
- json2: second object
-
- Returns:
- bool: True if the objects are equal, False otherwise
- """
- if isinstance(json1, dict) and isinstance(json2, dict):
- for key in json1.keys():
- if key not in json2:
- print(f'Key {key} not in second json')
- return False
- elif key not in ['id', 'root_ids']: # add keys to ignore here
- pass
- elif not BokehSnapshotExtension.compare_json(json1[key], json2[key]):
- print(f'Values for key {key} not equal')
- return False
- return True
- elif isinstance(json1, list) and isinstance(json2, list):
- if len(json1) != len(json2):
- print('Lists have different lengths')
- return False
- # lists are unordered so we need to compare every element one by one
- for i in json1:
- if isinstance(i, dict):
- # find the corresponding dictionary in json2
- for j in json2:
- if j['type'] == i['type']:
- if not BokehSnapshotExtension.compare_json(i, j):
- print(f'Element {i} not equal to {j}')
- return False
- return True
- print(f'Element {i} not in second list')
- return False
- else:
- return json1[i] == json2[i]
- return True
- else:
- if json1 != json2:
- print(f'Values not equal: {json1} != {json2}')
- return json1 == json2
-
- def _read_snapshot_data_from_location(
- self, *, snapshot_location: str, snapshot_name: str, session_id: str
- ):
- # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L139
- try:
- with open(snapshot_location, 'r') as f:
- a = f.read()
- return a
- except OSError:
- return None
-
- @classmethod
- def _write_snapshot_collection(
- cls, *, snapshot_collection: SnapshotCollection
- ) -> None:
- # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L161
-
- filepath, data = (
- snapshot_collection.location,
- next(iter(snapshot_collection)).data,
- )
- with open(filepath, 'w') as f:
- f.write(data)
-
- def serialize(self, data: SerializableData, **kwargs: Any) -> str:
- """
- Serialize the bokeh plot as an html string (which is output to a file)
-
- Args:
- data (SerializableData): Data to serialize
-
- Returns:
- str: html string
- """
- return file_html(data, CDN)
\ No newline at end of file
diff --git a/pyopenms_viz/testing/PlotlySnapshotExtension.py b/pyopenms_viz/testing/PlotlySnapshotExtension.py
deleted file mode 100644
index 4caee48d..00000000
--- a/pyopenms_viz/testing/PlotlySnapshotExtension.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from typing import Any
-from syrupy.data import SnapshotCollection
-from syrupy.extensions.single_file import SingleFileSnapshotExtension
-from syrupy.types import SerializableData
-from plotly.io import to_json
-import json
-import math
-
-class PlotlySnapshotExtension(SingleFileSnapshotExtension):
- """
- Handles Plotly Snapshots. Snapshots are stored as json files and the json output from the files are compared.
- """
- _file_extension = "json"
-
- def matches(self, *, serialized_data, snapshot_data):
- json1 = json.loads(serialized_data)
- json2 = json.loads(snapshot_data)
- return PlotlySnapshotExtension.compare_json(json1, json2)
-
- @staticmethod
- def compare_json(json1, json2) -> bool:
- """
- Compare two plotly json objects. This function acts recursively
-
- Args:
- json1: first json
- json2: second json
-
- Returns:
- bool: True if the objects are equal, False otherwise
- """
- if isinstance(json1, dict) and isinstance(json2, dict):
- for key in json1.keys():
- if key not in json2:
- print(f'Key {key} not in second json')
- return False
- if not PlotlySnapshotExtension.compare_json(json1[key], json2[key]):
- print(f'Values for key {key} not equal')
- return False
- return True
- elif isinstance(json1, list) and isinstance(json2, list):
- if len(json1) != len(json2):
- print('Lists have different lengths')
- return False
- for i, j in zip(json1, json2):
- if not PlotlySnapshotExtension.compare_json(i, j):
- return False
- return True
- else:
- if isinstance(json1, float):
- if not math.isclose(json1, json2):
- print(f'Values not equal: {json1} != {json2}')
- return False
- else:
- if json1 != json2:
- print(f'Values not equal: {json1} != {json2}')
- return False
- return True
-
- def _read_snapshot_data_from_location(
- self, *, snapshot_location: str, snapshot_name: str, session_id: str
- ):
- # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L139
- try:
- with open(snapshot_location, 'r') as f:
- a = f.read()
- return a
- except OSError:
- return None
-
- @classmethod
- def _write_snapshot_collection(
- cls, *, snapshot_collection: SnapshotCollection
- ) -> None:
- # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L161
-
- filepath, data = (
- snapshot_collection.location,
- next(iter(snapshot_collection)).data,
- )
- with open(filepath, 'w') as f:
- f.write(data)
-
- def serialize(self, data: SerializableData, **kwargs: Any) -> str:
- """
- Serialize the data to a json string
-
- Args:
- data (SerializableData): plotly data to serialize
-
- Returns:
- str: json string of plotly plot
- """
- return to_json(data, pretty=True, engine='json')
\ No newline at end of file
diff --git a/pyopenms_viz/testing/__init__.py b/pyopenms_viz/testing/__init__.py
deleted file mode 100644
index 462624fd..00000000
--- a/pyopenms_viz/testing/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-"""
-pyopenms_viz/testing
-~~~~~~~~~~~~~~~~
-
-This package contains classes for testing pyopenms_viz. SnapShotExtension classes are based off of syrupy snapshots
-"""
-
-from .BokehSnapshotExtension import BokehSnapshotExtension
-from .PlotlySnapshotExtension import PlotlySnapshotExtension
-
-__all__ = [
- "BokehSnapshotExtension",
- "PlotlySnapshotExtension"]
\ No newline at end of file
diff --git a/pyopenms_viz/util/_decorators.py b/pyopenms_viz/util/_decorators.py
deleted file mode 100644
index ba26458f..00000000
--- a/pyopenms_viz/util/_decorators.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from dataclasses import fields
-
-def filter_unexpected_fields(cls):
- """
- Decorator function that filters unexpected fields from the keyword arguments passed to the class constructor.
-
- Args:
- cls: The class to decorate.
-
- Returns:
- The decorated class.
-
- Example:
- @filter_unexpected_fields
- class MyClass:
- def __init__(self, name, age):
- self.name = name
- self.age = age
-
- obj = MyClass(name='John', age=25, gender='Male')
- # The 'gender' keyword argument will be filtered out and not passed to the class constructor.
- """
- original_init = cls.__init__
-
- def new_init(self, *args, **kwargs):
- expected_fields = {field.name for field in fields(cls)}
- cleaned_kwargs = {key: value for key, value in kwargs.items() if key in expected_fields}
- print(f"WARNING: Filtered out unexpected fields: {set(kwargs) - expected_fields}")
- original_init(self, *args, **cleaned_kwargs)
-
- cls.__init__ = new_init
- return cls
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 4b5a5f84..0a354ac3 100644
--- a/setup.py
+++ b/setup.py
@@ -22,9 +22,9 @@
],
entry_points={
"pandas_plotting_backends": [
- "pomsvib = pyopenms_viz.plotting._bokeh",
- "pomsvip = pyopenms_viz.plotting._plotly",
- "pomsvim = pyopenms_viz.plotting._matplotlib"
+ "MSBokeh = pyopenms_viz._bokeh",
+ "MSPlotly = pyopenms_viz._plotly",
+ "MSMatplotlib = pyopenms_viz._matplotlib",
],
},
classifiers=[