From a46dae94059b8067de0a8b65fea42e3a1d4a62af Mon Sep 17 00:00:00 2001 From: Calina Cenan Date: Thu, 2 May 2024 07:05:19 +0000 Subject: [PATCH] Plot log loss. --- pdr_backend/sim/dash_plots/view_elements.py | 53 +++++++-------------- pdr_backend/sim/sim_plotter.py | 19 ++++++++ pdr_backend/sim/test/test_dash_plots.py | 9 ++-- 3 files changed, 39 insertions(+), 42 deletions(-) diff --git a/pdr_backend/sim/dash_plots/view_elements.py b/pdr_backend/sim/dash_plots/view_elements.py index 5d88fea38..c82da8732 100644 --- a/pdr_backend/sim/dash_plots/view_elements.py +++ b/pdr_backend/sim/dash_plots/view_elements.py @@ -14,6 +14,7 @@ "aimodel_varimps", "aimodel_response", "f1_precision_recall_vs_time", + "log_loss_vs_time", ] empty_slider = dcc.Slider( @@ -59,49 +60,27 @@ def get_header_elements(run_id, st, ts): ] +def side_by_side_graphs(figures, name1, name2): + return html.Div( + [ + dcc.Graph(figure=figures[name1], id=name1, style={"width": "50%"}), + dcc.Graph(figure=figures[name2], id=name2, style={"width": "50%"}), + ], + style={"display": "flex", "justifyContent": "space-between"}, + ) + + def arrange_figures(figures): return [ + side_by_side_graphs(figures, "pdr_profit_vs_time", "trader_profit_vs_time"), html.Div( [ - dcc.Graph(figure=figures["pdr_profit_vs_time"], style={"width": "50%"}), - dcc.Graph( - figure=figures["trader_profit_vs_time"], style={"width": "50%"} - ), - ], - style={"display": "flex", "justifyContent": "space-between"}, - ), - html.Div( - [ - dcc.Graph(figure=figures["accuracy_vs_time"]), - ] - ), - html.Div( - [ - dcc.Graph( - figure=figures["pdr_profit_vs_ptrue"], style={"width": "50%"} - ), - dcc.Graph( - figure=figures["trader_profit_vs_ptrue"], style={"width": "50%"} - ), - ], - style={"display": "flex", "justifyContent": "space-between"}, - ), - html.Div( - [ - dcc.Graph( - figure=figures["aimodel_varimps"], - id="aimodel_varimps", - style={"width": "50%"}, - ), - dcc.Graph(figure=figures["aimodel_response"], style={"width": "50%"}), - ], - style={"display": "flex", "justifyContent": "space-between"}, - ), - html.Div( - [ - dcc.Graph(figure=figures["f1_precision_recall_vs_time"]), + dcc.Graph(figure=figures["accuracy_vs_time"], id="accuracy_vs_time"), ] ), + side_by_side_graphs(figures, "pdr_profit_vs_ptrue", "trader_profit_vs_ptrue"), + side_by_side_graphs(figures, "aimodel_varimps", "aimodel_response"), + side_by_side_graphs(figures, "f1_precision_recall_vs_time", "log_loss_vs_time"), ] diff --git a/pdr_backend/sim/sim_plotter.py b/pdr_backend/sim/sim_plotter.py index d2e4fba7f..d5008f890 100644 --- a/pdr_backend/sim/sim_plotter.py +++ b/pdr_backend/sim/sim_plotter.py @@ -311,6 +311,25 @@ def plot_trader_profit_vs_ptrue(self): return fig + @enforce_types + def plot_log_loss_vs_time(self): + clm = self.st.clm + s = f"log loss = {clm.losses[-1]:.4f}" + + y = "log loss" + df = pd.DataFrame(clm.losses, columns=[y]) + df["time"] = range(len(clm.losses)) + + fig = go.Figure( + go.Scatter(x=df["time"], y=df[y], mode="lines", name="log loss") + ) + + fig.update_layout(title=s) + fig.update_xaxes(title="time") + fig.update_yaxes(title=y) + + return fig + def file_age_in_seconds(pathname): stat_result = os.stat(pathname) diff --git a/pdr_backend/sim/test/test_dash_plots.py b/pdr_backend/sim/test/test_dash_plots.py index 6d92e9a30..39d072287 100644 --- a/pdr_backend/sim/test/test_dash_plots.py +++ b/pdr_backend/sim/test/test_dash_plots.py @@ -39,16 +39,12 @@ def test_arrange_figures(): figures = {key: Figure() for key in figure_names} result = arrange_figures(figures) seen = set() - count = 0 for div in result: for graph in div.children: - count += 1 if hasattr(graph, "id"): seen.add(graph.id) - assert count == len(figure_names) - assert seen == {"aimodel_varimps"} # only one with id present for now - set(figure_names) + assert len(seen) == len(figure_names) def test_snapshot_slider(): @@ -91,8 +87,11 @@ def test_get_figures_by_state(): mock_sim_plotter.plot_pdr_profit_vs_ptrue.return_value = Figure() mock_sim_plotter.plot_trader_profit_vs_ptrue.return_value = Figure() mock_sim_plotter.plot_f1_precision_recall_vs_time.return_value = Figure() + mock_sim_plotter.plot_log_loss_vs_time.return_value = Figure() + plotdata = Mock() plotdata.colnames = ["var1", "var2"] + mock_sim_plotter.aimodel_plotdata = plotdata with patch(