Skip to content

Commit

Permalink
Cleaning up and documenting plot.py
Browse files Browse the repository at this point in the history
Signed-off-by: bvandekerkhof <[email protected]>
  • Loading branch information
bvandekerkhof committed Nov 1, 2024
1 parent b9ea923 commit 2f4f0fb
Showing 1 changed file with 48 additions and 36 deletions.
84 changes: 48 additions & 36 deletions src/pyelq/plotting/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,28 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Type, Union

import numpy as np
import pandas as pd
import plotly.figure_factory as ff
import plotly.graph_objects as go
from geojson import Feature, FeatureCollection
from openmcmc.mcmc import MCMC
from scipy.ndimage import label
from shapely import geometry

from pyelq.component.background import TemporalBackground
from pyelq.component.error_model import ErrorModel
from pyelq.component.offset import PerSensor
from pyelq.component.source_model import SlabAndSpike, SourceModel
from pyelq.coordinate_system import ENU, LLA
from pyelq.coordinate_system import LLA
from pyelq.dispersion_model.gaussian_plume import GaussianPlume
from pyelq.sensor.sensor import Sensor, SensorGroup
from pyelq.support_functions.post_processing import is_regularly_spaced, calculate_rectangular_statistics, create_LLA_polygons_from_XY_points
from pyelq.support_functions.post_processing import (
calculate_rectangular_statistics,
create_lla_polygons_from_xy_points,
is_regularly_spaced,
)

