Skip to content

Commit

Permalink
Draft add 1-var interactivity.
Browse files Browse the repository at this point in the history
  • Loading branch information
trentmc authored and calina-c committed Apr 26, 2024
1 parent e0a943c commit 09c02ba
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 39 deletions.
24 changes: 21 additions & 3 deletions pdr_backend/aimodel/aimodel_plotdata.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import List
from typing import List, Optional

from enforce_typing import enforce_types
import numpy as np

from pdr_backend.aimodel.aimodel import Aimodel


@enforce_types
class AimodelPlotdata:
"""Simple class to manage many inputs going into plot_model."""

Expand All @@ -17,14 +16,18 @@ def __init__(
ytrue_train: np.ndarray,
colnames: List[str],
slicing_x: np.ndarray,
sweep_vars: Optional[List[int]] = None,
):
"""
@arguments
model -- Aimodel
X_train -- 2d array [sample_i, var_i]:cont_value -- model trn inputs
ytrue_train -- 1d array [sample_i]:bool_value -- model trn outputs
colnames -- [var_i]:str -- name for each of the X inputs
slicing_x -- arrat [dim_i]:floatval - when >2 dims, plot about this pt
slicing_x -- array [var_i]:floatval - values for non-sweep vars
sweep_vars -- list with [sweepvar_i] or [sweepvar_i, sweepvar_j]
-- If 1 entry, do line plot (1 var), where y-axis is response
-- If 2 entries, do contour plot (2 vars), where z-axis is response
"""
# preconditions
assert X_train.shape[1] == len(colnames) == slicing_x.shape[0], (
Expand All @@ -36,15 +39,30 @@ def __init__(
X_train.shape[0],
ytrue_train.shape[0],
)
assert sweep_vars is None or len(sweep_vars) in [1, 2]

# set values
self.model = model
self.X_train = X_train
self.ytrue_train = ytrue_train
self.colnames = colnames
self.slicing_x = slicing_x
self.sweep_vars = sweep_vars

@property
@enforce_types
def n(self) -> int:
"""Number of input dimensions == # columns in X"""
return self.X_train.shape[1]

@property
@enforce_types
def n_sweep(self) -> int:
"""Number of variables to sweep in the plot"""
if self.sweep_vars is None:
return 0

if self.n == 1:
return 1

return len(self.sweep_vars)
89 changes: 58 additions & 31 deletions pdr_backend/aimodel/aimodel_plotter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import numpy as np
import plotly.graph_objects as go
from enforce_typing import enforce_types
Expand All @@ -8,31 +6,20 @@


@enforce_types
def plot_aimodel_response(
aimodel_plotdata: AimodelPlotdata,
label: Optional[str] = None,
):
def plot_aimodel_response(aimodel_plotdata: AimodelPlotdata):
"""
@description
Plot the model response in a line plot (1 var) contour plot (>1 vars)
Plot the model response in a line plot (1 var) or contour plot (2 vars).
And overlay X-data. (Training data or otherwise.)
If the model has >2 vars, it plots the 2 most important vars.
@arguments
aimodel_plotdata -- holds:
model -- Aimodel
X_train -- array [sample_i][var_i]:floatval -- trn model inputs (or other)
ytrue_train -- array [sample_i]:boolval -- trn model outputs (or other)
colnames -- list [var_i]:X_column_name
slicing_x -- arrat [var_i]:floatval - when >2 dims, plot about this pt
fig_ax -- None or (fig, ax) to easily embed into existing plot
legend_loc -- eg "upper left". Applies only to contour plots.
"""
if aimodel_plotdata.n == 1:
return _plot_aimodel_lineplot(aimodel_plotdata)
d = aimodel_plotdata
# assert d.n_sweep in [1, 2]

if d.n_sweep == 1 and d.n == 1 or d.n == 1:
return _plot_aimodel_lineplot_1var(aimodel_plotdata)

if label is not None:
return _plot_aimodel_lineplot(aimodel_plotdata, label=label)
if d.n_sweep == 1 and d.n > 1:
return _plot_aimodel_lineplot_nvars(aimodel_plotdata)

return _plot_aimodel_contour(aimodel_plotdata)

Expand All @@ -41,25 +28,22 @@ def plot_aimodel_response(


@enforce_types
def _plot_aimodel_lineplot(
aimodel_plotdata: AimodelPlotdata, label: Optional[str] = None
):
def _plot_aimodel_lineplot_1var(aimodel_plotdata: AimodelPlotdata):
"""
@description
Plot the model, when there's 1 input x-var. Use a line plot.
Do a 1d lineplot, when exactly 1 input x-var
Will fail if not 1 var.
Because one var total, we can show more info of true-vs-actual
"""
# aimodel data
assert aimodel_plotdata.n == 1 or label in aimodel_plotdata.colnames
d = aimodel_plotdata
assert d.n == 1
assert d.n_sweep == 1
X, ytrue = d.X_train, d.ytrue_train

label_index = d.colnames.index(label) if label is not None else 0
x = X[:, label_index]
x = X[:, 0]
N = len(x)

# TODO: fix err

# calc mesh_X = uniform grid
mesh_x = np.linspace(min(x), max(x), 200)
mesh_N = len(mesh_x)
Expand Down Expand Up @@ -139,6 +123,49 @@ def _plot_aimodel_lineplot(
return fig_bars


@enforce_types
def _plot_aimodel_lineplot_nvars(aimodel_plotdata: AimodelPlotdata):
"""
@description
Do a 1d lineplot, when >1 input x-var, and we have chosen the var.
Because >1 var total, we can show more info of true-vs-actual
"""
# input data
d = aimodel_plotdata
assert d.n >= 1
assert d.n_sweep == 1

# construct sweep_x
sweepvar_i = d.sweep_vars[0] # type: ignore[index]
mn_x, mx_x = min(d.X_train[:, sweepvar_i]), max(d.X_train[:, sweepvar_i])
N = 200
sweep_x = np.linspace(mn_x, mx_x, N)

# construct X
X = np.empty((N, d.n), dtype=float)
X[:, sweepvar_i] = sweep_x
for var_i in range(d.n):
if var_i == sweepvar_i:
continue
X[:, var_i] = d.slicing_x[var_i]

# calc model response
yptrue = d.model.predict_ptrue(X) # [sample_i]: prob_of_being_true

# line plot: model response surface
fig_line = go.Figure(
data=go.Scatter(
x=sweep_x,
y=yptrue,
mode="lines",
line={"color": "gray"},
name="model prob(true)",
)
)

return fig_line


@enforce_types
def _plot_aimodel_contour(
aimodel_plotdata: AimodelPlotdata,
Expand Down
61 changes: 59 additions & 2 deletions pdr_backend/aimodel/test/test_aimodel_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_aimodel_accuracy_from_create_xy():


@enforce_types
def test_aimodel_factory_1var_main():
def test_aimodel_factory_1varmodel_lineplot():
"""1 input var. It will plot that var on both axes"""
# settings, factory
ss = AimodelSS(aimodel_ss_test_dict(approach="LinearLogistic"))
Expand All @@ -159,7 +159,64 @@ def test_aimodel_factory_1var_main():
# plot
colnames = ["x0"]
slicing_x = np.array([0.1]) # arbitrary
aimodel_plotdata = AimodelPlotdata(model, X, ytrue, colnames, slicing_x)
sweep_vars = [0]
aimodel_plotdata = AimodelPlotdata(
model,
X,
ytrue,
colnames,
slicing_x,
sweep_vars,
)
figure = plot_aimodel_response(aimodel_plotdata)
assert isinstance(figure, Figure)

if SHOW_PLOT:
figure.show()


@enforce_types
def test_aimodel_factory_5varmodel_lineplot():
"""5 input vars; sweep 1 var."""
# settings, factory
ss = AimodelSS(aimodel_ss_test_dict(approach="LinearLogistic"))
factory = AimodelFactory(ss)

# data
N = 1000
mn, mx = -10.0, +10.0
X = np.random.uniform(mn, mx, (N, 5))
ycont = (
10.0
+ 0.1 * X[:, 0]
+ 1.0 * X[:, 1]
+ 2.0 * X[:, 2]
+ 3.0 * X[:, 3]
+ 4.0 * X[:, 4]
)
y_thr = np.average(ycont) # avg gives good class balance
ytrue = ycont > y_thr

# build model
model = factory.build(X, ytrue, show_warnings=False)

# test variable importances
imps = model.importance_per_var()
assert len(imps) == 5
assert imps[0] < imps[1] < imps[2] < imps[3] < imps[4]

# plot
colnames = ["x0", "x1", "x2", "x3", "x4"]
slicing_x = np.array([0.1] * 5) # arbitrary
sweep_vars = [2] # x2
aimodel_plotdata = AimodelPlotdata(
model,
X,
ytrue,
colnames,
slicing_x,
sweep_vars,
)
figure = plot_aimodel_response(aimodel_plotdata)
assert isinstance(figure, Figure)

Expand Down
5 changes: 3 additions & 2 deletions pdr_dash_plots/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ def get_figures_by_state(sim_plotter: SimPlotter, clickData=None):
elif key == "aimodel_response":
func_name = getattr(aimodel_plotter, f"plot_{key}")
label = clickData["points"][0]["y"] if clickData else None
fig = aimodel_plotter.plot_aimodel_response(
sim_plotter.aimodel_plotdata, label
sim_plotter.aimodel_plotdata.sweep_vars = (
[sim_plotter.aimodel_plotdata.colnames.index(label)] if label else None
)
fig = aimodel_plotter.plot_aimodel_response(sim_plotter.aimodel_plotdata)
else:
func_name = getattr(aimodel_plotter, f"plot_{key}")
fig = func_name(sim_plotter.aimodel_plotdata)
Expand Down
2 changes: 1 addition & 1 deletion sim_plots
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ from pdr_dash_plots.callbacks import get_callbacks
from pdr_dash_plots.view_elements import empty_graphs_template
from pdr_dash_plots.util import get_all_run_names

# TODO: handle clickdata in varimps callback
# TODO: handle deselection of varimps!
# TODO: CSS/HTML layout tweaks

app = Dash(__name__)
Expand Down

0 comments on commit 09c02ba

Please sign in to comment.