-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convert streamlit dashboard to Dash. Main cases.
- Loading branch information
Showing
6 changed files
with
128 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,142 @@ | ||
import os | ||
import sys | ||
import time | ||
from dash import Dash, dcc, html, Input, Output, callback, State | ||
import plotly | ||
from pdr_backend.sim.sim_plotter import SimPlotter | ||
from pathlib import Path | ||
import os | ||
from pdr_backend.aimodel import aimodel_plotter | ||
|
||
import streamlit | ||
|
||
from pdr_backend.aimodel import aimodel_plotter | ||
from pdr_backend.sim.sim_plotter import SimPlotter | ||
# TODO: run with specific state -> check args in Readme | ||
# TODO: run with specific ports for multiple instances -> check args in Readme | ||
# TODO: display slider to select state and add callback | ||
# TODO: handle clickdata in varimps callback | ||
|
||
|
||
app = Dash(__name__) | ||
app.config["suppress_callback_exceptions"] = True | ||
|
||
streamlit.set_page_config(layout="wide") | ||
|
||
title = streamlit.empty() | ||
subtitle = streamlit.empty() | ||
inputs = streamlit.empty() | ||
c1, c2, c3 = streamlit.columns((1, 1, 2)) | ||
c4, c5 = streamlit.columns((1, 1)) | ||
c6, c7 = streamlit.columns((1, 1)) | ||
c8, _ = streamlit.columns((1, 1)) | ||
|
||
canvas = { | ||
"pdr_profit_vs_time": c1.empty(), | ||
"trader_profit_vs_time": c2.empty(), | ||
"accuracy_vs_time": c3.empty(), | ||
"pdr_profit_vs_ptrue": c4.empty(), | ||
"trader_profit_vs_ptrue": c5.empty(), | ||
"aimodel_varimps": c6.empty(), | ||
"aimodel_response": c7.empty(), | ||
"f1_precision_recall_vs_time": c8.empty(), | ||
} | ||
|
||
last_ts = None | ||
sim_plotter = SimPlotter() | ||
canvas = [ | ||
"pdr_profit_vs_time", | ||
"trader_profit_vs_time", | ||
"accuracy_vs_time", | ||
"pdr_profit_vs_ptrue", | ||
"trader_profit_vs_ptrue", | ||
"aimodel_varimps", | ||
"aimodel_response", | ||
"f1_precision_recall_vs_time", | ||
] | ||
empty_graphs_template = [ | ||
dcc.Graph(figure=plotly.graph_objs.Figure(), id=key) for key in canvas | ||
] | ||
|
||
|
||
app.layout = html.Div( | ||
html.Div( | ||
[ | ||
html.Div(empty_graphs_template, id="live-graphs"), | ||
dcc.Interval( | ||
id="interval-component", | ||
interval=3 * 1000, # in milliseconds | ||
n_intervals=0, | ||
disabled=False, | ||
), | ||
] | ||
) | ||
) | ||
|
||
|
||
@app.callback( | ||
Output("interval-component", "disabled"), | ||
[Input("sim_current_ts", "className")], | ||
[State("interval-component", "disabled")], | ||
) | ||
def callback_func_start_stop_interval(value, disabled_state): | ||
# stop refreshing if final state was reached | ||
return value == "finalState" | ||
|
||
|
||
def get_latest_state_id(): | ||
path = sorted(Path("sim_state").iterdir(), key=os.path.getmtime)[-1] | ||
return str(path).replace("sim_state/", "") | ||
|
||
|
||
state_id = sys.argv[1] if len(sys.argv) > 1 else get_latest_state_id() | ||
subtitle.markdown(f"Simulation ID: {state_id}") | ||
|
||
|
||
def load_canvas_on_state(ts): | ||
titletext = f"Iter #{st.iter_number} ({ts})" if ts != "final" else "Final sim state" | ||
title.title(titletext) | ||
|
||
def get_figures_by_state(): | ||
figures = {} | ||
for key in canvas: | ||
if not key.startswith("aimodel"): | ||
fig = getattr(sim_plotter, f"plot_{key}")() | ||
else: | ||
func_name = getattr(aimodel_plotter, f"plot_{key}") | ||
fig = func_name(sim_plotter.aimodel_plotdata) | ||
|
||
canvas[key].plotly_chart(fig, use_container_width=True, theme="streamlit") | ||
figures[key] = fig | ||
|
||
return figures | ||
|
||
while True: | ||
try: | ||
sim_plotter.load_state(state_id) | ||
break | ||
except Exception as e: | ||
time.sleep(3) | ||
title.title(f"Waiting for sim state... {e}") | ||
continue | ||
|
||
while True: | ||
@callback( | ||
Output("live-graphs", "children"), | ||
Input("interval-component", "n_intervals"), | ||
Input("aimodel_varimps", "clickData"), | ||
) | ||
def update_graph_live(n, clickData): | ||
state_id = get_latest_state_id() | ||
|
||
try: | ||
st, new_ts = sim_plotter.load_state(state_id) | ||
except EOFError: | ||
time.sleep(1) | ||
continue | ||
|
||
if new_ts == last_ts: | ||
time.sleep(1) | ||
continue | ||
|
||
load_canvas_on_state(new_ts) | ||
last_ts = new_ts | ||
|
||
if last_ts == "final": | ||
snapshots = SimPlotter.available_snapshots() | ||
timestamp = inputs.select_slider("Go to snapshot", snapshots, value="final") | ||
st, new_ts = sim_plotter.load_state(state_id, timestamp) | ||
load_canvas_on_state(timestamp) | ||
break | ||
st, ts = sim_plotter.load_state(state_id) | ||
except Exception as e: | ||
return [ | ||
html.Div( | ||
[html.H2(f"Error/waiting: {e}", id="sim_state_text")] | ||
+ empty_graphs_template, | ||
id="live-graphs", | ||
), | ||
] | ||
|
||
figures = get_figures_by_state() | ||
|
||
return [ | ||
html.H2(f"Simulation ID: {state_id}", id="sim_state_text"), | ||
html.H2( | ||
f"Iter #{st.iter_number} ({ts})" if ts != "final" else "Final sim state", | ||
id="sim_current_ts", | ||
className="finalState" if ts == "final" else "runningState", | ||
), | ||
html.H2(str(clickData)), | ||
html.Div( | ||
[ | ||
dcc.Graph(figure=figures["pdr_profit_vs_time"]), | ||
dcc.Graph(figure=figures["trader_profit_vs_time"]), | ||
], | ||
style={"display": "flex"}, | ||
), | ||
html.Div( | ||
[ | ||
dcc.Graph(figure=figures["accuracy_vs_time"]), | ||
] | ||
), | ||
html.Div( | ||
[ | ||
dcc.Graph(figure=figures["pdr_profit_vs_ptrue"]), | ||
dcc.Graph(figure=figures["trader_profit_vs_ptrue"]), | ||
], | ||
style={"display": "flex"}, | ||
), | ||
html.Div( | ||
[ | ||
dcc.Graph(figure=figures["aimodel_varimps"], id="aimodel_varimps"), | ||
dcc.Graph(figure=figures["aimodel_response"]), | ||
], | ||
style={"display": "flex"}, | ||
), | ||
html.Div( | ||
[ | ||
dcc.Graph(figure=figures["f1_precision_recall_vs_time"]), | ||
] | ||
), | ||
] | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(debug=True) |
This file was deleted.
Oops, something went wrong.