diff --git a/src/pyelq/plotting/plot.py b/src/pyelq/plotting/plot.py index 609394c..99ac722 100644 --- a/src/pyelq/plotting/plot.py +++ b/src/pyelq/plotting/plot.py @@ -12,7 +12,7 @@ import warnings from copy import deepcopy from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Type, Union import numpy as np import pandas as pd @@ -20,17 +20,20 @@ import plotly.graph_objects as go from geojson import Feature, FeatureCollection from openmcmc.mcmc import MCMC -from scipy.ndimage import label from shapely import geometry from pyelq.component.background import TemporalBackground from pyelq.component.error_model import ErrorModel from pyelq.component.offset import PerSensor from pyelq.component.source_model import SlabAndSpike, SourceModel -from pyelq.coordinate_system import ENU, LLA +from pyelq.coordinate_system import LLA from pyelq.dispersion_model.gaussian_plume import GaussianPlume from pyelq.sensor.sensor import Sensor, SensorGroup -from pyelq.support_functions.post_processing import is_regularly_spaced, calculate_rectangular_statistics, create_LLA_polygons_from_XY_points +from pyelq.support_functions.post_processing import ( + calculate_rectangular_statistics, + create_lla_polygons_from_xy_points, + is_regularly_spaced, +) if TYPE_CHECKING: from pyelq.model import ELQModel @@ -128,7 +131,7 @@ def plot_quantiles_from_array( return fig -def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs) -> dict: +def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs: Any) -> dict: """Specification of different traces of single variables. Provides all details for plots where we want to plot a single variable as a line plot. Based on the object_to_plot @@ -137,7 +140,7 @@ def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel Args: object_to_plot (Union[Type[SlabAndSpike], SourceModel, MCMC]): Object which we want to plot a single variable from - **kwargs (dict): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not + **kwargs (Any): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not applicable to all. Returns: @@ -202,7 +205,7 @@ def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel def create_plot_specifics( - object_to_plot: Union[ErrorModel, PerSensor, MCMC], sensor_object: SensorGroup, plot_type: str = "", **kwargs + object_to_plot: Union[ErrorModel, PerSensor, MCMC], sensor_object: SensorGroup, plot_type: str = "", **kwargs: Any ) -> dict: """Specification of different traces where we want to plot a trace for each sensor. @@ -217,7 +220,7 @@ def create_plot_specifics( object_to_plot (Union[ErrorModel, PerSensor, MCMC]): Object which we want to plot a single variable from sensor_object (SensorGroup): SensorGroup object associated with the object_to_plot plot_type (str, optional): String specifying either a line or a box plot. - **kwargs (dict): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not + **kwargs (Any): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not applicable to all. Returns: @@ -311,7 +314,7 @@ def plot_single_scatter( y_values: np.ndarray, color: str, name: str, - **kwargs, + **kwargs: Any, ) -> go.Figure: """Plots a single scatter trace on the supplied figure object. @@ -321,8 +324,8 @@ def plot_single_scatter( y_values (np.ndarray): Numpy array containing the y-values to use in plotting. color (str): RGB color string to use for this trace. name (str): String name to show in the legend. - **kwargs (dict): Additional key word arguments, e.g. burn_in, legend_group, show_legend, used in some specific plots - but not applicable to all. + **kwargs (Any): Additional key word arguments, e.g. burn_in, legend_group, show_legend, used in some specific + plots but not applicable to all. Returns: fig (go.Figure): Plotly figure with the trace added to it. @@ -378,7 +381,7 @@ def plot_single_box(fig: go.Figure, y_values: np.ndarray, color: str, name: str) def plot_polygons_on_map( - polygons: Union[np.ndarray, list], values: np.ndarray, opacity: float, map_color_scale: str, **kwargs + polygons: Union[np.ndarray, list], values: np.ndarray, opacity: float, map_color_scale: str, **kwargs: Any ) -> go.Choroplethmapbox: """Plot a set of polygons on a map. @@ -388,8 +391,8 @@ def plot_polygons_on_map( used in coloring the polygons on the map. opacity (float): Float between 0 and 1 specifying the opacity of the polygon fill color. map_color_scale (str): The string which defines which plotly color scale. - **kwargs (dict): Additional key word arguments which can be passed on the go.Choroplethmapbox object (will override - the default values as specified in this function) + **kwargs (Any): Additional key word arguments which can be passed on the go.Choroplethmapbox object + (will override the default values as specified in this function) Returns: trace: go.Choroplethmapbox trace with the colored polygons which can be added to a go.Figure object. @@ -578,7 +581,7 @@ def show_all(self, renderer="browser"): for fig in self.figure_dict.values(): fig.show(renderer=renderer) - def plot_single_trace(self, object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs): + def plot_single_trace(self, object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs: Any): """Plotting a trace of a single variable. Depending on the object to plot it creates a figure which is stored in the figure_dict attribute. @@ -586,8 +589,8 @@ def plot_single_trace(self, object_to_plot: Union[Type[SlabAndSpike], SourceMode Args: object_to_plot (Union[Type[SlabAndSpike], SourceModel, MCMC]): The object from which to plot a variable - **kwargs (dict): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in some - specific plots but not applicable to all. + **kwargs (Any): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in + some specific plots but not applicable to all. """ plot_specifics = create_trace_specifics(object_to_plot=object_to_plot, **kwargs) @@ -633,7 +636,7 @@ def plot_trace_per_sensor( object_to_plot: Union[ErrorModel, PerSensor, MCMC], sensor_object: Union[SensorGroup, Sensor], plot_type: str, - **kwargs, + **kwargs: Any, ): """Plotting a trace of a single variable per sensor. @@ -644,8 +647,8 @@ def plot_trace_per_sensor( object_to_plot (Union[ErrorModel, PerSensor, MCMC]): The object which to plot a variable from sensor_object (Union[SensorGroup, Sensor]): Sensor object associated with the object_to_plot plot_type (str): String specifying a line or box plot. - **kwargs (dict): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in some - specific plots but not applicable to all. + **kwargs (Any): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in + some specific plots but not applicable to all. """ if isinstance(sensor_object, Sensor): @@ -790,7 +793,7 @@ def plot_fitted_values_per_sensor( self.figure_dict[dict_key] = fig - def plot_emission_rate_estimates(self, source_model_object, y_axis_type="linear", **kwargs): + def plot_emission_rate_estimates(self, source_model_object, y_axis_type="linear", **kwargs: Any): """Plot the emission rate estimates source model object against MCMC iteration. Based on the inputs it plots the results of the mcmc analysis, being the estimated emission rate values for @@ -804,7 +807,7 @@ def plot_emission_rate_estimates(self, source_model_object, y_axis_type="linear" Args: source_model_object (SourceModel): Source model object which contains the estimated emission rate estimates. y_axis_type (str, optional): String to indicate whether the y-axis should be linear of log scale. - **kwargs (dict): Additional key word arguments, e.g. burn_in, dict_key, used in some specific plots but not + **kwargs (Any): Additional key word arguments, e.g. burn_in, dict_key, used in some specific plots but not applicable to all. """ @@ -913,7 +916,7 @@ def create_empty_mapbox_figure(self, dict_key: str = "map_plot") -> None: ) def plot_values_on_map( - self, dict_key: str, coordinates: LLA, values: np.ndarray, aggregate_function: Callable = np.sum, **kwargs + self, dict_key: str, coordinates: LLA, values: np.ndarray, aggregate_function: Callable = np.sum, **kwargs: Any ): """Plot values on a map based on coordinates. @@ -923,7 +926,7 @@ def plot_values_on_map( values (np.ndarray): Numpy array of values consistent with coordinates to plot on the map aggregate_function (Callable, optional): Function which to apply on the data in each hexagonal bin to aggregate the data and visualise the result. - **kwargs (dict): Additional keyword arguments for plotting behaviour (opacity, map_color_scale, num_hexagons, + **kwargs (Any): Additional keyword arguments for plotting behaviour (opacity, map_color_scale, num_hexagons, show_positions) """ @@ -976,8 +979,23 @@ def plot_quantification_results_on_map( burn_in: int = 0, show_summary_results: bool = True, ): - """Placeholder for the quantification plots.""" + """Function to create a map with the quantification results of the model object. + This function takes the ELQModel object and calculates the statistics for the quantification results. It then + populates the figure dictionary with three different maps showing the normalized count, median emission rate + and the inter-quartile range of the emission rate estimates. + + Args: + model_object (ELQModel): ELQModel object containing the quantification results + bin_size_x (float, optional): Size of the bins in the x-direction. Defaults to 1. + bin_size_y (float, optional): Size of the bins in the y-direction. Defaults to 1. + normalized_count_limit (float, optional): Limit for the normalized count to show on the map. + Defaults to 0.005. + burn_in (int, optional): Number of burn-in iterations to discard before calculating the statistics. + Defaults to 0. + show_summary_results (bool, optional): Flag to show the summary results on the map. Defaults to True. + + """ ref_latitude = model_object.components["source"].dispersion_model.source_map.location.ref_latitude ref_longitude = model_object.components["source"].dispersion_model.source_map.location.ref_longitude ref_altitude = model_object.components["source"].dispersion_model.source_map.location.ref_altitude @@ -985,7 +1003,7 @@ def plot_quantification_results_on_map( datetime_min_string = model_object.sensor_object.time.min().strftime("%d-%b-%Y, %H:%M:%S") datetime_max_string = model_object.sensor_object.time.max().strftime("%d-%b-%Y, %H:%M:%S") - result_weighted, overall_count, normalized_count, count_boolean, enu_points, summary_result = ( + result_weighted, _, normalized_count, count_boolean, enu_points, summary_result = ( calculate_rectangular_statistics( model_object=model_object, bin_size_x=bin_size_x, @@ -995,7 +1013,7 @@ def plot_quantification_results_on_map( ) ) - polygons = create_LLA_polygons_from_XY_points( + polygons = create_lla_polygons_from_xy_points( points_array=enu_points, ref_latitude=ref_latitude, ref_longitude=ref_longitude, @@ -1132,17 +1150,11 @@ def create_summary_trace( ) -> go.Scattermapbox: """Helper function to create the summary information to plot on top of map type plots. - We identify all blobs of estimates which appear close together on the map by looking at connected pixels in the - count_boolean array. Next we find the summary statistics for all estimates in that blob like overall median and - IQR estimate, mean location and the likelihood of that blob. - - When multiple sources are present in the same blob at the same iteration we first sum those emission rate - estimates before taking the median. - - The summary statistics are also printed out on screen. + We use the summary result calculated through the support functions module to create a trace which contains + the summary information for each source location. Args: - summary_result + summary_result (pd.DataFrame): DataFrame containing the summary information for each source location. Returns: summary_trace (go.Scattermapbox): Trace with summary information to plot on top of map type plots.