if TYPE_CHECKING:
from pyelq.model import ELQModel
Expand Down Expand Up @@ -128,7 +131,7 @@ def plot_quantiles_from_array(
return fig


def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs) -> dict:
def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs: Any) -> dict:
"""Specification of different traces of single variables.
Provides all details for plots where we want to plot a single variable as a line plot. Based on the object_to_plot
Expand All @@ -137,7 +140,7 @@ def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel
Args:
object_to_plot (Union[Type[SlabAndSpike], SourceModel, MCMC]): Object which we want to plot a single
variable from
**kwargs (dict): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not
**kwargs (Any): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not
applicable to all.
Returns:
Expand Down Expand Up @@ -202,7 +205,7 @@ def create_trace_specifics(object_to_plot: Union[Type[SlabAndSpike], SourceModel


def create_plot_specifics(
object_to_plot: Union[ErrorModel, PerSensor, MCMC], sensor_object: SensorGroup, plot_type: str = "", **kwargs
object_to_plot: Union[ErrorModel, PerSensor, MCMC], sensor_object: SensorGroup, plot_type: str = "", **kwargs: Any
) -> dict:
"""Specification of different traces where we want to plot a trace for each sensor.
Expand All @@ -217,7 +220,7 @@ def create_plot_specifics(
object_to_plot (Union[ErrorModel, PerSensor, MCMC]): Object which we want to plot a single variable from
sensor_object (SensorGroup): SensorGroup object associated with the object_to_plot
plot_type (str, optional): String specifying either a line or a box plot.
**kwargs (dict): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not
**kwargs (Any): Additional key word arguments, e.g. burn_in or dict_key, used in some specific plots but not
applicable to all.
Returns:
Expand Down Expand Up @@ -311,7 +314,7 @@ def plot_single_scatter(
y_values: np.ndarray,
color: str,
name: str,
**kwargs,
**kwargs: Any,
) -> go.Figure:
"""Plots a single scatter trace on the supplied figure object.
Expand All @@ -321,8 +324,8 @@ def plot_single_scatter(
y_values (np.ndarray): Numpy array containing the y-values to use in plotting.
color (str): RGB color string to use for this trace.
name (str): String name to show in the legend.
**kwargs (dict): Additional key word arguments, e.g. burn_in, legend_group, show_legend, used in some specific plots
but not applicable to all.
**kwargs (Any): Additional key word arguments, e.g. burn_in, legend_group, show_legend, used in some specific
plots but not applicable to all.
Returns:
fig (go.Figure): Plotly figure with the trace added to it.
Expand Down Expand Up @@ -378,7 +381,7 @@ def plot_single_box(fig: go.Figure, y_values: np.ndarray, color: str, name: str)


def plot_polygons_on_map(
polygons: Union[np.ndarray, list], values: np.ndarray, opacity: float, map_color_scale: str, **kwargs
polygons: Union[np.ndarray, list], values: np.ndarray, opacity: float, map_color_scale: str, **kwargs: Any
) -> go.Choroplethmapbox:
"""Plot a set of polygons on a map.
Expand All @@ -388,8 +391,8 @@ def plot_polygons_on_map(
used in coloring the polygons on the map.
opacity (float): Float between 0 and 1 specifying the opacity of the polygon fill color.
map_color_scale (str): The string which defines which plotly color scale.
**kwargs (dict): Additional key word arguments which can be passed on the go.Choroplethmapbox object (will override
the default values as specified in this function)
**kwargs (Any): Additional key word arguments which can be passed on the go.Choroplethmapbox object
(will override the default values as specified in this function)
Returns:
trace: go.Choroplethmapbox trace with the colored polygons which can be added to a go.Figure object.
Expand Down Expand Up @@ -578,16 +581,16 @@ def show_all(self, renderer="browser"):
for fig in self.figure_dict.values():
fig.show(renderer=renderer)

def plot_single_trace(self, object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs):
def plot_single_trace(self, object_to_plot: Union[Type[SlabAndSpike], SourceModel, MCMC], **kwargs: Any):
"""Plotting a trace of a single variable.
Depending on the object to plot it creates a figure which is stored in the figure_dict attribute.
First it grabs all the specifics needed for the plot and then plots the trace.
Args:
object_to_plot (Union[Type[SlabAndSpike], SourceModel, MCMC]): The object from which to plot a variable
**kwargs (dict): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in some
specific plots but not applicable to all.
**kwargs (Any): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in
some specific plots but not applicable to all.
"""
plot_specifics = create_trace_specifics(object_to_plot=object_to_plot, **kwargs)
Expand Down Expand Up @@ -633,7 +636,7 @@ def plot_trace_per_sensor(
object_to_plot: Union[ErrorModel, PerSensor, MCMC],
sensor_object: Union[SensorGroup, Sensor],
plot_type: str,
**kwargs,
**kwargs: Any,
):
"""Plotting a trace of a single variable per sensor.
Expand All @@ -644,8 +647,8 @@ def plot_trace_per_sensor(
object_to_plot (Union[ErrorModel, PerSensor, MCMC]): The object which to plot a variable from
sensor_object (Union[SensorGroup, Sensor]): Sensor object associated with the object_to_plot
plot_type (str): String specifying a line or box plot.
**kwargs (dict): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in some
specific plots but not applicable to all.
**kwargs (Any): Additional key word arguments, e.g. burn_in, legend_group, show_legend, dict_key, used in
some specific plots but not applicable to all.
"""
if isinstance(sensor_object, Sensor):
Expand Down Expand Up @@ -790,7 +793,7 @@ def plot_fitted_values_per_sensor(

self.figure_dict[dict_key] = fig

def plot_emission_rate_estimates(self, source_model_object, y_axis_type="linear", **kwargs):
def plot_emission_rate_estimates(self, source_model_object, y_axis_type="linear", **kwargs: Any):
"""Plot the emission rate estimates source model object against MCMC iteration.
Based on the inputs it plots the results of the mcmc analysis, being the estimated emission rate values for
Expand All @@ -804,7 +807,7 @@ def plot_emission_rate_estimates(self, source_model_object, y_axis_type="linear"
Args:
source_model_object (SourceModel): Source model object which contains the estimated emission rate estimates.
y_axis_type (str, optional): String to indicate whether the y-axis should be linear of log scale.
**kwargs (dict): Additional key word arguments, e.g. burn_in, dict_key, used in some specific plots but not
**kwargs (Any): Additional key word arguments, e.g. burn_in, dict_key, used in some specific plots but not
applicable to all.
"""
Expand Down Expand Up @@ -913,7 +916,7 @@ def create_empty_mapbox_figure(self, dict_key: str = "map_plot") -> None:
)

def plot_values_on_map(
self, dict_key: str, coordinates: LLA, values: np.ndarray, aggregate_function: Callable = np.sum, **kwargs
self, dict_key: str, coordinates: LLA, values: np.ndarray, aggregate_function: Callable = np.sum, **kwargs: Any
):
"""Plot values on a map based on coordinates.
Expand All @@ -923,7 +926,7 @@ def plot_values_on_map(
values (np.ndarray): Numpy array of values consistent with coordinates to plot on the map
aggregate_function (Callable, optional): Function which to apply on the data in each hexagonal bin to
aggregate the data and visualise the result.
**kwargs (dict): Additional keyword arguments for plotting behaviour (opacity, map_color_scale, num_hexagons,
**kwargs (Any): Additional keyword arguments for plotting behaviour (opacity, map_color_scale, num_hexagons,
show_positions)
"""
Expand Down Expand Up @@ -976,16 +979,31 @@ def plot_quantification_results_on_map(
burn_in: int = 0,
show_summary_results: bool = True,
):
"""Placeholder for the quantification plots."""
"""Function to create a map with the quantification results of the model object.
This function takes the ELQModel object and calculates the statistics for the quantification results. It then
populates the figure dictionary with three different maps showing the normalized count, median emission rate
and the inter-quartile range of the emission rate estimates.
Args:
model_object (ELQModel): ELQModel object containing the quantification results
bin_size_x (float, optional): Size of the bins in the x-direction. Defaults to 1.
bin_size_y (float, optional): Size of the bins in the y-direction. Defaults to 1.
normalized_count_limit (float, optional): Limit for the normalized count to show on the map.
Defaults to 0.005.
burn_in (int, optional): Number of burn-in iterations to discard before calculating the statistics.
Defaults to 0.
show_summary_results (bool, optional): Flag to show the summary results on the map. Defaults to True.
"""
ref_latitude = model_object.components["source"].dispersion_model.source_map.location.ref_latitude
ref_longitude = model_object.components["source"].dispersion_model.source_map.location.ref_longitude
ref_altitude = model_object.components["source"].dispersion_model.source_map.location.ref_altitude

datetime_min_string = model_object.sensor_object.time.min().strftime("%d-%b-%Y, %H:%M:%S")
datetime_max_string = model_object.sensor_object.time.max().strftime("%d-%b-%Y, %H:%M:%S")

result_weighted, overall_count, normalized_count, count_boolean, enu_points, summary_result = (
result_weighted, _, normalized_count, count_boolean, enu_points, summary_result = (
calculate_rectangular_statistics(
model_object=model_object,
bin_size_x=bin_size_x,
Expand All @@ -995,7 +1013,7 @@ def plot_quantification_results_on_map(
)
)

polygons = create_LLA_polygons_from_XY_points(
polygons = create_lla_polygons_from_xy_points(
points_array=enu_points,
ref_latitude=ref_latitude,
ref_longitude=ref_longitude,
Expand Down Expand Up @@ -1132,17 +1150,11 @@ def create_summary_trace(
) -> go.Scattermapbox:
"""Helper function to create the summary information to plot on top of map type plots.
We identify all blobs of estimates which appear close together on the map by looking at connected pixels in the
count_boolean array. Next we find the summary statistics for all estimates in that blob like overall median and
IQR estimate, mean location and the likelihood of that blob.
When multiple sources are present in the same blob at the same iteration we first sum those emission rate
estimates before taking the median.
The summary statistics are also printed out on screen.
We use the summary result calculated through the support functions module to create a trace which contains
the summary information for each source location.
Args:
summary_result
summary_result (pd.DataFrame): DataFrame containing the summary information for each source location.
Returns:
summary_trace (go.Scattermapbox): Trace with summary information to plot on top of map type plots.
Expand Down

0 comments on commit 2f4f0fb

Please sign in to comment.