diff --git a/pyopenms_viz/BasePlotter.py b/pyopenms_viz/BasePlotter.py deleted file mode 100644 index d32aa8fe..00000000 --- a/pyopenms_viz/BasePlotter.py +++ /dev/null @@ -1,124 +0,0 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -from enum import Enum -from typing import Literal, List, Tuple -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -class Engine(Enum): - PLOTLY = 1 - BOKEH = 2 - MATPLOTLIB = 3 - -# A colorset suitable for color blindness -class Colors(str, Enum): - BLUE = "#4575B4" - RED = "#D73027" - LIGHTBLUE = "#91BFDB" - ORANGE = "#FC8D59" - PURPLE = "#7B2C65" - YELLOW = "#FCCF53" - DARKGRAY = "#555555" - LIGHTGRAY = "#BBBBBB" - -@dataclass(kw_only=True) -class LegendConfig: - loc: str = 'right' - title: str = 'Legend' - fontsize: int = 10 - show: bool = True - onClick: Literal["hide", "mute"] = 'mute' # legend click policy, only valid for bokeh - bbox_to_anchor: Tuple[float, float] = (1.2, 0.5) # for fine control legend positioning in matplotlib - - @staticmethod - def _matplotlibLegendLocationMapper(loc): - ''' - Maps the legend location to the matplotlib equivalent - ''' - loc_mapper = {'right':'center right', 'left':'center left', 'above':'upper center', 'below':'lower center'} - return loc_mapper[loc] - -@dataclass(kw_only=True) -class _BasePlotterConfig(ABC): - title: str = "1D Plot" - xlabel: str = "X-axis" - ylabel: str = "Y-axis" - engine: Literal["PLOTLY", "BOKEH", "MATPLOTLIB"] = "PLOTLY" - height: int = 500 - width: int = 500 - relative_intensity: bool = False - show_legend: bool = True - show: bool = True - colormap: str = 'viridis' - legend: LegendConfig = field(default_factory=LegendConfig) - grid: bool = True - lineStyle: str = 'solid' - lineWidth: float = 1 - - @property - def engine_enum(self): - return Engine[self.engine] - - -# Abstract Class for Plotting -class _BasePlotter(ABC): - def __init__(self, config: _BasePlotterConfig) -> None: - self.config = config - self.fig = None # holds the figure object - self.main_palette = None # holds the main color palette - self.feature_palette = None # holds the feature color palette - - def updateConfig(self, **kwargs): - for key, value in kwargs.items(): - if hasattr(self.config, key): - setattr(self.config, key, value) - else: - raise ValueError(f"Invalid config setting: {key}") - - @staticmethod - def generate_colors(colormap, n): - # Use Matplotlib's built-in color palettes - cmap = plt.get_cmap(colormap, n) - colors = cmap(np.linspace(0, 1, n)) - - # Convert colors to hex format - hex_colors = ['#{:02X}{:02X}{:02X}'.format(int(r*255), int(g*255), int(b*255)) for r, g, b, _ in colors] - - return hex_colors - - def _get_n_grayscale_colors(self, n: int) -> List[str]: - """Returns n evenly spaced grayscale colors in hex format.""" - hex_list = [] - for v in np.linspace(50, 200, n): - hex = "#" - for _ in range(3): - hex += f"{int(round(v)):02x}" - hex_list.append(hex) - return hex_list - - def plot(self, data : pd.DataFrame , featureData : pd.DataFrame = None, **kwargs): - - # TODO: Assert throws errors at startup if using a test streamlit app - # ### Assert color palettes are set - # assert(self.main_palette is not None) - # assert(self.feature_palette is not None if featureData is not None else True) - - if self.config.engine_enum == Engine.PLOTLY: - return self._plotPlotly(data, **kwargs) - elif self.config.engine_enum == Engine.BOKEH: - return self._plotBokeh(data, **kwargs) - else: # self.config.engine_enum == Engine.MATPLOTLIB: - return self._plotMatplotlib(data, **kwargs) - - @abstractmethod - def _plotBokeh(self, data, **kwargs): - pass - - @abstractmethod - def _plotPlotly(self, data, **kwargs): - pass - - @abstractmethod - def _plotMatplotlib(self, data, **kwargs): - pass diff --git a/pyopenms_viz/ChromatogramPlotter.py b/pyopenms_viz/ChromatogramPlotter.py deleted file mode 100644 index 21a51652..00000000 --- a/pyopenms_viz/ChromatogramPlotter.py +++ /dev/null @@ -1,819 +0,0 @@ -from typing import List, Tuple -from pandas import DataFrame -import matplotlib.pyplot as plt -from matplotlib import cm -import numpy as np -import pandas as pd -from dataclasses import dataclass, field, fields -from typing import Literal - -from .BasePlotter import _BasePlotter, _BasePlotterConfig, Engine, LegendConfig -from .util._decorators import filter_unexpected_fields - -@dataclass(kw_only=True) -class ChromatogramFeatureConfig: - def default_legend_factory(): - return LegendConfig(title="Features", loc='right', bbox_to_anchor=(1.5, 0.5)) - - colormap: str = "viridis" - lineWidth: float = 1 - lineStyle: str = 'solid' - legend: LegendConfig = field(default_factory=default_legend_factory) - -@filter_unexpected_fields -@dataclass(kw_only=True) -class ChromatogramPlotterConfig(_BasePlotterConfig): - def default_legend_factory(): - return LegendConfig(title="Transitions") - - # Plot Aesthetics - title: str = "Chromatogram Plot" - xlabel: str = "Retention Time" - ylabel: str = "Intensity" - x_axis_col: str = "rt" - y_axis_col: str = "int" - x_axis_location: str = "below" - y_axis_location: str = "left" - min_border: int = 0 - show: bool = True - lineWidth: float = 1 - lineStyle: str = 'solid' - plot_type: str = "lineplot" - add_marginals: bool = False - featureConfig: ChromatogramFeatureConfig = field(default_factory=ChromatogramFeatureConfig) - legend: LegendConfig = field(default_factory=default_legend_factory) - - # Data Specific Attributes - ion_mobility: bool = False # if True, plot ion mobility as well in a heatmap - - -class ChromatogramPlotter(_BasePlotter): - - def __init__(self, config: _BasePlotterConfig, **kwargs) -> None: - super().__init__(config, **kwargs) - - @staticmethod - def rgb_to_hex(rgb): - """ - Converts an RGB color value to its corresponding hexadecimal representation. - - Args: - rgb (tuple): A tuple containing the RGB values as floats between 0 and 1. - - Returns: - str: The hexadecimal representation of the RGB color. - - Example: - >>> rgb_to_hex((0.5, 0.75, 1.0)) - '#7fbfff' - """ - return "#{:02x}{:02x}{:02x}".format(int(rgb[0] * 255), int(rgb[1] * 255), int(rgb[2] * 255)) - - - @staticmethod - def _get_data_ranges(arr: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, float, float, float, float, float, float]: - """ - Get parameters for plotting. - - Args: - arr (pd.DataFrame): The input DataFrame. - - Returns: - Tuple[np.ndarray, np.ndarray, float, float, float, float, float, float]: The parameters for plotting. - """ - im_arr = arr.index.to_numpy() - rt_arr = arr.columns.to_numpy() - - dw_main = rt_arr.max() - rt_arr.min() - dh_main = im_arr.max() - im_arr.min() - - rt_min, rt_max, im_min, im_max = rt_arr.min(), rt_arr.max(), im_arr.min(), im_arr.max() - - return im_arr, rt_arr, dw_main, dh_main, rt_min, rt_max, im_min, im_max - - @staticmethod - def _prepare_array(arr: pd.DataFrame) -> np.ndarray: - """ - Prepare the array for plotting. Also performs equalization and/or smoothing if specified in the configuration. - - Args: - arr (pd.DataFrame): The input DataFrame. - - Returns: - np.ndarray: The prepared array. - """ - arr = arr.to_numpy() - arr[np.isnan(arr)] = 0 - - return arr - - @staticmethod - def _get_data_as_two_dimenstional_array(data: pd.DataFrame) -> np.ndarray: - feat_arrs = {ion_trace:grp_df.pivot_table(index='im', columns='rt', values='int', aggfunc="sum") for ion_trace, grp_df in data.groupby('Annotation')} - return feat_arrs - - @staticmethod - def _integrate_data_along_dim(data: pd.DataFrame, col_dim: str) -> pd.DataFrame: - # TODO: Double check which columns are required - grouped = data.fillna({'native_id': 'NA'}).groupby(['native_id', 'ms_level', 'precursor_mz', 'Annotation', 'product_mz', col_dim])['int'].sum().reset_index() - return grouped - - ### assume that the chromatogram have the following columns: intensity, time - ### optional column to be used: annotation - ### assume that the chromatogramFeatures have the following columns: left_width, right_width (optional columns: area, q_value) - def plot(self, chromatogram, chromatogramFeatures = None, **kwargs): - #### General Data Processing before plotting #### - # sort by q_value if available - if chromatogramFeatures is not None: - if "q_value" in chromatogramFeatures.columns: - chromatogramFeatures = chromatogramFeatures.sort_values(by="q_value") - - #### compute apex intensity for features if not already computed - if chromatogramFeatures is not None: - if "apexIntensity" not in chromatogramFeatures.columns: - all_apexIntensity = [] - for _, feature in chromatogramFeatures.iterrows(): - apexIntensity = 0 - for _, row in chromatogram.iterrows(): - if row["rt"] >= feature["leftWidth"] and row["rt"] <= feature["rightWidth"] and row["int"] > apexIntensity: - apexIntensity = row["int"] - all_apexIntensity.append(apexIntensity) - - chromatogramFeatures["apexIntensity"] = all_apexIntensity - - # compute colormaps based on the number of transitions and features - self.main_palette = self.generate_colors(self.config.colormap, len(chromatogram["Annotation"].unique()) if 'Annotation' in chromatogram.columns else 1) - self.feature_palette = self.generate_colors(self.config.featureConfig.colormap, len(chromatogramFeatures)) if chromatogramFeatures is not None else None - - return super().plot(chromatogram, chromatogramFeatures, **kwargs) - - def _plotBokeh(self, data: DataFrame, chromatogramFeatures: DataFrame = None): - - def _plotLines(self, data: pd.DataFrame, chromatogramFeatures: DataFrame = None): - from bokeh.plotting import figure, show - from bokeh.models import ColumnDataSource, Legend - - # Tooltips for interactive information - TOOLTIPS = [ - ("index", "$index"), - ("Retention Time", "@rt{0.2f}"), - ("Intensity", "@int{0.2f}"), - ("m/z", "@mz{0.4f}") - ] - - if "Annotation" in data.columns: - TOOLTIPS.append(("Annotation", "@Annotation")) - if "product_mz" in data.columns: - TOOLTIPS.append(("Target m/z", "@product_mz{0.4f}")) - - # Create the Bokeh plot - p = figure(title=self.config.title, - x_axis_label=self.config.xlabel, - y_axis_label=self.config.ylabel, - x_axis_location=self.config.x_axis_location, - y_axis_location=self.config.y_axis_location, - width=self.config.width, - height=self.config.height, - tooltips=TOOLTIPS) - - # Create a legend - legend = Legend() - - # Create a list to store legend items - if 'Annotation' in data.columns: - legend_items = [] - i = 0 - for annotation, group_df in data.groupby('Annotation'): - source = ColumnDataSource(group_df) - line = p.line(x=self.config.x_axis_col, y=self.config.y_axis_col, source=source, line_width=self.config.lineWidth, line_color=self.main_palette[i], line_dash=self.config.lineStyle) - legend_items.append((annotation, [line])) - i+=1 - - # Add legend items to the legend - legend.items = legend_items - - # Add the legend to the plot - p.add_layout(legend, self.config.legend.loc) - - p.legend.click_policy=self.config.legend.onClick - p.legend.title = self.config.legend.title - p.legend.label_text_font_size = str(self.config.legend.fontsize) + 'pt' - - else: - source = ColumnDataSource(data) - line = p.line(x=self.config.x_axis_col, y=self.config.y_axis_col, source=source, line_width=self.config.lineWidth, line_color=self.main_palette[0], line_alpha=0.5, line_dash=self.config.lineStyle) - # Customize the plot - p.grid.visible = self.config.grid - p.toolbar_location = "above" #NOTE: This is hardcoded - - ##### Plotting chromatogram features ##### - if chromatogramFeatures is not None: - - for idx, (_, feature) in enumerate(chromatogramFeatures.iterrows()): - - leftWidth_line = p.line(x=[feature['leftWidth']] * 2, y=[0, feature['apexIntensity']], width=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], line_dash=self.config.featureConfig.lineStyle ) - rightWidth_line = p.line(x=[feature['rightWidth']] * 2, y=[0, feature['apexIntensity']], width=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], line_dash = self.config.featureConfig.lineStyle) - - if self.config.featureConfig.legend.show: - feature_legend_items = [] - if "q_value" in chromatogramFeatures.columns: - legend_msg = f'Feature {idx} (q={feature["q_value"]:.2f})' - else: - legend_msg = f'Feature {idx}' - feature_legend_items.append((legend_msg, [leftWidth_line])) - - legend = Legend(items=feature_legend_items, title=self.config.legend.title ) - p.add_layout(legend, self.config.legend.loc) - - if self.config.show: - show(p) - - return p - - def _plotHeatmap(self, data: pd.DataFrame): - from bokeh.plotting import figure, show - from bokeh.models import ColumnDataSource, Legend, HoverTool, CrosshairTool - - AFMHOT_CMAP = [self.rgb_to_hex(cm.afmhot_r(i)[:3]) for i in range(256)] - - # Get the data as a two-dimensional array - feat_arrs = self._get_data_as_two_dimenstional_array(data) - - # Get the data ranges - im_range = data.groupby('Annotation')['im'].agg(lambda x: x.max() - x.min()) - rt_range = data.groupby('Annotation')['rt'].agg(lambda x: x.max() - x.min()) - - p_hm_legends = [] - # Create a legend - legend_hm = Legend() - p = figure(x_range=(data.rt.min(), data.rt.max()), - y_range=(data.im.min(), data.im.max()), - x_axis_label="Retention Time [sec]", - y_axis_label="Ion Mobility", - # y_axis_label=None, - # y_axis_location = None, - width=700, - height=700, - min_border=0 - ) - - for annotation, df_wide in feat_arrs.items(): - arr = self._prepare_array(df_wide) - heatmap_img = p.image(image=[arr], x=data.rt.min(), y=data.im.min(), dw=rt_range.mean(), dh=im_range.min(), palette=AFMHOT_CMAP) - p_hm_legends.append((annotation, [heatmap_img])) - - # Add legend items to the legend - legend_hm.items = p_hm_legends - - # Add the legend to the plot - p.add_layout(legend_hm, 'right') - - hover = HoverTool(renderers=[heatmap_img], tooltips=[("Value", "@image")]) - linked_crosshair = CrosshairTool(dimensions="both") - p.add_tools(hover) - p.add_tools(linked_crosshair) - - p.grid.visible = False - - if self.config.add_marginals: - # Store original config values that need to be changed for marginals - show_org = self.config.show - self.config.show = False - - # Integrate the data along the retention time dimension - rt_integrated = self._integrate_data_along_dim(data, 'rt') - # Generate a lineplot for XIC - self.config.y_axis_location = "right" - p_xic = _plotLines(self, rt_integrated, chromatogramFeatures) - - # Link range of XIC plot with the main plot - p_xic.x_range = p.x_range - p_xic.width = p.width - - # Modify labels - p_xic.title = "Integrated Ion Chromatogram" - # Hide x-axis for grouped plot - p_xic.xaxis.visible = False - - # Make border 0 - p_xic.min_border = 0 - - # Integrate the data along the ion mobility dimension - im_integrated = self._integrate_data_along_dim(data, 'im') - - # Generate a lineplot for XIM - self.config.x_axis_col = 'int' - self.config.y_axis_col = 'im' - self.config.y_axis_location = "left" - self.config.legend.loc = 'below' - p_xim = _plotLines(self, im_integrated) - - # Link y-axis with heatmap - p_xim.y_range = p.y_range - p_xim.height = p.height - - p_xim.legend.orientation = "horizontal" - - # Flip x-axis range - p_xim.x_range.flipped = True - # p_mobi.y_range - - # Modify labels - p_xim.title = "Integrated Ion Mobilogram" - p_xim.title_location = "below" - p_xim.xaxis.axis_label = "Intensity" - p_xim.yaxis.axis_label = "Ion Mobility" - - # Make border 0 - p_xim.min_border = 0 - - # Heatmap mod - p.yaxis.visible = False - - # Construct Marginal Plot - from bokeh.layouts import gridplot - - # Combine the plots into a grid layout - layout = gridplot([[None, p_xic], [p_xim, p]], sizing_mode="stretch_both") - - # Reset the config values to org - self.config.show = show_org - - if self.config.show: - show(layout) - - return layout - - if self.config.show: - show(p) - - return p - - - - if self.config.plot_type == "lineplot": - return _plotLines(self, data, chromatogramFeatures) - elif self.config.plot_type == "heatmap": - return _plotHeatmap(self, data) - else: - raise ValueError(f"Invalid plot type: {type}") - - def _plotPlotly(self, data: DataFrame, chromatogramFeatures: DataFrame = None): - - def _plotLines(self, data: pd.DataFrame, chromatogramFeatures: DataFrame = None): - - import plotly.graph_objects as go - - # Create a trace for each unique annotation - traces = [] - if "Annotation" in data.columns: - for i, (annotation, group_df) in enumerate(data.groupby('Annotation')): - trace = go.Scatter( - x=group_df[self.config.x_axis_col], - y=group_df[self.config.y_axis_col], - mode='lines', - name=annotation, - line=dict( - color=self.main_palette[i], - width=self.config.lineWidth, - dash=self.config.lineStyle - ) - ) - traces.append(trace) - else: - trace = go.Scatter( - x=data[self.config.x_axis_col], - y=data[self.config.y_axis_col], - mode='lines', - name="Transition", - line=dict( - color=self.main_palette[0], - width=self.config.lineWidth, - dash=self.config.lineStyle - )) - traces.append(trace) - - - # Create the Plotly figure - fig = go.Figure(data=traces) - fig.update_layout( - title=self.config.title, - xaxis_title=self.config.xlabel, - yaxis_title=self.config.ylabel, - width=self.config.width, - height=self.config.height, - legend_title="Transition", - legend_font_size=self.config.legend.fontsize - ) - - available_columns = data.columns.tolist() - available_columns = data.columns.tolist() - custom_hover_data = [data[col] for col in ["index", "mz"] if col in available_columns] - - hover_template_parts = [ - "Index: %{customdata[0]}", - "Retention Time: %{x:.2f}", - "Intensity: %{y:.2f}", - ] - - if "mz" in available_columns: - hover_template_parts.append("m/z: %{customdata[1]:.4f}") - custom_hover_data_index = 2 - else: - custom_hover_data_index = 1 - - if "Annotation" in available_columns: - hover_template_parts.append("Annotation: %{customdata[" + str(custom_hover_data_index) + "]}") - custom_hover_data.append(data["Annotation"]) - - hovertemplate = "
".join(hover_template_parts) - - fig.update_traces( - hovertemplate=hovertemplate, - customdata=np.column_stack(custom_hover_data) - ) - - # Customize the plot - fig.update_layout( - plot_bgcolor='white', - paper_bgcolor='white', - xaxis_showgrid=True, - yaxis_showgrid=True, - xaxis_zeroline=False, - yaxis_zeroline=False - ) - - ##### Plotting chromatogram features ##### - if chromatogramFeatures is not None: - for idx, (_, feature) in enumerate(chromatogramFeatures.iterrows()): - feature_group = f"Feature {idx}" - - feature_boundary_box = fig.add_shape(type='rect', - x0=feature['leftWidth'], - y0=0, - x1=feature['rightWidth'], - y1=feature['apexIntensity'], - legendgroup=feature_group, - legendgrouptitle_text="Features", - showlegend=self.config.featureConfig.legend.show, - name=f'Feature {idx}' if "q_value" not in chromatogramFeatures.columns else f'Feature {idx} (q={feature["q_value"]:.2f})', - line=dict( - color=self.feature_palette[idx], - width=self.config.featureConfig.lineWidth, - dash=self.config.featureConfig.lineStyle) - ) - - if self.config.show: - fig.show() - return fig - - def _plotHeatmap(self, data: pd.DataFrame): - import plotly.graph_objects as go - - AFMHOT_CMAP = [self.rgb_to_hex(cm.afmhot_r(i)[:3]) for i in range(256)] - - # Get the data as a two-dimensional array - feat_arrs = self._get_data_as_two_dimenstional_array(data) - - # Create the Plotly figure - fig = go.Figure() - - # Create a trace for each unique annotation - for annotation, df_wide in feat_arrs.items(): - arr = self._prepare_array(df_wide) - fig.add_trace(go.Heatmap(z=arr, x=df_wide.columns, y=df_wide.index, colorscale=AFMHOT_CMAP, coloraxis="coloraxis", name=annotation, showlegend=True)) - - fig.update_layout(coloraxis = {'colorscale':AFMHOT_CMAP}) - - # fig.update_traces(name="Transitions", showlegend=True) - - # Customize the plot - fig.update_layout( - title=self.config.title, - xaxis_title=self.config.xlabel, - yaxis_title=self.config.ylabel, - width=self.config.width, - height=self.config.height, - legend=dict( - orientation="h", # Set the legend orientation to horizontal - y=-0.2, # Adjust the vertical position of the legend - x=0.5, # Adjust the horizontal position of the legend - xanchor="center" # Center the legend horizontally - ) - ) - - if self.config.add_marginals: - from plotly.subplots import make_subplots - - # Store original config values that need to be changed for marginals - show_org = self.config.show - self.config.show = False - - # Integrate the data along the retention time dimension - rt_integrated = self._integrate_data_along_dim(data, 'rt') - # Generate a lineplot for XIC - self.config.y_axis_location = "right" - fig_xic = _plotLines(self, rt_integrated, chromatogramFeatures) - fig_xic.update_layout(title="Integrated Ion Chromatogram") - fig_xic.update_xaxes(visible=False) - - # Integrate the data along the ion mobility dimension - im_integrated = self._integrate_data_along_dim(data, 'im') - # Generate a lineplot for XIM - self.config.x_axis_col = 'int' - self.config.y_axis_col = 'im' - self.config.y_axis_location = "left" - fig_xim = _plotLines(self, im_integrated) - fig_xim.update_layout(title="Integrated Ion Mobilogram") - fig_xim.update_xaxes(range=[0, im_integrated['int'].max()]) - fig_xim.update_yaxes(range=[im_integrated['im'].min(), im_integrated['im'].max()]) - fig_xim.update_layout(xaxis_title="Intensity", yaxis_title="Ion Mobility") - - # Create a figure with subplots - fig_m = make_subplots( - rows=2, cols=2, - shared_xaxes=True, shared_yaxes=True, - vertical_spacing=0, horizontal_spacing=0, - subplot_titles=(None, "Integrated Ion Chromatogram", "Integrated Ion Mobilogram", None), - specs=[[{}, {"type": "xy", "rowspan": 1, "secondary_y":True}], - [{"type": "xy", "rowspan": 1, "secondary_y":False}, {"type": "xy", "rowspan": 1, "secondary_y":False}]] - ) - - # Add the heatmap to the first row - for trace in fig.data: - trace.showlegend = False - trace.legendgroup = trace.name - fig_m.add_trace(trace, row=2, col=2, secondary_y=False) - - # Update the heatmap layout - fig_m.update_layout(fig.layout) - fig_m.update_yaxes(row=2, col=2, secondary_y=False) - - # Add the XIC plot to the second row - for trace in fig_xic.data: - trace.legendgroup = trace.name - fig_m.add_trace(trace, row=1, col=2, secondary_y=True) - - # Update the XIC layout - fig_m.update_layout(fig_xic.layout) - - # Make the y-axis of fig_xic independent - fig_m.update_yaxes(overwrite=True, row=1, col=2, secondary_y=True) - - # Manually adjust the domain of secondary y-axis to only span the first row of the subplot - fig_m['layout']['yaxis3']['domain'] = [0.5, 1.0] - - # Add the XIM plot to the second row - for trace in fig_xim.data: - trace.showlegend = False - trace.legendgroup = trace.name - fig_m.add_trace(trace, row=2, col=1) - - # Update the XIM layout - fig_m.update_layout(fig_xim.layout) - - # Make the x-axis of fig_xim independent - fig_m.update_xaxes(overwrite=True, row=2, col=1) - - # Reverse the x-axis range for the XIM subplot - fig_m.update_xaxes(autorange="reversed", row=2, col=1) - - # Update xaxis properties - fig_m.update_xaxes(title_text="Retention Time [sec]", row=2, col=2) - fig_m.update_xaxes(title_text="Intensity", row=2, col=1) - - # Update yaxis properties - fig_m.update_yaxes(title_text="Intensity", row=1, col=2) - fig_m.update_yaxes(title_text="Ion Mobility", row=2, col=1) - - # Update the layout - fig_m.update_layout( - height=800, - width=1200, - title=self.config.title - ) - - # Reset the config values to org - self.config.show = show_org - - if self.config.show: - fig_m.show() - return fig_m - - if self.config.show: - fig.show() - return fig - - if self.config.plot_type == "lineplot": - return _plotLines(self, data, chromatogramFeatures) - elif self.config.plot_type == "heatmap": - return _plotHeatmap(self, data) - else: - raise ValueError(f"Invalid plot type: {type}") - - def _plotMatplotlib(self, data: DataFrame, chromatogramFeatures: DataFrame = None): - - def _plotLines(self, data: pd.DataFrame, chromatogramFeatures: DataFrame = None, ax=None): - import matplotlib.pyplot as plt - from matplotlib.lines import Line2D - - # Create a figure and axis - if ax is None: - fig, ax = plt.subplots(figsize=(self.config.width/100, self.config.height/100), dpi=100) - else: - fig = ax.get_figure() - - - # Set plot title and axis labels - ax.set_title(self.config.title) - ax.set_xlabel(self.config.xlabel) - ax.set_ylabel(self.config.ylabel) - - # Create a legend - legend_lines = [] - legend_labels = [] - - # Plot each unique annotation - if "Annotation" in data.columns: - for i, (annotation, group_df) in enumerate(data.groupby('Annotation')): - line, = ax.plot(group_df[self.config.x_axis_col], group_df[self.config.y_axis_col], color=self.main_palette[i], linewidth=self.config.lineWidth, ls=self.config.lineStyle) - legend_lines.append(line) - legend_labels.append(annotation) - - # Add legend - matplotlibLegendLoc= LegendConfig._matplotlibLegendLocationMapper(self.config.legend.loc) - legend = ax.legend(legend_lines, legend_labels, loc=matplotlibLegendLoc, bbox_to_anchor=self.config.legend.bbox_to_anchor, title=self.config.legend.title, prop={'size': self.config.legend.fontsize}) - legend.get_title().set_fontsize(str(self.config.legend.fontsize)) - - else: # only one transition - line, = ax.plot(data[self.config.x_axis_col], data[self.config.y_axis_col], color=self.main_palette[0], linewidth=self.config.lineWidth, ls=self.config.lineStyle) - - # Customize the plot - ax.grid(self.config.grid) - - ## add 10% padding to the plot - padding = (data[self.config.y_axis_col].max() - data[self.config.y_axis_col].min() ) * 0.1 - ax.set_xlim(data[self.config.x_axis_col].min(), data[self.config.x_axis_col].max()) - ax.set_ylim(data[self.config.y_axis_col].min(), data[self.config.y_axis_col].max() + padding) - - ##### Plotting chromatogram features ##### - if chromatogramFeatures is not None: - ax.add_artist(legend) - - for idx, (_, feature) in enumerate(chromatogramFeatures.iterrows()): - - ax.vlines(x=feature['leftWidth'], ymin=0, ymax=feature['apexIntensity'], lw=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], ls=self.config.featureConfig.lineStyle) - ax.vlines(x=feature['rightWidth'], ymin=0, ymax=feature['apexIntensity'], lw=self.config.featureConfig.lineWidth, color=self.feature_palette[idx], ls=self.config.featureConfig.lineStyle) - - if self.config.featureConfig.legend.show: - custom_lines = [Line2D([0], [0], color=self.feature_palette[i], lw=2) for i in range(len(chromatogramFeatures))] - if "q_value" in chromatogramFeatures.columns: - legend_labels = [f'Feature {i} (q={feature["q_value"]:.2f})' for i, (_,feature) in enumerate(chromatogramFeatures.iterrows())] - else: - legend_labels = [f'Feature {i}' for i in range(len(chromatogramFeatures))] - - if self.config.featureConfig.legend.show: - - matplotlibLegendLoc= LegendConfig._matplotlibLegendLocationMapper(self.config.featureConfig.legend.loc) - ax.legend(custom_lines, legend_labels, loc=matplotlibLegendLoc, bbox_to_anchor=self.config.featureConfig.legend.bbox_to_anchor, title=self.config.featureConfig.legend.title) - - if self.config.show: - plt.show() - return fig - - def _plotHeatmap(self, data: pd.DataFrame): - import matplotlib.pyplot as plt - - if not self.config.add_marginals: - # Create a figure and axis - fig, ax = plt.subplots(figsize=(self.config.width/100, self.config.height/100), dpi=200, constrained_layout=True) - - # Plot each unique annotation - for annotation, df_group in data.groupby("Annotation"): - x = df_group.rt - y = df_group.im - values = df_group.int - - scatter = ax.scatter(x, y, c=values, cmap='afmhot_r', marker='s', s=20, edgecolors='none') - - # Customize the plot - ax.set_title(self.config.title) - ax.set_xlabel("Retention Time [sec]") - ax.set_ylabel("Ion Mobility") - - else: - # Store original config values that need to be changed for marginals - show_org = self.config.show - self.config.show = False - - # Create a figure and axis - fig, ax = plt.subplots(2, 2, figsize=(self.config.width/100, self.config.height/100), dpi=200) - - # Plot each unique annotation - for annotation, df_group in data.groupby("Annotation"): - x = df_group.rt - y = df_group.im - values = df_group.int - - scatter = ax[1, 1].scatter(x, y, c=values, cmap='afmhot_r', marker='s', s=20, edgecolors='none') - - # Customize the plot - ax[1, 1].set_title(None) - ax[1, 1].set_xlabel("Retention Time [sec]") - ax[1, 1].set_ylabel(None) - ax[1, 1].set_yticklabels([]) - ax[1, 1].set_yticks([]) - - # Integrate the data along the retention time dimension - rt_integrated = self._integrate_data_along_dim(data, 'rt') - _plotLines(self, rt_integrated, chromatogramFeatures, ax=ax[0, 1]) - # Generate a lineplot for XIC - # ax[0, 1].plot(rt_integrated['rt'], rt_integrated['int']) - # ax[0, 1].set_title("Integrated Ion Chromatogram") - ax[0, 1].set_title(None) - ax[0, 1].set_xlabel(None) - ax[0, 1].set_xticklabels([]) - ax[0, 1].set_xticks([]) - ax[0, 1].set_ylabel("Intensity") - ax[0, 1].yaxis.set_ticks_position('right') - ax[0, 1].yaxis.set_label_position('right') - ax[0, 1].yaxis.tick_right() - ax[0, 1].legend_ = None - - # Integrate the data along the ion mobility dimension - im_integrated = self._integrate_data_along_dim(data, 'im') - self.config.x_axis_col = 'int' - self.config.y_axis_col = 'im' - self.config.legend.loc = 'below' - # self.config.y_axis_location = "left" - # self.config.legend.loc = 'below' - _plotLines(self, im_integrated, ax=ax[1, 0]) - # Generate a lineplot for XIM - # ax[1, 0].plot(im_integrated['int'], im_integrated['im']) - ax[1, 0].invert_xaxis() - ax[1, 0].set_title(None) - ax[1, 0].set_xlabel("Intensity") - ax[1, 0].set_ylabel("Ion Mobility") - ax[1, 0].legend_ = None - - - # Hide the first subplot - ax[0, 0].axis('off') - - # Adjust the layout - plt.subplots_adjust(wspace=0, hspace=0) - # plt.tight_layout() - - # Reset the config values to org - self.config.show = show_org - - if self.config.show: - plt.show() - return fig - - if self.config.plot_type == "lineplot": - return _plotLines(self, data, chromatogramFeatures) - elif self.config.plot_type == "heatmap": - return _plotHeatmap(self, data) - else: - raise ValueError(f"Invalid plot type: {type}") - -# ============================================================================= # -## FUNCTIONAL API ## -# ============================================================================= # -def plotChromatogram(chromatogram: pd.DataFrame, - chromatogram_features: pd.DataFrame = None, - plot_type: str = "lineplot", - add_marginals: bool = False, - title: str = "Chromatogram Plot", - show_plot: bool = True, - ion_mobility: bool = False, - width: int = 500, - height: int = 500, - engine: Literal['PLOTLY', 'BOKEH', 'MATPLOTLIB'] = 'PLOTLY', - **kwargs): - """ - Plot a Chromatogram from a MSChromatogram Object - - Args: - chromatogram (DataFrame): DataFrame containing chromatogram data - chromatogram_features (DataFrame, optional): DataFrame containing chromatogram features. Defaults to None. - plot_type (str, optional): Type of plot to generate. Defaults to "lineplot". Can be either "lineplot" or "heatmap". - add_marginals (bool, optional): If True, adds marginal plots for integrated ion chromatogram and ion mobilogram. Defaults to False. - title (str, optional): title of plot. Defaults to "Chromatogram Plot". - show_plot (bool, optional): If True, shows the plot. Defaults to True. - ion_mobility (bool, optional): If True, plots a heatmap of Retention Time vs ion mobility with intensity as the color. Defaults to False. - width (int, optional): width of the figure. Defaults to 500. - height (int, optional): height of the figure. Defaults to 500. - engine (Literal['PLOTLY', 'BOKEH'], optional): Plotting engine to use. Defaults to 'PLOTLY'. Can be either 'PLOTLY' or 'BOKEH' - - Returns: - PLOTLY figure or BOKEH figure depending on engine - """ - - - config = ChromatogramPlotterConfig(title=title, show=show_plot, ion_mobility=ion_mobility, width=width, height=height, plot_type=plot_type, add_marginals=add_marginals, engine=engine, **kwargs) - - plotter = ChromatogramPlotter(config) - return plotter.plot(chromatogram=chromatogram, chromatogramFeatures=chromatogram_features) - diff --git a/pyopenms_viz/MSExperimentPlotter.py b/pyopenms_viz/MSExperimentPlotter.py deleted file mode 100644 index ed6ef1a5..00000000 --- a/pyopenms_viz/MSExperimentPlotter.py +++ /dev/null @@ -1,332 +0,0 @@ -from dataclasses import dataclass -from typing import Literal, Union - -import matplotlib.pyplot as plt -import pandas as pd -import plotly.graph_objects as go -from bokeh.models import ColorBar, ColumnDataSource, HoverTool, PrintfTickFormatter -from bokeh.palettes import Plasma256 -from bokeh.plotting import figure -from bokeh.transform import linear_cmap - -from .BasePlotter import Colors, _BasePlotter, _BasePlotterConfig - - -@dataclass(kw_only=True) -class MSExperimentPlotterConfig(_BasePlotterConfig): - bin_peaks: Union[Literal["auto"], bool] = "auto" - num_RT_bins: int = 50 - num_mz_bins: int = 50 - plot3D: bool = False - - -class MSExperimentPlotter(_BasePlotter): - def __init__(self, config: MSExperimentPlotterConfig, **kwargs) -> None: - """ - Initialize the MSExperimentPlotter with a given configuration and optional parameters. - - Args: - config (MSExperimentPlotterConfig): Configuration settings for the spectrum plotter. - **kwargs: Additional keyword arguments for customization. - """ - super().__init__(config=config, **kwargs) - - def _prepare_data(self, exp: pd.DataFrame) -> pd.DataFrame: - """Prepares data for plotting based on configuration (binning, relative intensity, hover text).""" - if self.config.bin_peaks == True or ( - exp.shape[0] > self.config.num_mz_bins * self.config.num_RT_bins - and self.config.bin_peaks == "auto" - ): - exp["mz"] = pd.cut(exp["mz"], bins=self.config.num_mz_bins) - exp["RT"] = pd.cut(exp["RT"], bins=self.config.num_RT_bins) - - # Group by x and y bins and calculate the mean intensity within each bin - exp = ( - exp.groupby(["mz", "RT"], observed=True) - .agg({"inty": "mean"}) - .reset_index() - ) - exp["mz"] = exp["mz"].apply(lambda interval: interval.mid).astype(float) - exp["RT"] = exp["RT"].apply(lambda interval: interval.mid).astype(float) - exp = exp.fillna(0) - else: - self.config.bin_peaks = False - - if self.config.relative_intensity: - exp["inty"] = exp["inty"] / max(exp["inty"]) * 100 - - exp["hover_text"] = exp.apply( - lambda x: f"m/z: {round(x['mz'], 6)}
RT: {round(x['RT'], 2)}
intensity: {int(x['inty'])}", - axis=1, - ) - - return exp.sort_values("inty") - - def _plotMatplotlib3D( - self, - exp: pd.DataFrame, - ) -> plt.Figure: - """Plot 3D peak map with mz, RT and intensity dimensions. Colored peaks based on intensity.""" - fig = plt.figure( - figsize=(self.config.width / 100, self.config.height / 100), - layout="constrained", - ) - ax = fig.add_subplot(111, projection="3d") - - if self.config.title: - ax.set_title(self.config.title, fontsize=12, loc="left") - ax.set_xlabel( - self.config.ylabel, - fontsize=9, - labelpad=-2, - color=Colors["DARKGRAY"], - style="italic", - ) - ax.set_ylabel( - self.config.xlabel, - fontsize=9, - labelpad=-2, - color=Colors["DARKGRAY"], - ) - ax.set_zlabel("intensity", fontsize=10, color=Colors["DARKGRAY"], labelpad=-2) - for axis in ("x", "y", "z"): - ax.tick_params(axis=axis, labelsize=8, pad=-2, colors=Colors["DARKGRAY"]) - - ax.set_box_aspect(aspect=None, zoom=0.88) - ax.ticklabel_format(axis="z", style="sci", useMathText=True) - ax.grid(color="#FF0000", linewidth=0.8) - ax.xaxis.pane.fill = False - ax.yaxis.pane.fill = False - ax.zaxis.pane.fill = False - ax.view_init(elev=25, azim=-45, roll=0) - - # Plot lines to the bottom with colored based on inty - for i in range(len(exp)): - ax.plot( - [exp["RT"].iloc[i], exp["RT"].iloc[i]], - [exp["inty"].iloc[i], 0], - [exp["mz"].iloc[i], exp["mz"].iloc[i]], - zdir="x", - color=plt.cm.magma_r(exp["inty"].iloc[i] / exp["inty"].max()), - ) - return fig - - def _plotMatplotlib2D( - self, - exp: pd.DataFrame, - ) -> plt.Figure: - """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity.""" - if self.config.plot3D: - return self._p - fig, ax = plt.subplots( - figsize=(self.config.width / 100, self.config.height / 100) - ) - if self.config.title: - ax.set_title(self.config.title, fontsize=12, loc="left", pad=20) - ax.set_xlabel(self.config.xlabel, fontsize=10, color=Colors["DARKGRAY"]) - ax.set_ylabel( - self.config.ylabel, fontsize=10, style="italic", color=Colors["DARKGRAY"] - ) - ax.xaxis.label.set_color(Colors["DARKGRAY"]) - ax.tick_params(axis="x", colors=Colors["DARKGRAY"]) - ax.yaxis.label.set_color(Colors["DARKGRAY"]) - ax.tick_params(axis="y", colors=Colors["DARKGRAY"]) - ax.spines[["right", "top"]].set_visible(False) - - scatter = ax.scatter( - exp["RT"], - exp["mz"], - c=exp["inty"], - cmap="magma_r", - s=20, - marker="s", - ) - if self.config.show_legend: - cb = fig.colorbar(scatter, aspect=40) - cb.outline.set_visible(False) - if self.config.relative_intensity: - cb.ax.yaxis.set_major_formatter( - plt.FuncFormatter(lambda x, _: f"{int(x)}%") - ) - else: - cb.formatter.set_powerlimits((0, 0)) - cb.formatter.set_useMathText(True) - return fig - - def _plotMatplotlib( - self, - exp: pd.DataFrame, - ) -> plt.Figure: - """Prepares data and returns Matplotlib 2D or 3D plot.""" - exp = self._prepare_data(exp) - if self.config.plot3D: - return self._plotMatplotlib3D(exp) - return self._plotMatplotlib2D(exp) - - def _plotBokeh2D( - self, - exp: pd.DataFrame, - ) -> figure: - """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity.""" - # Initialize figure - p = figure( - title=self.config.title, - x_axis_label=self.config.xlabel, - y_axis_label=self.config.ylabel, - width=self.config.width, - height=self.config.height, - ) - - p.grid.grid_line_color = None - p.border_fill_color = None - p.outline_line_color = None - - mapper = linear_cmap( - field_name="inty", - palette=Plasma256[::-1], - low=exp["inty"].min(), - high=exp["inty"].max(), - ) - source = ColumnDataSource(exp) - p.scatter( - x="RT", - y="mz", - size=6, - source=source, - color=mapper, - marker="square", - ) - # if not self.config.bin_peaks: - hover = HoverTool( - tooltips=""" -
- @hover_text{safe} -
- """ - ) - p.add_tools(hover) - if self.config.show_legend: - # Add color bar - color_bar = ColorBar( - color_mapper=mapper["transform"], - width=8, - location=(0, 0), - formatter=PrintfTickFormatter(format="%e"), - ) - p.add_layout(color_bar, "right") - - return p - - def _plotBokeh( - self, - exp: pd.DataFrame, - ) -> figure: - """Prepares data and returns Bokeh 2D plot.""" - exp = self._prepare_data(exp) - return self._plotBokeh2D(exp) - - def _plotPlotly2D( - self, - exp: pd.DataFrame, - ) -> go.Figure: - """Plot 2D peak map with mz and RT dimensions. Colored peaks based on intensity.""" - layout = go.Layout( - title=dict(text=self.config.title), - xaxis=dict(title=self.config.xlabel), - yaxis=dict(title=self.config.ylabel), - showlegend=self.config.show_legend, - template="simple_white", - dragmode="select", - height=self.config.height, - width=self.config.width, - ) - fig = go.Figure(layout=layout) - fig.add_trace( - go.Scattergl( - name="peaks", - x=exp["RT"], - y=exp["mz"], - mode="markers", - marker=dict( - color=exp["inty"], - colorscale="sunset", - size=8, - symbol="square", - colorbar=( - dict(thickness=8, outlinewidth=0, tickformat=".0e") - if self.config.show_legend - else None - ), - ), - hovertext=exp["hover_text"],# if not self.config.bin_peaks else None, - hoverinfo="text", - showlegend=False, - ) - ) - return fig - - def _plotPlotly( - self, - exp: pd.DataFrame, - ) -> go.Figure: - """Prepares data and returns Plotly 2D plot.""" - exp = self._prepare_data(exp) - return self._plotPlotly2D(exp) - -# ============================================================================= # -## FUNCTIONAL API ## -# ============================================================================= # - - -def plotMSExperiment( - exp: pd.DataFrame, - plot3D: bool = False, - relative_intensity: bool = False, - bin_peaks: Union[Literal["auto"], bool] = "auto", - num_RT_bins: int = 50, - num_mz_bins: int = 50, - width: int = 750, - height: int = 500, - title: str = "Peak Map", - xlabel: str = "RT (s)", - ylabel: str = "m/z", - show_legend: bool = False, - engine: Literal["PLOTLY", "BOKEH", "MATPLOTLIB"] = "PLOTLY", -): - """ - Plots a Spectrum from an MSSpectrum object - - Args: - spectrum (pd.DataFrame): OpenMS MSSpectrum Object - plot3D: (bool = False, optional): Plot peak map 3D with peaks colored based on intensity. Disables colorbar legend. Works with "MATPLOTLIB" engine only. Defaults to False. - relative_intensity (bool, optional): If true, plot relative intensity values. Defaults to False. - bin_peaks: (Union[Literal["auto"], bool], optional): Bin peaks to reduce complexity and improve plotting speed. Hovertext disabled if activated. If set to "auto" any MSExperiment with more then num_RT_bins x num_mz_bins peaks will be binned. Defaults to "auto". - num_RT_bins: (int, optional): Number of bins in RT dimension. Defaults to 50. - num_mz_bins: (int, optional): Number of bins in m/z dimension. Defaults to 50. - width (int, optional): Width of plot. Defaults to 500px. - height (int, optional): Height of plot. Defaults to 500px. - title (str, optional): Plot title. Defaults to "Spectrum Plot". - xlabel (str, optional): X-axis label. Defaults to "m/z". - ylabel (str, optional): Y-axis label. Defaults to "intensity" or "ion mobility". - show_legend (int, optional): Show legend. Defaults to False. - engine (Literal['PLOTLY', 'BOKEH', 'MATPLOTLIB'], optional): Plotting engine to use. Defaults to 'PLOTLY' can be either 'PLOTLY', 'BOKEH' or 'MATPLOTLIB'. - - Returns: - Plot: The generated plot using the specified engine. - """ - config = MSExperimentPlotterConfig( - plot3D=plot3D, - relative_intensity=relative_intensity, - bin_peaks=bin_peaks, - num_RT_bins=num_RT_bins, - num_mz_bins=num_mz_bins, - width=width, - height=height, - title=title, - xlabel=xlabel, - ylabel=ylabel, - show_legend=show_legend, - engine=engine, - ) - plotter = MSExperimentPlotter(config) - return plotter.plot(exp.copy()) diff --git a/pyopenms_viz/SpectrumPlotter.py b/pyopenms_viz/SpectrumPlotter.py deleted file mode 100644 index 5d7d5808..00000000 --- a/pyopenms_viz/SpectrumPlotter.py +++ /dev/null @@ -1,622 +0,0 @@ -import re -from dataclasses import dataclass -from itertools import cycle -from typing import Literal, Optional, Union - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import plotly.graph_objects as go -from bokeh.models import ColorBar, ColumnDataSource, HoverTool, Label, Span -from bokeh.palettes import Plasma256 -from bokeh.plotting import figure -from bokeh.transform import linear_cmap - -from .BasePlotter import Colors, _BasePlotter, _BasePlotterConfig - - -@dataclass(kw_only=True) -class SpectrumPlotterConfig(_BasePlotterConfig): - ion_mobility: bool = False - annotate_mz: bool = False - annotate_ions: bool = False - annotate_sequence: bool = False - mirror_spectrum: bool = (False,) - custom_peak_color: bool = (False,) - custom_annotation_color: bool = (False,) - custom_annotation_text: bool = False - - -class SpectrumPlotter(_BasePlotter): - def __init__(self, config: SpectrumPlotterConfig, **kwargs) -> None: - """ - Initialize the SpectrumPlotter with a given configuration and optional parameters. - - Args: - config (SpectrumPlotterConfig): Configuration settings for the spectrum plotter. - **kwargs: Additional keyword arguments for customization. - """ - super().__init__(config=config, **kwargs) - # If y-axis label is default ("intensity") and ion_mobility is True, update label - if self.config.ylabel == "intensity" and self.config.ion_mobility: - self.config.ylabel = "ion mobility" - - def _prepare_data( - self, - spectrum: Union[pd.DataFrame, list[pd.DataFrame]], - reference_spectrum: Union[pd.DataFrame, list[pd.DataFrame], None], - ) -> tuple[list, list]: - """Prepares data for plotting based on configuration (ensures list format for input spectra, relative intensity, hover text).""" - - # Ensure input spectra dataframes are in lists - if not isinstance(spectrum, list): - spectrum = [spectrum] - - if reference_spectrum is None: - reference_spectrum = [] - elif not isinstance(reference_spectrum, list): - reference_spectrum = [reference_spectrum] - # Convert to relative intensity if required - if self.config.relative_intensity or self.config.mirror_spectrum: - combined_spectra = spectrum + ( - reference_spectrum if reference_spectrum else [] - ) - for df in combined_spectra: - df["intensity"] = df["intensity"] / df["intensity"].max() * 100 - # Add hover text - for spec in spectrum+reference_spectrum: - spec["hover_text_ion_mobility"] = spec.apply( - lambda x: f"{x['native_id']}
m/z: {x['mz']}
ion mobility: {x['ion_mobility']}
intensity: {x['intensity']}", - axis=1, - ) - return spectrum, reference_spectrum - - def _get_ion_color_annotation(self, annotation: str) -> str: - """Retrieve the color associated with a specific ion annotation from a predefined colormap.""" - colormap = { - "a": Colors["PURPLE"], - "b": Colors["BLUE"], - "c": Colors["LIGHTBLUE"], - "x": Colors["YELLOW"], - "y": Colors["RED"], - "z": Colors["ORANGE"], - } - for key in colormap.keys(): - # Exact matches - if annotation == key: - return colormap[key] - # Fragment ions via regex - x = re.search(r"^[abcxyz]{1}[0-9]*[+-]$", annotation) - if x: - return colormap[annotation[0]] - return Colors["DARKGRAY"] - - def _get_peak_color(self, default_color: str, peak: pd.Series) -> str: - """Determine the color of a peak based on custom settings or annotation.""" - if self.config.custom_peak_color and "color_peak" in peak: - return peak["color_peak"] - - if self.config.annotate_ions: - return self._get_annotation_color(peak, default_color) - - return default_color - - def _get_annotation_text(self, peak: pd.Series) -> str: - """Generate the annotation text for a given peak based on the configuration.""" - if "custom_annotation" in peak and self.config.custom_annotation_text: - return peak["custom_annotation"] - - texts = [] - - if self.config.annotate_ions and peak["ion_annotation"] != "none": - texts.append(peak["ion_annotation"]) - - if self.config.annotate_sequence and peak["sequence"]: - texts.append(peak["sequence"]) - - if self.config.annotate_mz: - texts.append(str(peak["mz"])) - - return "
".join(texts) - - def _get_annotation_color( - self, peak: pd.Series, fallback_color: str = "black" - ) -> str: - """Determine the color for annotations based on custom settings or ion type.""" - if "color_annotation" in peak and self.config.custom_annotation_color: - return peak["color_annotation"] - - if self.config.annotate_ions: - return self._get_ion_color_annotation(peak["ion_annotation"]) - - if self.config.custom_peak_color and "color_peak" in peak: - return peak["color_peak"] - - return fallback_color - - def _get_relative_intensity_ticks(self) -> tuple[list[int], list[str]]: - """Generate the ticks and labels for relative intensity on the y-axis.""" - ticks = [0, 25, 50, 75, 100] - labels = ["0%", "25%", "50%", "75%", "100%"] - - if self.config.mirror_spectrum: - mirror_ticks = [-100, -75, -50, -25] - mirror_labels = ["-100%", "-75%", "-50%", "-25%"] - ticks = mirror_ticks + ticks - labels = mirror_labels + labels - - return ticks, labels - - - def _combine_sort_spectra_by_intensity( - self, spectra: list[pd.DataFrame] - ) -> pd.DataFrame: - """Combine and sort spectra by intensity.""" - combined_df = pd.concat(spectra).reset_index(drop=True) - sorted_df = combined_df.sort_values(by="intensity").reset_index(drop=True) - return sorted_df - - def _plotMatplotlib( - self, - spectrum: Union[pd.DataFrame, list[pd.DataFrame]], - reference_spectrum: Optional[Union[pd.DataFrame, list[pd.DataFrame]]] = None, - ) -> plt.Figure: - """Plot the spectrum using Matplotlib.""" - - def plot_spectrum(ax, df, color, mirror=False): - for i, peak in df.iterrows(): - intensity = -peak["intensity"] if mirror else peak["intensity"] - peak_color = self._get_peak_color(color, peak) - ax.plot( - [peak["mz"], peak["mz"]], - [0, intensity], - color=peak_color, - linewidth=1.5, - label=peak["native_id"] if i == 0 else None, - ) - if any( - [ - self.config.annotate_mz, - self.config.annotate_ions, - self.config.annotate_sequence, - self.config.custom_annotation_text, - ] - ): - text = self._get_annotation_text(peak).replace("
", "\n") - annotation_color = self._get_annotation_color(peak, peak_color) - ax.annotate( - text, - xy=(peak["mz"], intensity), - xytext=(1, 0), - textcoords="offset points", - fontsize=8, - color=annotation_color, - ) - - spectrum, reference_spectrum = self._prepare_data( - spectrum, reference_spectrum - ) - - fig, ax = plt.subplots( - figsize=(self.config.width / 100, self.config.height / 100) - ) - ax.set_title(self.config.title, fontsize=12, loc="left", pad=20) - ax.set_xlabel( - self.config.xlabel, fontsize=10, style="italic", color=Colors["DARKGRAY"] - ) - ax.set_ylabel(self.config.ylabel, fontsize=10, color=Colors["DARKGRAY"]) - ax.xaxis.label.set_color(Colors["DARKGRAY"]) - ax.tick_params(axis="x", colors=Colors["DARKGRAY"]) - ax.yaxis.label.set_color(Colors["DARKGRAY"]) - ax.tick_params(axis="y", colors=Colors["DARKGRAY"]) - ax.spines[["right", "top"]].set_visible(False) - - if self.config.ion_mobility: - df = self._combine_sort_spectra_by_intensity(spectrum) - scatter = ax.scatter( - df["mz"], - df["ion_mobility"], - c=df["intensity"], - cmap="plasma_r", - s=20, - marker="s", - ) - if self.config.show_legend: - cb = fig.colorbar(scatter, aspect=40) - cb.outline.set_visible(False) - return fig - - gs_colors = self._get_n_grayscale_colors( - max(len(spectrum), len(reference_spectrum or [])) - ) - colors = cycle(gs_colors) - - for spec in spectrum: - plot_spectrum(ax, spec, next(colors)) - - if self.config.mirror_spectrum: - colors = cycle(gs_colors) - for ref_spec in reference_spectrum: - plot_spectrum(ax, ref_spec, next(colors), mirror=True) - ax.plot(ax.get_xlim(), [0, 0], color="#EEEEEE", linewidth=1.5) - - ax.set_ylim([0 if not self.config.mirror_spectrum else None, None]) - if self.config.relative_intensity or self.config.mirror_spectrum: - ticks, labels = self._get_relative_intensity_ticks() - ax.set_yticks(ticks) - ax.set_yticklabels(labels) - else: - ax.ticklabel_format(axis="both", style="sci", useMathText=True) - - if self.config.show_legend: - ax.legend(loc="best") - - return fig - - def _plotBokeh( - self, - spectrum: Union[pd.DataFrame, list[pd.DataFrame]], - reference_spectrum: Optional[Union[pd.DataFrame, list[pd.DataFrame]]] = None, - ) -> figure: - """Plot the spectrum using Bokeh.""" - - def plot_spectrum(p, df, color, mirror=False): - """Add spectrum plot to figure including annotations based on configuration, default peak color and mirror spectrum.""" - for i, peak in df.iterrows(): - intensity = -peak["intensity"] if mirror else peak["intensity"] - peak_color = self._get_peak_color(color, peak) - if i == 0 and self.config.show_legend: - p.line( - [peak["mz"], peak["mz"]], - [0, intensity], - line_color=peak_color, - line_width=2, - legend_label=peak["native_id"], - ) - else: - p.line( - [peak["mz"], peak["mz"]], - [0, intensity], - line_color=peak_color, - line_width=2, - ) - if any( - [ - self.config.annotate_mz, - self.config.annotate_ions, - self.config.annotate_sequence, - self.config.custom_annotation_text, - ] - ): - text = self._get_annotation_text(peak).replace("
", "\n") - annotation_color = self._get_annotation_color(peak, peak_color) - label = Label( - x=peak["mz"], - y=intensity, - text=text, - text_font_size="8pt", - text_color=annotation_color, - x_offset=1, - y_offset=0, - ) - p.add_layout(label) - - def _get_ion_mobility_plot(p: figure) -> figure: - """Adds ion mobility traces to figure.""" - df = self._combine_sort_spectra_by_intensity(spectrum) - mapper = linear_cmap( - field_name="intensity", - palette=Plasma256[::-1], - low=df["intensity"].min(), - high=df["intensity"].max(), - ) - source = ColumnDataSource(df) - p.scatter( - x="mz", - y="ion_mobility", - size=6, - source=source, - color=mapper, - marker="square", - ) - hover = HoverTool( - tooltips=""" -
- @hover_text_ion_mobility{safe} -
- """ - ) - p.add_tools(hover) - if self.config.show_legend: - color_bar = ColorBar( - color_mapper=mapper["transform"], width=8, location=(0, 0) - ) - p.add_layout(color_bar, "right") - return p - - spectrum, reference_spectrum = self._prepare_data( - spectrum, reference_spectrum - ) - - # Initialize figure - p = figure( - title=self.config.title, - x_axis_label=self.config.xlabel, - y_axis_label=self.config.ylabel, - width=self.config.width, - height=self.config.height, - ) - - p.grid.grid_line_color = None - p.border_fill_color = None - p.outline_line_color = None - - if self.config.ion_mobility: - return _get_ion_mobility_plot(p) - - gs_colors = self._get_n_grayscale_colors( - max(len(spectrum), len(reference_spectrum or [])) - ) - colors = cycle(gs_colors) - - for spec in spectrum: - plot_spectrum(p, spec, next(colors)) - - if self.config.mirror_spectrum: - colors = cycle(gs_colors) - for ref_spec in reference_spectrum: - plot_spectrum(p, ref_spec, next(colors), mirror=True) - zero_line = Span( - location=0, dimension="width", line_color="#EEEEEE", line_width=1.5 - ) - p.add_layout(zero_line) - - if self.config.relative_intensity or self.config.mirror_spectrum: - ticks, labels = self._get_relative_intensity_ticks() - p.yaxis.ticker = ticks - p.yaxis.major_label_overrides = { - tick: label for tick, label in zip(ticks, labels) - } - else: - p.yaxis.formatter.use_scientific = True - - p.y_range.start = -110 if self.config.mirror_spectrum else 0 - # adjust x-axis limits to not cut peaks and annotations - x_values = [ - glyph.data_source.data["x"][0] - for glyph in p.renderers - if hasattr(glyph, "data_source") - ] - xmin = min(x_values) - xmax = max(x_values) - padding = 0.15 * (xmax - xmin) - p.x_range.end = xmax + padding - return p - - def _plotPlotly( - self, - spectrum: Union[pd.DataFrame, list[pd.DataFrame]], - reference_spectrum: Optional[Union[pd.DataFrame, list[pd.DataFrame]]] = None, - ) -> go.Figure: - """Plot the spectrum using Plotly.""" - - def _create_peak_traces( - spectrum: pd.DataFrame, - line_color: str, - intensity_direction: Literal[1, -1] = 1, - ) -> list[go.Scattergl]: - """Create peak traces based on given line color and orientation (-1 for mirror spectra).""" - return [ - go.Scattergl( - x=[peak["mz"]] * 2, - y=[0, intensity_direction * peak["intensity"]], - mode="lines", - line=dict(color=self._get_peak_color(line_color, peak)), - name=peak["native_id"], - text=f"{peak['native_id']}
m/z: {peak['mz']}
intensity: {peak['intensity']}", - hoverinfo="text", - showlegend=(i == 0), - ) - for i, peak in spectrum.iterrows() - ] - - def _create_annotations( - spectra: list[pd.DataFrame], - intensity_sign: Literal[1, -1] = 1, - ) -> list[dict]: - """Create peak annotations based on configuration.""" - if not any( - [ - self.config.annotate_mz, - self.config.annotate_ions, - self.config.annotate_sequence, - self.config.custom_annotation_text, - ] - ): - return [] - - annotations = [] - colors = cycle(self.gs_colors) - for spectrum in spectra: - for _, peak in spectrum.iterrows(): - text = self._get_annotation_text(peak) - color = self._get_annotation_color(peak, next(colors)) - annotations.append( - dict( - x=peak["mz"], - y=intensity_sign * peak["intensity"], - text=text, - showarrow=False, - xanchor="left", - font=dict( - family="Open Sans Mono, monospace", - size=12, - color=color, - ), - ) - ) - return annotations - - def _get_ion_mobility_plot(fig: go.Figure) -> go.Figure: - """Adds ion mobility traces to figure.""" - df = self._combine_sort_spectra_by_intensity(spectrum) - fig.add_trace( - go.Scattergl( - name="peaks", - x=df["mz"], - y=df["ion_mobility"], - mode="markers", - marker=dict( - color=df["intensity"], - colorscale="sunset", - size=8, - symbol="square", - colorbar=( - dict(thickness=8, outlinewidth=0) - if self.config.show_legend - else None - ), - ), - hovertext=df["hover_text_ion_mobility"], - hoverinfo="text", - showlegend=False, - ) - ) - return fig - - spectrum, reference_spectrum = self._prepare_data( - spectrum, reference_spectrum - ) - - layout = go.Layout( - title=dict(text=self.config.title), - xaxis=dict(title=self.config.xlabel), - yaxis=dict(title=self.config.ylabel, rangemode="tozero"), - showlegend=self.config.show_legend, - template="simple_white", - ) - fig = go.Figure(layout=layout) - - if self.config.ion_mobility: - return _get_ion_mobility_plot(fig) - - self.gs_colors = self._get_n_grayscale_colors( - max(len(spectrum), len(reference_spectrum or [])) - ) - colors = cycle(self.gs_colors) - - traces = [] - for spec in spectrum: - traces += _create_peak_traces(spec, next(colors)) - - if self.config.mirror_spectrum: - colors = cycle(self.gs_colors) - for ref_spec in reference_spectrum: - traces += _create_peak_traces( - ref_spec, next(colors), intensity_direction=-1 - ) - - fig.add_traces(traces) - - annotations = _create_annotations(spectrum, 1) - for annotation in annotations: - fig.add_annotation(annotation) - - if self.config.mirror_spectrum: - annotations = _create_annotations(reference_spectrum, -1) - for annotation in annotations: - fig.add_annotation(annotation) - fig.add_hline(y=0, line_color=Colors["LIGHTGRAY"], line_width=2) - - if self.config.relative_intensity or self.config.mirror_spectrum: - ticks, labels = self._get_relative_intensity_ticks() - fig.update_layout( - yaxis=dict(tickmode="array", tickvals=ticks, ticktext=labels), - yaxis_range=[-110 if self.config.mirror_spectrum else 0, 110], - ) - # adjust x-axis limits to not cut peaks and annotations - x_values = [trace.x for trace in fig.data] - xmin = min([min(values) for values in x_values]) - xmax = max([max(values) for values in x_values]) - padding = 0.15 * (xmax - xmin) - fig.update_layout( - xaxis_range=[ - xmin - 1, - xmax + padding, - ] - ) - - return fig - - -# ============================================================================= # -## FUNCTIONAL API ## -# ============================================================================= # - - -def plotSpectrum( - spectrum: Union[pd.DataFrame, list[pd.DataFrame]], - reference_spectrum: Union[pd.DataFrame, list[pd.DataFrame]] = None, - ion_mobility: bool = False, - annotate_mz: bool = False, - annotate_ions: bool = False, - annotate_sequence: bool = False, - mirror_spectrum: bool = False, - relative_intensity: bool = False, - custom_peak_color: bool = False, - custom_annotation_text: bool = False, - custom_annotation_color: bool = False, - width: int = 750, - height: int = 500, - title: str = "Spectrum Plot", - xlabel: str = "m/z", - ylabel: str = "intensity", - show_legend: bool = False, - engine: Literal["PLOTLY", "BOKEH", "MATPLOTLIB"] = "PLOTLY", -): - """ - Plots a Spectrum from an MSSpectrum object - - Args: - spectrum (Union[pd.DataFrame, list[pd.DataFrame]]): OpenMS MSSpectrum Object - reference_spectrum (Union[pd.DataFrame, list[pd.DataFrame]], optional): Optional OpenMS Spectrum object to plot in mirror or used in annotation. Defaults to None. - ion_mobility (bool, optional): If true, plots spectra (not including reference spectra) as heatmap of m/z vs ion mobility with intensity as color. Defaults to False. - annotate_mz (bool, optional): If true, annotate peaks with m/z values. Defaults to False. - annotate_ions (bool, optional): If true, annotate fragment ions. Defaults to False. - annotate_sequence (bool, optional): Annotate peaks based on sequence provided. Defaults to False - mirror_spectrum (bool, optional): If true, plot mirror spectrum. Defaults to True, if no mirror reference_spectrum is provided, this is ignored. - relative_intensity (bool, optional): If true, plot relative intensity values. Defaults to False. - custom_peak_color (bool, optional): If true, plot peaks with colors from "color_peak" column. - custom_annotation_text (bool, optional): If true, annotate peaks with custom text from "custom_annotation" column. Overwrites all other annotations.Use
for line breaks. - custom_annotation_color (bool, optional): If true, plot annotations with colors from "color_annotation" column. - width (int, optional): Width of plot. Defaults to 500px. - height (int, optional): Height of plot. Defaults to 500px. - title (str, optional): Plot title. Defaults to "Spectrum Plot". - xlabel (str, optional): X-axis label. Defaults to "m/z". - ylabel (str, optional): Y-axis label. Defaults to "intensity" or "ion mobility". - show_legend (int, optional): Show legend. Defaults to False. - engine (Literal['PLOTLY', 'BOKEH', 'MATPLOTLIB'], optional): Plotting engine to use. Defaults to 'PLOTLY' can be either 'PLOTLY', 'BOKEH' OR 'MATPLOTLIB'. - - Returns: - Plot: The generated plot using the specified engine. - """ - config = SpectrumPlotterConfig( - ion_mobility=ion_mobility, - annotate_mz=annotate_mz, - mirror_spectrum=mirror_spectrum, - annotate_sequence=annotate_sequence, - annotate_ions=annotate_ions, - relative_intensity=relative_intensity, - custom_peak_color=custom_peak_color, - custom_annotation_text=custom_annotation_text, - custom_annotation_color=custom_annotation_color, - width=width, - height=height, - title=title, - xlabel=xlabel, - ylabel=ylabel, - show_legend=show_legend, - engine=engine, - ) - plotter = SpectrumPlotter(config) - return plotter.plot(spectrum.copy(), reference_spectrum=reference_spectrum.copy()) diff --git a/pyopenms_viz/__init__.py b/pyopenms_viz/__init__.py index 2df4ba87..ad3a6504 100644 --- a/pyopenms_viz/__init__.py +++ b/pyopenms_viz/__init__.py @@ -1,8 +1,169 @@ -""" -pyopenms_viz is a package for visualizing OpenMS data using Bokeh and Plotly. -""" +from pandas.plotting._core import PlotAccessor +from pandas.core.frame import DataFrame +from typing import Any +from pandas.core.dtypes.generic import ABCDataFrame +import types -from .ChromatogramPlotter import plotChromatogram, ChromatogramPlotterConfig, ChromatogramPlotter -from .SpectrumPlotter import plotSpectrum, SpectrumPlotterConfig, SpectrumPlotter -__all__ = [ 'plotChromatogram', 'ChromatogramPlotterConfig', 'ChromatogramPlotter', 'plotSpectrum', 'SpectrumPlotterConfig', 'SpectrumPlotter'] \ No newline at end of file +class PlotAccessor: + """ + Make plots of MassSpec data using dataframes + + """ + + _common_kinds = ("line", "vline", "scatter") + _msdata_kinds = ("chromatogram", "mobilogram", "spectrum", "feature_heatmap") + _all_kinds = _common_kinds + _msdata_kinds + + def __init__(self, data: DataFrame) -> None: + self._parent = data + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + backend_name = kwargs.get("backend", None) + if backend_name is None: + backend_name = "matplotlib" + + plot_backend = _get_plot_backend(backend_name) + + x, y, kind, kwargs = self._get_call_args( + plot_backend.__name__, self._parent, args, kwargs + ) + + if kind not in self._all_kinds: + raise ValueError( + f"{kind} is not a valid plot kind " + f"Valid plot kinds: {self._all_kinds}" + ) + + # Call the plot method of the selected backend + if "backend" in kwargs: + kwargs.pop("backend") + return plot_backend.plot(self._parent, x=x, y=y, kind=kind, **kwargs) + + @staticmethod + def _get_call_args(backend_name: str, data: DataFrame, args, kwargs): + """ + Get the arguments to pass to the plotting backend. + + Parameters + ---------- + backend_name : str + The name of the backend. + data : DataFrame + The data to plot. + args : tuple + The positional arguments passed to the plotting function. + kwargs : dict + The keyword arguments passed to the plotting function. + + Returns + ------- + dict + The arguments to pass to the plotting backend. + """ + if isinstance(data, ABCDataFrame): + arg_def = [ + ("x", None), + ("y", None), + ("kind", "line"), + ("by", None), + ("subplots", None), + ("sharex", None), + ("sharey", None), + ("height", None), + ("width", None), + ("grid", None), + ("toolbar_location", None), + ("fig", None), + ("title", None), + ("xlabel", None), + ("ylabel", None), + ("x_axis_location", None), + ("y_axis_location", None), + ("line_type", None), + ("line_width", None), + ("min_border", None), + ("show_plot", None), + ("legend", None), + ("feature_config", None), + ("_config", None), + ("backend", backend_name), + ] + else: + raise ValueError( + f"Unsupported data type: {type(data).__name__}, expected DataFrame." + ) + + pos_args = {name: value for (name, _), value in zip(arg_def, args)} + + kwargs = dict(arg_def, **pos_args, **kwargs) + + x = kwargs.pop("x", None) + y = kwargs.pop("y", None) + kind = kwargs.pop("kind", "line") + return x, y, kind, kwargs + + +_backends: dict[str, types.ModuleType] = {} + + +def _load_backend(backend: str) -> types.ModuleType: + """ + Load a plotting backend. + + Parameters + ---------- + backend : str + The identifier for the backend. Either "bokeh", "matplotlib", "plotly", + or a module name. + + Returns + ------- + types.ModuleType + The imported backend. + """ + if backend == "bokeh": + try: + module = importlib.import_module("pyopenms_viz.plotting._bokeh") + except ImportError: + raise ImportError( + "Bokeh is required for plotting when the 'bokeh' backend is selected." + ) from None + return module + + elif backend == "matplotlib": + try: + module = importlib.import_module("pyopenms_viz.plotting._matplotlib") + except ImportError: + raise ImportError( + "Matplotlib is required for plotting when the 'matplotlib' backend is selected." + ) from None + return module + + elif backend == "plotly": + try: + module = importlib.import_module("pyopenms_viz.plotting._plotly") + except ImportError: + raise ImportError( + "Plotly is required for plotting when the 'plotly' backend is selected." + ) from None + return module + + raise ValueError( + f"Could not find plotting backend '{backend}'. Needs to be one of 'bokeh', 'matplotlib', or 'plotly'." + ) + + +def _get_plot_backend(backend: str | None = None): + + backend_str: str = backend or "matplotlib" + + if backend_str in _backends: + return _backends[backend_str] + + module = _load_backend(backend_str) + _backends[backend_str] = module + return module + + +__all__ = ["PlotAccessor"] diff --git a/pyopenms_viz/plotting/_bokeh/__init__.py b/pyopenms_viz/_bokeh/__init__.py similarity index 80% rename from pyopenms_viz/plotting/_bokeh/__init__.py rename to pyopenms_viz/_bokeh/__init__.py index f9030883..38547b99 100644 --- a/pyopenms_viz/plotting/_bokeh/__init__.py +++ b/pyopenms_viz/_bokeh/__init__.py @@ -2,8 +2,9 @@ from typing import TYPE_CHECKING +from ..constants import IS_SPHINX_BUILD -from pyopenms_viz.plotting._bokeh.core import ( +from .core import ( BOKEHLinePlot, BOKEHVLinePlot, BOKEHScatterPlot, @@ -14,7 +15,11 @@ ) if TYPE_CHECKING: - from pyopenms_viz.plotting._bokeh.core import BOKEHPlot + from .core import BOKEHPlot + +if IS_SPHINX_BUILD: + from .core import BOKEH_MSPlotter, BOKEHPlot + PLOT_CLASSES: dict[str, type[BOKEHPlot]] = { "line": BOKEHLinePlot, diff --git a/pyopenms_viz/plotting/_bokeh/core.py b/pyopenms_viz/_bokeh/core.py similarity index 96% rename from pyopenms_viz/plotting/_bokeh/core.py rename to pyopenms_viz/_bokeh/core.py index 68ed23b4..eacb121b 100644 --- a/pyopenms_viz/plotting/_bokeh/core.py +++ b/pyopenms_viz/_bokeh/core.py @@ -25,15 +25,15 @@ LinePlot, VLinePlot, ScatterPlot, - ComplexPlot, + BaseMSPlotter, ChromatogramPlot, MobilogramPlot, FeatureHeatmapPlot, SpectrumPlot, - APPEND_PLOT_DOC + APPEND_PLOT_DOC, ) from .._misc import ColorGenerator -from ...constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON +from ..constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON class BOKEHPlot(BasePlotter, ABC): @@ -302,7 +302,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, **kwargs): return fig, legend -class BOKEHComplexPlot(ComplexPlot, BOKEHPlot, ABC): +class BOKEH_MSPlotter(BaseMSPlotter, BOKEHPlot, ABC): def get_line_renderer(self, data, x, y, **kwargs) -> None: return BOKEHLinePlot(data, x, y, **kwargs) @@ -334,7 +334,7 @@ def _create_tooltips(self): return TOOLTIPS, None -class BOKEHChromatogramPlot(BOKEHComplexPlot, ChromatogramPlot): +class BOKEHChromatogramPlot(BOKEH_MSPlotter, ChromatogramPlot): """ Class for assembling a Bokeh extracted ion chromatogram plot """ @@ -363,8 +363,8 @@ def _add_peak_boundaries(self, annotation_data): line_dash=self.feature_config.line_type, line_width=self.feature_config.line_width, ) - if 'name' in annotation_data.columns: - use_name = feature['name'] + if "name" in annotation_data.columns: + use_name = feature["name"] else: use_name = f"Feature {idx}" if "q_value" in annotation_data.columns: @@ -404,7 +404,7 @@ class BOKEHMobilogramPlot(BOKEHChromatogramPlot, MobilogramPlot): pass -class BOKEHSpectrumPlot(BOKEHComplexPlot, SpectrumPlot): +class BOKEHSpectrumPlot(BOKEH_MSPlotter, SpectrumPlot): """ Class for assembling a Bokeh spectrum plot """ @@ -412,7 +412,7 @@ class BOKEHSpectrumPlot(BOKEHComplexPlot, SpectrumPlot): pass -class BOKEHFeatureHeatmapPlot(BOKEHComplexPlot, FeatureHeatmapPlot): +class BOKEHFeatureHeatmapPlot(BOKEH_MSPlotter, FeatureHeatmapPlot): """ Class for assembling a Bokeh feature heatmap plot """ @@ -429,7 +429,7 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs): self.fig = scatterPlot.generate( marker="square", line_color=mapper, fill_color=mapper, **other_kwargs ) - + if self.annotation_data is not None: self._add_box_boundaries(self.annotation_data) @@ -489,7 +489,7 @@ def get_manual_bounding_box_coords(self): "y1": f"{self.y}_1", } ) - + def _add_box_boundaries(self, annotation_data): color_gen = ColorGenerator( colormap=self.feature_config.colormap, n=annotation_data.shape[0] @@ -515,10 +515,10 @@ def _add_box_boundaries(self, annotation_data): color=next(color_gen), line_dash=self.feature_config.line_type, line_width=self.feature_config.line_width, - fill_alpha=0 + fill_alpha=0, ) - if 'name' in annotation_data.columns: - use_name = feature['name'] + if "name" in annotation_data.columns: + use_name = feature["name"] else: use_name = f"Feature {idx}" if "q_value" in annotation_data.columns: @@ -536,4 +536,3 @@ def _add_box_boundaries(self, annotation_data): str(self.feature_config.legend.fontsize) + "pt" ) self.fig.add_layout(legend, self.feature_config.legend.loc) - \ No newline at end of file diff --git a/pyopenms_viz/plotting/_config.py b/pyopenms_viz/_config.py similarity index 100% rename from pyopenms_viz/plotting/_config.py rename to pyopenms_viz/_config.py diff --git a/pyopenms_viz/plotting/_core.py b/pyopenms_viz/_core.py similarity index 97% rename from pyopenms_viz/plotting/_core.py rename to pyopenms_viz/_core.py index 929f0c1c..1e96f131 100644 --- a/pyopenms_viz/plotting/_core.py +++ b/pyopenms_viz/_core.py @@ -10,11 +10,7 @@ from pandas.core.dtypes.common import is_integer from pandas.util._decorators import Appender -from ._config import ( - LegendConfig, - FeatureConfig, - _BasePlotterConfig -) +from ._config import LegendConfig, FeatureConfig, _BasePlotterConfig from ._misc import ColorGenerator @@ -137,7 +133,7 @@ def __init__( self.kind = kind self.by = by self.relative_intensity = relative_intensity - + # Plotting attributes self.subplots = subplots self.sharex = sharex @@ -156,21 +152,21 @@ def __init__( self.line_width = line_width self.min_border = min_border self.show_plot = show_plot - + self.legend = legend self.feature_config = feature_config - + self._config = _config - + if _config is not None: self._update_from_config(_config) - + if self.legend is not None and isinstance(self.legend, dict): self.legend = LegendConfig.from_dict(self.legend) - + if self.feature_config is not None and isinstance(self.feature_config, dict): self.feature_config = FeatureConfig.from_dict(self.feature_config) - + ### get x and y data if self._kind in { "line", @@ -390,7 +386,7 @@ def _kind(self): return "scatter" -class ComplexPlot(BasePlotter, ABC): +class BaseMSPlotter(BasePlotter, ABC): """ Abstract class for complex plots, such as chromatograms and mobilograms which are made up of simple plots such as ScatterPlots, VLines and LinePlots. @@ -423,7 +419,7 @@ def _create_tooltips(self): pass -class ChromatogramPlot(BasePlotter, ABC): +class ChromatogramPlot(BaseMSPlotter, ABC): @property def _kind(self): return "chromatogram" @@ -431,10 +427,10 @@ def _kind(self): def __init__( self, data, x, y, annotation_data: DataFrame | None = None, **kwargs ) -> None: - + # Set default config attributes if not passed as keyword arguments kwargs["_config"] = _BasePlotterConfig(kind=self._kind) - + super().__init__(data, x, y, **kwargs) if annotation_data is not None: @@ -500,13 +496,19 @@ def plot(self, data, x, y, **kwargs): self._modify_y_range((0, self.data[y].max()), (0, 0.1)) -class SpectrumPlot(ComplexPlot, ABC): +class SpectrumPlot(BaseMSPlotter, ABC): @property def _kind(self): return "spectrum" def __init__( - self, data, x, y, reference_spectrum: DataFrame | None = None, mirror_spectrum: bool = False, **kwargs + self, + data, + x, + y, + reference_spectrum: DataFrame | None = None, + mirror_spectrum: bool = False, + **kwargs, ) -> None: # Set default config attributes if not passed as keyword arguments @@ -578,16 +580,24 @@ def _prepare_data( return spectrum, reference_spectrum -class FeatureHeatmapPlot(ComplexPlot, ABC): +class FeatureHeatmapPlot(BaseMSPlotter, ABC): # need to inherit from ChromatogramPlot and SpectrumPlot for get_line_renderer and get_vline_renderer methods respectively @property def _kind(self): return "feature_heatmap" def __init__( - self, data, x, y, z, zlabel=None, add_marginals=False, annotation_data: DataFrame | None = None, **kwargs + self, + data, + x, + y, + z, + zlabel=None, + add_marginals=False, + annotation_data: DataFrame | None = None, + **kwargs, ) -> None: - + # Set default config attributes if not passed as keyword arguments kwargs["_config"] = _BasePlotterConfig(kind=self._kind) @@ -596,11 +606,11 @@ def __init__( self.zlabel = zlabel self.add_marginals = add_marginals - + if annotation_data is not None: self.annotation_data = annotation_data.copy() else: - self.annotation_data = None + self.annotation_data = None super().__init__(data, x, y, z=z, **kwargs) self.plot(x, y, z, **kwargs) @@ -666,11 +676,11 @@ def create_x_axis_plot(self, x, z, class_kwargs) -> "figure": x_config = self._config.copy() x_config.ylabel = self.zlabel x_config.y_axis_location = "right" - x_config.legend.show = True + x_config.legend.show = True x_config.legend.loc = "right" color_gen = ColorGenerator() - + # remove legend from class_kwargs to update legend args for x axis plot class_kwargs.pop("legend", None) class_kwargs.pop("ylabel", None) @@ -697,7 +707,7 @@ def create_y_axis_plot(self, y, z, class_kwargs) -> "figure": y_config.y_axis_location = "left" y_config.legend.show = True y_config.legend.loc = "below" - + # remove legend from class_kwargs to update legend args for y axis plot class_kwargs.pop("legend", None) class_kwargs.pop("xlabel", None) @@ -715,7 +725,7 @@ def create_y_axis_plot(self, y, z, class_kwargs) -> "figure": @abstractmethod def combine_plots(self, x_fig, y_fig): pass - + @abstractmethod def _add_box_boundaries(self, annotation_data): """ diff --git a/pyopenms_viz/plotting/_matplotlib/__init__.py b/pyopenms_viz/_matplotlib/__init__.py similarity index 81% rename from pyopenms_viz/plotting/_matplotlib/__init__.py rename to pyopenms_viz/_matplotlib/__init__.py index 2f873236..3b8a4b58 100644 --- a/pyopenms_viz/plotting/_matplotlib/__init__.py +++ b/pyopenms_viz/_matplotlib/__init__.py @@ -1,8 +1,9 @@ from __future__ import annotations from typing import TYPE_CHECKING +from ..constants import IS_SPHINX_BUILD -from pyopenms_viz.plotting._matplotlib.core import ( +from .core import ( MATPLOTLIBLinePlot, MATPLOTLIBVLinePlot, MATPLOTLIBScatterPlot, @@ -13,7 +14,10 @@ ) if TYPE_CHECKING: - from pyopenms_viz.plotting._matplotlib.core import MATPLOTLIBPlot + from .core import MATPLOTLIBPlot + +if IS_SPHINX_BUILD: + from .core import MATPLOTLIB_MSPlotter, MATPLOTLIBPlot PLOT_CLASSES: dict[str, type[MATPLOTLIBPlot]] = { "line": MATPLOTLIBLinePlot, diff --git a/pyopenms_viz/plotting/_matplotlib/core.py b/pyopenms_viz/_matplotlib/core.py similarity index 95% rename from pyopenms_viz/plotting/_matplotlib/core.py rename to pyopenms_viz/_matplotlib/core.py index 494e47e5..40cbfa77 100644 --- a/pyopenms_viz/plotting/_matplotlib/core.py +++ b/pyopenms_viz/_matplotlib/core.py @@ -15,12 +15,12 @@ LinePlot, VLinePlot, ScatterPlot, - ComplexPlot, + BaseMSPlotter, ChromatogramPlot, MobilogramPlot, SpectrumPlot, FeatureHeatmapPlot, - APPEND_PLOT_DOC + APPEND_PLOT_DOC, ) @@ -259,7 +259,7 @@ def plot( return ax, (legend_lines, legend_labels) -class MATPLOTLIBComplexPlot(ComplexPlot, MATPLOTLIBPlot, ABC): +class MATPLOTLIB_MSPlotter(BaseMSPlotter, MATPLOTLIBPlot, ABC): def get_line_renderer(self, data, x, y, **kwargs) -> None: return MATPLOTLIBLinePlot(data, x, y, **kwargs) @@ -277,8 +277,9 @@ def _create_tooltips(self): # No tooltips for MATPLOTLIB because it is not interactive return None, None + @APPEND_PLOT_DOC -class MATPLOTLIBChromatogramPlot(MATPLOTLIBComplexPlot, ChromatogramPlot): +class MATPLOTLIBChromatogramPlot(MATPLOTLIB_MSPlotter, ChromatogramPlot): """ Class for assembling a matplotlib extracted ion chromatogram plot """ @@ -302,7 +303,7 @@ def _add_peak_boundaries(self, annotation_data): ) legend_items = [] - legend_labels = [] + legend_labels = [] for idx, (_, feature) in enumerate(annotation_data.iterrows()): use_color = next(color_gen) left_vlne = self.fig.vlines( @@ -323,8 +324,8 @@ def _add_peak_boundaries(self, annotation_data): ) legend_items.append(left_vlne) - if 'name' in annotation_data.columns: - use_name = feature['name'] + if "name" in annotation_data.columns: + use_name = feature["name"] else: use_name = f"Feature {idx}" if "q_value" in annotation_data.columns: @@ -360,7 +361,7 @@ class MATPLOTLIBMobilogramPlot(MATPLOTLIBChromatogramPlot, MobilogramPlot): @APPEND_PLOT_DOC -class MATPLOTLIBSpectrumPlot(MATPLOTLIBComplexPlot, SpectrumPlot): +class MATPLOTLIBSpectrumPlot(MATPLOTLIB_MSPlotter, SpectrumPlot): """ Class for assembling a matplotlib spectrum plot """ @@ -368,7 +369,7 @@ class MATPLOTLIBSpectrumPlot(MATPLOTLIBComplexPlot, SpectrumPlot): pass -class MATPLOTLIBFeatureHeatmapPlot(MATPLOTLIBComplexPlot, FeatureHeatmapPlot): +class MATPLOTLIBFeatureHeatmapPlot(MATPLOTLIB_MSPlotter, FeatureHeatmapPlot): """ Class for assembling a matplotlib feature heatmap plot """ @@ -428,7 +429,7 @@ def create_y_axis_plot(self, y, z, class_kwargs) -> "figure": y_config.legend.loc = "below" y_config.legend.orientation = "horizontal" y_config.legend.bbox_to_anchor = (1, -0.4) - + # remove legend from class_kwargs to update legend args for y axis plot class_kwargs.pop("legend", None) class_kwargs.pop("xlabel", None) @@ -460,7 +461,7 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs): cmap="afmhot_r", **other_kwargs, ) - + if self.annotation_data is not None: self._add_box_boundaries(self.annotation_data) @@ -482,12 +483,12 @@ def create_main_plot_marginals(self, x, y, z, class_kwargs, other_kwargs): self.ax_grid[1, 1].set_yticklabels([]) self.ax_grid[1, 1].set_yticks([]) self.ax_grid[1, 1].legend_ = None - + def _add_box_boundaries(self, annotation_data): if self.by is not None: legend = self.fig.get_legend() self.fig.add_artist(legend) - + color_gen = ColorGenerator( colormap=self.feature_config.colormap, n=annotation_data.shape[0] ) @@ -504,15 +505,19 @@ def _add_box_boundaries(self, annotation_data): height = abs(y1 - y0) color = next(color_gen) - custom_lines = Rectangle((x0, y0), width, height, - fill=False, - edgecolor=color, - linestyle=self.feature_config.line_type, - linewidth=self.feature_config.line_width) + custom_lines = Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=color, + linestyle=self.feature_config.line_type, + linewidth=self.feature_config.line_width, + ) self.fig.add_patch(custom_lines) - if 'name' in annotation_data.columns: - use_name = feature['name'] + if "name" in annotation_data.columns: + use_name = feature["name"] else: use_name = f"Feature {idx}" if "q_value" in annotation_data.columns: diff --git a/pyopenms_viz/plotting/_misc.py b/pyopenms_viz/_misc.py similarity index 100% rename from pyopenms_viz/plotting/_misc.py rename to pyopenms_viz/_misc.py diff --git a/pyopenms_viz/plotting/_plotly/__init__.py b/pyopenms_viz/_plotly/__init__.py similarity index 74% rename from pyopenms_viz/plotting/_plotly/__init__.py rename to pyopenms_viz/_plotly/__init__.py index 5655072b..4c820125 100644 --- a/pyopenms_viz/plotting/_plotly/__init__.py +++ b/pyopenms_viz/_plotly/__init__.py @@ -1,8 +1,9 @@ from __future__ import annotations from typing import TYPE_CHECKING +from ..constants import IS_SPHINX_BUILD -from pyopenms_viz.plotting._plotly.core import ( +from .core import ( PLOTLYLinePlot, PLOTLYVLinePlot, PLOTLYScatterPlot, @@ -13,9 +14,12 @@ ) if TYPE_CHECKING: - from pyopenms_viz.plotting._plotly.core import PLOTLYPlot + from .core import PLOTLYPlotter -PLOT_CLASSES: dict[str, type[PLOTLYPlot]] = { +if IS_SPHINX_BUILD: + from .core import PLOTLY_MSPlotter, PLOTLYPlotter + +PLOT_CLASSES: dict[str, type[PLOTLYPlotter]] = { "line": PLOTLYLinePlot, "vline": PLOTLYVLinePlot, "scatter": PLOTLYScatterPlot, diff --git a/pyopenms_viz/plotting/_plotly/core.py b/pyopenms_viz/_plotly/core.py similarity index 91% rename from pyopenms_viz/plotting/_plotly/core.py rename to pyopenms_viz/_plotly/core.py index 164ff699..fd9bb5a0 100644 --- a/pyopenms_viz/plotting/_plotly/core.py +++ b/pyopenms_viz/_plotly/core.py @@ -2,7 +2,7 @@ from abc import ABC -from typing import TYPE_CHECKING, Literal, List, Tuple, Union +from typing import List, Tuple, Union import plotly.graph_objects as go from plotly.graph_objs import Figure @@ -17,20 +17,20 @@ LinePlot, VLinePlot, ScatterPlot, - ComplexPlot, + BaseMSPlotter, ChromatogramPlot, MobilogramPlot, SpectrumPlot, FeatureHeatmapPlot, - APPEND_PLOT_DOC + APPEND_PLOT_DOC, ) from .._config import bokeh_line_dash_mapper -from pyopenms_viz.plotting._misc import ColorGenerator -from pyopenms_viz.constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON +from .._misc import ColorGenerator +from ..constants import PEAK_BOUNDARY_ICON, FEATURE_BOUNDARY_ICON -class PLOTLYPlot(BasePlotter, ABC): +class PLOTLYPlotter(BasePlotter, ABC): """ Base class for assembling a Ploty plot """ @@ -189,7 +189,7 @@ def show(self, **kwargs): self.fig.show(**kwargs) -class PLOTLYLinePlot(PLOTLYPlot, LinePlot): +class PLOTLYLinePlot(PLOTLYPlotter, LinePlot): """ Class for assembling a set of line plots in plotly """ @@ -228,7 +228,7 @@ def plot( # type: ignore[override] return fig, None -class PLOTLYVLinePlot(PLOTLYPlot, VLinePlot): +class PLOTLYVLinePlot(PLOTLYPlotter, VLinePlot): @classmethod @APPEND_PLOT_DOC @@ -280,7 +280,7 @@ def plot(cls, fig, data, x, y, by=None, **kwargs) -> Tuple[Figure, "Legend"]: return fig, None -class PLOTLYScatterPlot(PLOTLYPlot, ScatterPlot): +class PLOTLYScatterPlot(PLOTLYPlotter, ScatterPlot): @classmethod @APPEND_PLOT_DOC @@ -317,7 +317,7 @@ def plot(cls, fig, data, x, y, by=None, **kwargs) -> Tuple[Figure, "Legend"]: return fig, None -class PLOTLYComplexPlot(ComplexPlot, PLOTLYPlot, ABC): +class PLOTLY_MSPlotter(BaseMSPlotter, PLOTLYPlotter, ABC): def get_line_renderer(self, data, x, y, **kwargs) -> None: return PLOTLYLinePlot(data, x, y, **kwargs) @@ -367,7 +367,7 @@ def _create_tooltips(self): return "
".join(TOOLTIPS), column_stack(custom_hover_data) -class PLOTLYChromatogramPlot(PLOTLYComplexPlot, ChromatogramPlot): +class PLOTLYChromatogramPlot(PLOTLY_MSPlotter, ChromatogramPlot): def _add_peak_boundaries(self, annotation_data, **kwargs): color_gen = ColorGenerator( @@ -390,10 +390,12 @@ def _add_peak_boundaries(self, annotation_data, **kwargs): y=[feature["apexIntensity"], 0, 0, feature["apexIntensity"]], opacity=0.5, line=dict( - color = next(color_gen), - dash=bokeh_line_dash_mapper(self.feature_config.line_type, 'plotly'), - width=self.feature_config.line_width + color=next(color_gen), + dash=bokeh_line_dash_mapper( + self.feature_config.line_type, "plotly" ), + width=self.feature_config.line_width, + ), name=legend_label, ) ) @@ -407,7 +409,7 @@ class PLOTLYMobilogramPlot(PLOTLYChromatogramPlot, MobilogramPlot): pass -class PLOTLYSpectrumPlot(PLOTLYComplexPlot, SpectrumPlot): +class PLOTLYSpectrumPlot(PLOTLY_MSPlotter, SpectrumPlot): def _prepare_data( self, spectrum: DataFrame, y: str, reference_spectrum: DataFrame | None ) -> Tuple[List]: @@ -423,7 +425,7 @@ def _prepare_data( return spectrum, reference_spectrum -class PLOTLYFeatureHeatmapPlot(PLOTLYComplexPlot, FeatureHeatmapPlot): +class PLOTLYFeatureHeatmapPlot(PLOTLY_MSPlotter, FeatureHeatmapPlot): def create_main_plot(self, x, y, z, class_kwargs, other_kwargs): scatterPlot = self.get_scatter_renderer(self.data, x, y, **class_kwargs) @@ -440,7 +442,7 @@ def create_main_plot(self, x, y, z, class_kwargs, other_kwargs): ), **other_kwargs, ) - + if self.annotation_data is not None: self._add_box_boundaries(self.annotation_data) @@ -532,7 +534,7 @@ def combine_plots(self, x_fig, y_fig): # Update yaxis properties fig_m.update_yaxes(title_text=self.zlabel, row=1, col=2) fig_m.update_yaxes(title_text=self.ylabel, row=2, col=1) - + # Remove axes for first quadrant fig_m.update_xaxes(visible=False, row=1, col=1) fig_m.update_yaxes(visible=False, row=1, col=1) @@ -554,11 +556,11 @@ def _add_box_boundaries(self, annotation_data, **kwargs): x1 = feature["rightWidth"] y0 = feature["IM_leftWidth"] y1 = feature["IM_rightWidth"] - + color = next(color_gen) - - if 'name' in annotation_data.columns: - use_name = feature['name'] + + if "name" in annotation_data.columns: + use_name = feature["name"] else: use_name = f"Feature {idx}" if "q_value" in annotation_data.columns: @@ -566,17 +568,25 @@ def _add_box_boundaries(self, annotation_data, **kwargs): else: legend_label = f"{use_name}" self.fig.add_trace( - go.Scatter( - x=[x0, x1, x1, x0, x0], # Start and end at the same point to close the shape - y=[y0, y0, y1, y1, y0], - mode='lines', - fill='none', - opacity=0.5, - line=dict( - color = color, - width=self.feature_config.line_width, - dash=bokeh_line_dash_mapper(self.feature_config.line_type, 'plotly') + go.Scatter( + x=[ + x0, + x1, + x1, + x0, + x0, + ], # Start and end at the same point to close the shape + y=[y0, y0, y1, y1, y0], + mode="lines", + fill="none", + opacity=0.5, + line=dict( + color=color, + width=self.feature_config.line_width, + dash=bokeh_line_dash_mapper( + self.feature_config.line_type, "plotly" ), - name=legend_label - ) - ) \ No newline at end of file + ), + name=legend_label, + ) + ) diff --git a/pyopenms_viz/constants.py b/pyopenms_viz/constants.py index 76d6c976..7236c8f3 100644 --- a/pyopenms_viz/constants.py +++ b/pyopenms_viz/constants.py @@ -11,5 +11,22 @@ ###################### ## Icons -PEAK_BOUNDARY_ICON = Image.open(os.path.normpath(os.path.join(PYOPENMS_VIZ_DIRNAME, 'assets/img/peak_boundary.png'))) -FEATURE_BOUNDARY_ICON = Image.open(os.path.normpath(os.path.join(PYOPENMS_VIZ_DIRNAME, 'assets/img/feature_boundary.png'))) \ No newline at end of file +PEAK_BOUNDARY_ICON = Image.open( + os.path.normpath(os.path.join(PYOPENMS_VIZ_DIRNAME, "assets/img/peak_boundary.png")) +) +FEATURE_BOUNDARY_ICON = Image.open( + os.path.normpath( + os.path.join(PYOPENMS_VIZ_DIRNAME, "assets/img/feature_boundary.png") + ) +) + + +###################### +## Determine if running in SPHINX build +IS_SPHINX_BUILD = False +try: + import sphinx + + IS_SPHINX_BUILD = hasattr(sphinx, "application") +except ImportError: + pass # Not running SPHINX diff --git a/pyopenms_viz/datastructures/MSChromatogram.py b/pyopenms_viz/datastructures/MSChromatogram.py deleted file mode 100644 index 417df95d..00000000 --- a/pyopenms_viz/datastructures/MSChromatogram.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Schema definition for mass spectrometry chromatogram. This is used to store any chromatogram object -""" - -REQUIRED_CHROMATOGRAM_DATAFRAME_COLUMNS = { - "time": "Numeric column representing the retention time (in seconds) of the chromatographic peaks.", - "intensity": "Numeric column representing the intensity (abundance) of the signal at each time point.", -} - -OPTIONAL_METADATA_CHROMATOGRAM_DATAFRAME_COLUMNS = { - "native_id" : "Chromatogram id, necessary if multiple chromatograms are in the same dataframe.", - "chromatogram_type": "Type of chromatogram must be one of: MASS_CHROMATOGRAM, TOTAL_ION_CURRENT_CHROMATOGRAM, SELECTED_ION_CURRENT_CHROMATOGRAM, BASEPEAK_CHROMATOGRAM, SELECTED_ION_MONITORING_CHROMATOGRAM, SELECTED_REACTION_MONITORING_CHROMATOGRAM, ELECTROMAGNETIC_RADIATION_CHROMATOGRAM, ABSORPTION_CHROMATOGRAM, EMISSION_CHROMATOGRAM", - "ms_level": "Integer column indicating the MS level (1 for MS1, 2 for MS2, etc.).", - "sequence": "String column representing the peptide sequence.", - "modified_sequence": "String column representing the modified peptide sequence. Modification can be represented using the UniMod ontology, either with the UniMod Accession (UniMod: 21) or with the UniMod Codename (Phospho).", - "precursor_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the precursor ion.", - "precursor_charge": "Integer column representing the charge state of the precursor ion.", - "product_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the product ion.", - "product_charge": "Integer column representing the charge state of the product ion.", - "annotation": "String column representing the annotation of the spectrum, such as the fragment ion series." -} - diff --git a/pyopenms_viz/datastructures/MSSpectrum.py b/pyopenms_viz/datastructures/MSSpectrum.py deleted file mode 100644 index f69df7ae..00000000 --- a/pyopenms_viz/datastructures/MSSpectrum.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Schema definition for mass spectrometry spectrum. -""" - -REQUIRED_SPECTRUM_DATAFRAME_COLUMNS = { - "mz": "Numeric column representing the mass-to-charge ratio (m/z) values of the peaks in the mass spectrum.", - "intensity": "Numeric column representing the intensity (abundance) of the peaks in the mass spectrum." -} - -OPTIONAL_METADATA_SPECTRUM_DATAFRAME_COLUMNS = { - "native_id": "String column representing the native identifier (i.e. scan number) of the spectrum.", - "ms_level": "Integer column indicating the MS level (1 for MS1, 2 for MS2, etc.).", - "time": "Numeric column representing the retention time (in seconds) of the spectrum.", - "ion_mobility": "Numeric column representing the ion mobility.", - "sequence": "String column representing the peptide sequence.", - "modified_sequence": "String column representing the modified peptide sequence.", - "precursor_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the precursor ion.", - "precursor_charge": "Integer column representing the charge state of the precursor ion.", - "product_mz": "Numeric column representing the mass-to-charge ratio (m/z) of the product ion.", - "product_charge": "Integer column representing the charge state of the product ion.", - "annotation": "String column representing the annotation of the peak, such as the fragment ion series." -} \ No newline at end of file diff --git a/pyopenms_viz/datastructures/SchemaDefinition.py b/pyopenms_viz/datastructures/SchemaDefinition.py deleted file mode 100644 index a9a2d53f..00000000 --- a/pyopenms_viz/datastructures/SchemaDefinition.py +++ /dev/null @@ -1,180 +0,0 @@ -from enum import Enum -from typing import List, Union -import pandas as pd - -""" Outlines Column Types. Other columns are permitted in the class however they will be ignored """ - -class DataType(Enum): - NUMERIC=0 - STRING=1 - -class ColumnType(Enum): - REQUIRED=0 # Plots cannot be generated without these columns - OPTIONAL=1 # Used to add supplementary information to the plot, but not required for plotting the object - -class ColumnSchema: - def __init__(self, - name: Union[str, List[str]], - dataType: DataType = None, - description: str = None, - columnType: ColumnType = ColumnType.OPTIONAL): - - """Define a column schema - - Args: - name (str): Name of column - dataType (DataType, Optional): Datatype of column. Defaults to None. - description (str, Optional): Description of information that column is storing. Defaults to None. - columnType (ColumnType, optional): Whether column is REQUIRED or OPTIONAL. Defaults to OPTIONAL - """ - - self.name = name - self.dataType = dataType - self.description = description - self.columnType = columnType - - def __str__(self): - return f"ColumnSchema: {self.name} ({self.dataType})" - -class DataFrameSchema: - def __init__(self, - name: str, - columns: List[ColumnSchema], - description: str = None): - - """Define a table schema - - Args: - columns (List[ColumnSchema]): List of ColumnSchema objects - description (str, Optional): Description of table. Defaults to None. - """ - self.name = name - self.columns = columns - self.description = description - - def __str__(self): - return f"GenericTableSchema: {self.name}. Required Columns: {[i.name for i in self.getRequiredColumns()]}" - - def getOptionalColumns(self): - return [col for col in self.columns if col.columnType == ColumnType.OPTIONAL] - - def getRequiredColumns(self): - return [col for col in self.columns if col.columnType == ColumnType.REQUIRED] - - def validateDataFrame(self, df): - for col in self.getRequiredColumns(): - # if column name is a list, then exactly one column must be present - if isinstance(col.name, List): - intersect = set(df.columns).intersection(set(col.name)) - if len(intersect) == 0: - raise ValueError(f"Column {col.name} is required for {self.name} schema") - elif len(intersect) > 1: - raise ValueError(f"Only one of {col.name} columns is allowed for {self.name} schema") - nameFound = False - else: # assume column name is a string - if col.name not in df.columns: - raise ValueError(f"Column {col.name} is required for {self.name} schema") - - if col.dataType == DataType.NUMERIC: - assert pd.str.isnumeric(df[col.name]) - - for col in self.getOptionalColumns(): - if col.dataType == DataType.NUMERIC: - assert pd.str.isnumeric(df[col.name]) - return True - - -############################################################################################################ -#### DEFINED CHROMATOGRAM SCHEMAS #### -############################################################################################################ - -REQUIRED_CHROMATOGRAM_COLUMNS = [ - ColumnSchema("time", - dataType=DataType.NUMERIC, - description="Numeric column representing the retention time (in seconds) of the chromatographic peaks.", - columnType=ColumnType.REQUIRED), - - ColumnSchema("intensity", - dataType=DataType.NUMERIC, - description="Numeric column representing the intensity (abundance) of the signal at each time point.", - columnType=ColumnType.REQUIRED)] - -MULTIPLE_CHROMATOGRAMS_COLUMNS = REQUIRED_CHROMATOGRAM_COLUMNS + [ ColumnSchema("label", - dataType=DataType.STRING, - description="Chromatogram label, necessary if multiple chromatograms are in the same dataframe. This column is used to label the chromatograms", columnType=ColumnType.REQUIRED)] - -CHROMATOGRAM = DataFrameSchema(name='chromatogram', columns=REQUIRED_CHROMATOGRAM_COLUMNS, description="Base chromatogram schema") -MULTIPLE_CHROMATOGRAMS = DataFrameSchema(name='multiple chromatograms', columns=REQUIRED_CHROMATOGRAM_COLUMNS + REQUIRED_CHROMATOGRAM_COLUMNS, description="Multiple chromatograms in a single DataFrame schema") - -############################################################################################################ -#### DEFINED SPECTRUM SCHEMAS #### -############################################################################################################ -SPECTRUM_COLUMNS = [ - ColumnSchema("mz", - dataType=DataType.NUMERIC, - description="Numeric column representing the mass-to-charge ratio (m/z) values of the peaks in the mass spectrum.", - columnType=ColumnType.REQUIRED), - - ColumnSchema("intensity", - dataType=DataType.NUMERIC, - description="Numeric column representing the intensity (abundance) of the peaks in the mass spectrum.", - columnType=ColumnType.REQUIRED), - - ColumnSchema("annotation", - dataType=DataType.STRING, - description="String column representing the annotation of the spectrum, such as the fragment ion series.", - columnType=ColumnType.OPTIONAL) -] - -MULTIPLE_SPECTRA_COLUMNS = SPECTRUM_COLUMNS + [ ColumnSchema(name=["label", "native_id"], - dataType=DataType.STRING, - description="Spectrum label, necessary if multiple spectra are in the same dataframe. This column is used to label the spectra", columnType=ColumnType.REQUIRED)] - -SPECTRUM = DataFrameSchema(name='spectrum', columns=SPECTRUM_COLUMNS, description="Base spectrum schema, only a single spectrum is present") -MULTIPLE_SPECTRA = DataFrameSchema(name='multiple spectra', columns=SPECTRUM_COLUMNS + SPECTRUM_COLUMNS, description="Multiple spectra in a single DataFrame") - -############################################################################################################ -#### DEFINED CHROMATOGRAM FEATURE SCHEMAS #### -############################################################################################################ -REQUIRED_CHROMATOGRAM_FEATURE_COLUMNS = [ ColumnSchema(name=['RT', 'rt_apex'], - dataType=DataType.NUMERIC, - description="Numeric column representing the retention time of the peak apex.")] - -# Note: only include columns that could be used in plotting (e.g. exclude area because this will not be plotted) -CHROMATOGRAM_FEATURE = DataFrameSchema(name='chromatogram feature', - columns=[ - ColumnSchema(name=["rt_apex", 'RT'], - dataType=DataType.NUMERIC, - description="Numeric column representing the retention time of the peak apex."), - - ColumnSchema(name=["feature_id", 'id', 'label'], - dataType=DataType.STRING, - description="Represent the Id of the feature", - columnType=ColumnType.OPTIONAL), - - ColumnSchema(name=["left_width", 'left_boundary'], - dataType=DataType.NUMERIC, - description="Numeric column representing the left boundary of the peak.", - columnType=ColumnType.OPTIONAL), - - ColumnSchema(name=["right_width", 'right_boundary'], - dataType=DataType.NUMERIC, - description="Numeric column representing the right boundary of the peak.", - columnType=ColumnType.OPTIONAL), - - ColumnSchema(name=["q_value"], - dataType=DataType.NUMERIC, - description="Numeric column representing the q-value of the peak.", - columnType=ColumnType.OPTIONAL), - - ColumnSchema(name=["rank"], - dataType=DataType.NUMERIC, - description="Numeric column representing the rank of the peak. (1 is best rank)", - columnType=ColumnType.OPTIONAL) - ], - description="Chromatogram feature schema") - -############################################################################################################ -#### DEFINED SPECTRUM FEATURE SCHEMAS #### -############################################################################################################ -##TODO, Is this needed? \ No newline at end of file diff --git a/pyopenms_viz/plotting/__init__.py b/pyopenms_viz/plotting/__init__.py deleted file mode 100644 index a99f028d..00000000 --- a/pyopenms_viz/plotting/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from pandas.plotting._core import PlotAccessor - -__all__ = ["PlotAccessor"] diff --git a/pyopenms_viz/pyopenmsviz.py b/pyopenms_viz/pyopenmsviz.py deleted file mode 100644 index d85b9095..00000000 --- a/pyopenms_viz/pyopenmsviz.py +++ /dev/null @@ -1,14 +0,0 @@ -from pandas.core.frame import DataFrame - -from pyopenms_viz.plotting._core import PlotAccessor - - -class pyopenmsviz: - def __init__(self, data: DataFrame) -> None: - # Validate the input data frame. - if not isinstance(data, DataFrame): - raise TypeError(f"Input data must be a pandas DataFrame, not {type(data)}") - self._data = data - - def plot(self, *args, **kwargs): - return PlotAccessor(self._data)(*args, **kwargs) \ No newline at end of file diff --git a/pyopenms_viz/testing/BokehSnapshotExtension.py b/pyopenms_viz/testing/BokehSnapshotExtension.py deleted file mode 100644 index 07eba69c..00000000 --- a/pyopenms_viz/testing/BokehSnapshotExtension.py +++ /dev/null @@ -1,155 +0,0 @@ -""" -pyopenms_viz/testing/BokehSnapshotExtension -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -""" -from typing import Any -from bokeh.embed import file_html -import json -from syrupy.data import SnapshotCollection -from syrupy.extensions.single_file import SingleFileSnapshotExtension -from syrupy.types import SerializableData -from bokeh.resources import CDN -from html.parser import HTMLParser - -class BokehHTMLParser(HTMLParser): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.recording = False # boolean flag to indicate if we are currently recording the data - self.bokehJson = None # data to extract - - def handle_starttag(self, tag, attrs): - if tag == 'script' and self.bokehJson is None: - attrs_dict = dict(attrs) - if attrs_dict.get('type') == 'application/json': - self.recording = True - - def handle_endtag(self, tag): - if tag == 'script' and self.recording: - self.recording = False - - def handle_data(self, data): - if self.recording and self.bokehJson is None: - self.bokehJson = data - -class BokehSnapshotExtension(SingleFileSnapshotExtension): - """ - Handles Bokeh Snapshots. Snapshots are stored as html files and the bokeh .json output from the html files are compared. - """ - _file_extension = "html" - - def matches(self, *, serialized_data, snapshot_data): - """ - Determine if the serialized data matches the snapshot data. - - Args: - serialized_data: Data produced by the test - snapshot_data: Saved data from a previous test run - - """ - json_snapshot = self.extract_bokeh_json(snapshot_data) - json_serialized = self.extract_bokeh_json(serialized_data) - - # get the keys which store the json - # NOTE: keys are unique identifiers and are not supposed to be equal - # but the json objects they contain should be equal - key_json_snapshot = list(json_snapshot.keys())[0] - key_json_serialized = list(json_serialized.keys())[0] - - return BokehSnapshotExtension.compare_json(json_snapshot[key_json_snapshot], json_serialized[key_json_serialized]) - - def extract_bokeh_json(self, html: str) -> json: - """ - Extract the bokeh json from the html file. - - Args: - html (str): string containing the html data - - Returns: - json: bokeh json found in the html - """ - parser = BokehHTMLParser() - parser.feed(html) - return json.loads(parser.bokehJson) - - @staticmethod - def compare_json(json1, json2): - """ - Compare two bokeh json objects. This function acts recursively - - Args: - json1: first object - json2: second object - - Returns: - bool: True if the objects are equal, False otherwise - """ - if isinstance(json1, dict) and isinstance(json2, dict): - for key in json1.keys(): - if key not in json2: - print(f'Key {key} not in second json') - return False - elif key not in ['id', 'root_ids']: # add keys to ignore here - pass - elif not BokehSnapshotExtension.compare_json(json1[key], json2[key]): - print(f'Values for key {key} not equal') - return False - return True - elif isinstance(json1, list) and isinstance(json2, list): - if len(json1) != len(json2): - print('Lists have different lengths') - return False - # lists are unordered so we need to compare every element one by one - for i in json1: - if isinstance(i, dict): - # find the corresponding dictionary in json2 - for j in json2: - if j['type'] == i['type']: - if not BokehSnapshotExtension.compare_json(i, j): - print(f'Element {i} not equal to {j}') - return False - return True - print(f'Element {i} not in second list') - return False - else: - return json1[i] == json2[i] - return True - else: - if json1 != json2: - print(f'Values not equal: {json1} != {json2}') - return json1 == json2 - - def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str, session_id: str - ): - # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L139 - try: - with open(snapshot_location, 'r') as f: - a = f.read() - return a - except OSError: - return None - - @classmethod - def _write_snapshot_collection( - cls, *, snapshot_collection: SnapshotCollection - ) -> None: - # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L161 - - filepath, data = ( - snapshot_collection.location, - next(iter(snapshot_collection)).data, - ) - with open(filepath, 'w') as f: - f.write(data) - - def serialize(self, data: SerializableData, **kwargs: Any) -> str: - """ - Serialize the bokeh plot as an html string (which is output to a file) - - Args: - data (SerializableData): Data to serialize - - Returns: - str: html string - """ - return file_html(data, CDN) \ No newline at end of file diff --git a/pyopenms_viz/testing/PlotlySnapshotExtension.py b/pyopenms_viz/testing/PlotlySnapshotExtension.py deleted file mode 100644 index 4caee48d..00000000 --- a/pyopenms_viz/testing/PlotlySnapshotExtension.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Any -from syrupy.data import SnapshotCollection -from syrupy.extensions.single_file import SingleFileSnapshotExtension -from syrupy.types import SerializableData -from plotly.io import to_json -import json -import math - -class PlotlySnapshotExtension(SingleFileSnapshotExtension): - """ - Handles Plotly Snapshots. Snapshots are stored as json files and the json output from the files are compared. - """ - _file_extension = "json" - - def matches(self, *, serialized_data, snapshot_data): - json1 = json.loads(serialized_data) - json2 = json.loads(snapshot_data) - return PlotlySnapshotExtension.compare_json(json1, json2) - - @staticmethod - def compare_json(json1, json2) -> bool: - """ - Compare two plotly json objects. This function acts recursively - - Args: - json1: first json - json2: second json - - Returns: - bool: True if the objects are equal, False otherwise - """ - if isinstance(json1, dict) and isinstance(json2, dict): - for key in json1.keys(): - if key not in json2: - print(f'Key {key} not in second json') - return False - if not PlotlySnapshotExtension.compare_json(json1[key], json2[key]): - print(f'Values for key {key} not equal') - return False - return True - elif isinstance(json1, list) and isinstance(json2, list): - if len(json1) != len(json2): - print('Lists have different lengths') - return False - for i, j in zip(json1, json2): - if not PlotlySnapshotExtension.compare_json(i, j): - return False - return True - else: - if isinstance(json1, float): - if not math.isclose(json1, json2): - print(f'Values not equal: {json1} != {json2}') - return False - else: - if json1 != json2: - print(f'Values not equal: {json1} != {json2}') - return False - return True - - def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str, session_id: str - ): - # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L139 - try: - with open(snapshot_location, 'r') as f: - a = f.read() - return a - except OSError: - return None - - @classmethod - def _write_snapshot_collection( - cls, *, snapshot_collection: SnapshotCollection - ) -> None: - # see https://github.com/tophat/syrupy/blob/f4bc8453466af2cfa75cdda1d50d67bc8c4396c3/src/syrupy/extensions/base.py#L161 - - filepath, data = ( - snapshot_collection.location, - next(iter(snapshot_collection)).data, - ) - with open(filepath, 'w') as f: - f.write(data) - - def serialize(self, data: SerializableData, **kwargs: Any) -> str: - """ - Serialize the data to a json string - - Args: - data (SerializableData): plotly data to serialize - - Returns: - str: json string of plotly plot - """ - return to_json(data, pretty=True, engine='json') \ No newline at end of file diff --git a/pyopenms_viz/testing/__init__.py b/pyopenms_viz/testing/__init__.py deleted file mode 100644 index 462624fd..00000000 --- a/pyopenms_viz/testing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -pyopenms_viz/testing -~~~~~~~~~~~~~~~~ - -This package contains classes for testing pyopenms_viz. SnapShotExtension classes are based off of syrupy snapshots -""" - -from .BokehSnapshotExtension import BokehSnapshotExtension -from .PlotlySnapshotExtension import PlotlySnapshotExtension - -__all__ = [ - "BokehSnapshotExtension", - "PlotlySnapshotExtension"] \ No newline at end of file diff --git a/pyopenms_viz/util/_decorators.py b/pyopenms_viz/util/_decorators.py deleted file mode 100644 index ba26458f..00000000 --- a/pyopenms_viz/util/_decorators.py +++ /dev/null @@ -1,32 +0,0 @@ -from dataclasses import fields - -def filter_unexpected_fields(cls): - """ - Decorator function that filters unexpected fields from the keyword arguments passed to the class constructor. - - Args: - cls: The class to decorate. - - Returns: - The decorated class. - - Example: - @filter_unexpected_fields - class MyClass: - def __init__(self, name, age): - self.name = name - self.age = age - - obj = MyClass(name='John', age=25, gender='Male') - # The 'gender' keyword argument will be filtered out and not passed to the class constructor. - """ - original_init = cls.__init__ - - def new_init(self, *args, **kwargs): - expected_fields = {field.name for field in fields(cls)} - cleaned_kwargs = {key: value for key, value in kwargs.items() if key in expected_fields} - print(f"WARNING: Filtered out unexpected fields: {set(kwargs) - expected_fields}") - original_init(self, *args, **cleaned_kwargs) - - cls.__init__ = new_init - return cls \ No newline at end of file diff --git a/setup.py b/setup.py index 4b5a5f84..0a354ac3 100644 --- a/setup.py +++ b/setup.py @@ -22,9 +22,9 @@ ], entry_points={ "pandas_plotting_backends": [ - "pomsvib = pyopenms_viz.plotting._bokeh", - "pomsvip = pyopenms_viz.plotting._plotly", - "pomsvim = pyopenms_viz.plotting._matplotlib" + "MSBokeh = pyopenms_viz._bokeh", + "MSPlotly = pyopenms_viz._plotly", + "MSMatplotlib = pyopenms_viz._matplotlib", ], }, classifiers=[