diff --git a/READMEs/predictoor.md b/READMEs/predictoor.md index 35e392c00..704790ede 100644 --- a/READMEs/predictoor.md +++ b/READMEs/predictoor.md @@ -88,7 +88,7 @@ cd ~/code/pdr-backend # or wherever your pdr-backend dir is source venv/bin/activate #display real-time plots of the simulation -streamlit run sim_plots.py +sim_plots ``` "Predict" actions are _two-sided_: it does one "up" prediction tx, and one "down" tx, with more stake to the higher-confidence direction. Two-sided is more profitable than one-sided prediction. @@ -101,10 +101,10 @@ To see simulation CLI options: `pdr sim -h`. Simulation uses Python [logging](https://docs.python.org/3/howto/logging.html) framework. Configure it via [`logging.yaml`](../logging.yaml). [Here's](https://medium.com/@cyberdud3/a-step-by-step-guide-to-configuring-python-logging-with-yaml-files-914baea5a0e5) a tutorial on yaml settings. -By default, streamlit plots the latest sim (even if it is still running). To enable plotting for a specific run, e.g. if you used multisim or manually triggered different simulations, the sim engine assigns unique ids to each run. -Select that unique id from the `sim_state` folder, and run `streamlit run sim_plots.py ` e.g. `streamlit run sim_plots.py 97f9633c-a78c-4865-9cc6-b5152c9500a3` +By default, Dash plots the latest sim (even if it is still running). To enable plotting for a specific run, e.g. if you used multisim or manually triggered different simulations, the sim engine assigns unique ids to each run. +Select that unique id from the `sim_state` folder, and run `sim_plots --run_id ` e.g. `sim_plots --run-id 97f9633c-a78c-4865-9cc6-b5152c9500a3` -You can run many instances of streamlit at once, with different URLs. +You can run many instances of Dash at once, with different URLs. To run on different ports, use the `--port` argument. ## 3. Run Predictoor Bot on Sapphire Testnet diff --git a/READMEs/trader.md b/READMEs/trader.md index b1512f081..c9ccbe576 100644 --- a/READMEs/trader.md +++ b/READMEs/trader.md @@ -80,7 +80,7 @@ cd ~/code/pdr-backend # or wherever your pdr-backend dir is source venv/bin/activate #display real-time plots of the simulation -streamlit run sim_plots.py +sim_plots ``` "Predict" actions are _two-sided_: it does one "up" prediction tx, and one "down" tx, with more stake to the higher-confidence direction. Two-sided is more profitable than one-sided prediction. @@ -93,10 +93,10 @@ To see simulation CLI options: `pdr sim -h`. Simulation uses Python [logging](https://docs.python.org/3/howto/logging.html) framework. Configure it via [`logging.yaml`](../logging.yaml). [Here's](https://medium.com/@cyberdud3/a-step-by-step-guide-to-configuring-python-logging-with-yaml-files-914baea5a0e5) a tutorial on yaml settings. -By default, streamlit plots the latest sim (even if it is still running). To enable plotting for a specific run, e.g. if you used multisim or manually triggered different simulations, the sim engine assigns unique ids to each run. -Select that unique id from the `sim_state` folder, and run `streamlit run sim_plots.py ` e.g. `streamlit run sim_plots.py 97f9633c-a78c-4865-9cc6-b5152c9500a3` +By default, Dash plots the latest sim (even if it is still running). To enable plotting for a specific run, e.g. if you used multisim or manually triggered different simulations, the sim engine assigns unique ids to each run. +Select that unique id from the `sim_state` folder, and run `sim_plots --run_id ` e.g. `sim_plots --run_id 97f9633c-a78c-4865-9cc6-b5152c9500a3` -You can run many instances of streamlit at once, with different URLs. +You can run many instances of Dash at once, with different URLs. To run on different ports, use the `--port` argument. ## Run Trader Bot on Sapphire Testnet diff --git a/mypy.ini b/mypy.ini index abfb3b2c7..3c4fa436a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,9 @@ ignore_missing_imports = True [mypy-ccxt.*] ignore_missing_imports = True +[mypy-dash.*] +ignore_missing_imports = True + [mypy-ecies.*] ignore_missing_imports = True diff --git a/pdr_backend/aimodel/aimodel_plotdata.py b/pdr_backend/aimodel/aimodel_plotdata.py index 90c9e79fb..b8b800a30 100644 --- a/pdr_backend/aimodel/aimodel_plotdata.py +++ b/pdr_backend/aimodel/aimodel_plotdata.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from enforce_typing import enforce_types import numpy as np @@ -6,7 +6,6 @@ from pdr_backend.aimodel.aimodel import Aimodel -@enforce_types class AimodelPlotdata: """Simple class to manage many inputs going into plot_model.""" @@ -17,6 +16,7 @@ def __init__( ytrue_train: np.ndarray, colnames: List[str], slicing_x: np.ndarray, + sweep_vars: Optional[List[int]] = None, ): """ @arguments @@ -24,7 +24,10 @@ def __init__( X_train -- 2d array [sample_i, var_i]:cont_value -- model trn inputs ytrue_train -- 1d array [sample_i]:bool_value -- model trn outputs colnames -- [var_i]:str -- name for each of the X inputs - slicing_x -- arrat [dim_i]:floatval - when >2 dims, plot about this pt + slicing_x -- array [var_i]:floatval - values for non-sweep vars + sweep_vars -- list with [sweepvar_i] or [sweepvar_i, sweepvar_j] + -- If 1 entry, do line plot (1 var), where y-axis is response + -- If 2 entries, do contour plot (2 vars), where z-axis is response """ # preconditions assert X_train.shape[1] == len(colnames) == slicing_x.shape[0], ( @@ -36,6 +39,7 @@ def __init__( X_train.shape[0], ytrue_train.shape[0], ) + assert sweep_vars is None or len(sweep_vars) in [1, 2] # set values self.model = model @@ -43,8 +47,22 @@ def __init__( self.ytrue_train = ytrue_train self.colnames = colnames self.slicing_x = slicing_x + self.sweep_vars = sweep_vars @property + @enforce_types def n(self) -> int: """Number of input dimensions == # columns in X""" return self.X_train.shape[1] + + @property + @enforce_types + def n_sweep(self) -> int: + """Number of variables to sweep in the plot""" + if self.sweep_vars is None: + return 0 + + if self.n == 1: + return 1 + + return len(self.sweep_vars) diff --git a/pdr_backend/aimodel/aimodel_plotter.py b/pdr_backend/aimodel/aimodel_plotter.py index 1f869b061..53a4c2620 100644 --- a/pdr_backend/aimodel/aimodel_plotter.py +++ b/pdr_backend/aimodel/aimodel_plotter.py @@ -1,32 +1,25 @@ -from enforce_typing import enforce_types import numpy as np import plotly.graph_objects as go +from enforce_typing import enforce_types from pdr_backend.aimodel.aimodel_plotdata import AimodelPlotdata @enforce_types -def plot_aimodel_response( - aimodel_plotdata: AimodelPlotdata, -): +def plot_aimodel_response(aimodel_plotdata: AimodelPlotdata): """ @description - Plot the model response in a line plot (1 var) contour plot (>1 vars) + Plot the model response in a line plot (1 var) or contour plot (2 vars). And overlay X-data. (Training data or otherwise.) - If the model has >2 vars, it plots the 2 most important vars. - - @arguments - aimodel_plotdata -- holds: - model -- Aimodel - X_train -- array [sample_i][var_i]:floatval -- trn model inputs (or other) - ytrue_train -- array [sample_i]:boolval -- trn model outputs (or other) - colnames -- list [var_i]:X_column_name - slicing_x -- arrat [var_i]:floatval - when >2 dims, plot about this pt - fig_ax -- None or (fig, ax) to easily embed into existing plot - legend_loc -- eg "upper left". Applies only to contour plots. """ - if aimodel_plotdata.n == 1: - return _plot_aimodel_lineplot(aimodel_plotdata) + d = aimodel_plotdata + # assert d.n_sweep in [1, 2] + + if d.n_sweep == 1 and d.n == 1 or d.n == 1: + return _plot_aimodel_lineplot_1var(aimodel_plotdata) + + if d.n_sweep == 1 and d.n > 1: + return _plot_aimodel_lineplot_nvars(aimodel_plotdata) return _plot_aimodel_contour(aimodel_plotdata) @@ -35,15 +28,17 @@ def plot_aimodel_response( @enforce_types -def _plot_aimodel_lineplot(aimodel_plotdata: AimodelPlotdata): +def _plot_aimodel_lineplot_1var(aimodel_plotdata: AimodelPlotdata): """ @description - Plot the model, when there's 1 input x-var. Use a line plot. + Do a 1d lineplot, when exactly 1 input x-var Will fail if not 1 var. + Because one var total, we can show more info of true-vs-actual """ # aimodel data - assert aimodel_plotdata.n == 1 d = aimodel_plotdata + assert d.n == 1 + assert d.n_sweep == 1 X, ytrue = d.X_train, d.ytrue_train x = X[:, 0] @@ -128,6 +123,49 @@ def _plot_aimodel_lineplot(aimodel_plotdata: AimodelPlotdata): return fig_bars +@enforce_types +def _plot_aimodel_lineplot_nvars(aimodel_plotdata: AimodelPlotdata): + """ + @description + Do a 1d lineplot, when >1 input x-var, and we have chosen the var. + Because >1 var total, we can show more info of true-vs-actual + """ + # input data + d = aimodel_plotdata + assert d.n >= 1 + assert d.n_sweep == 1 + + # construct sweep_x + sweepvar_i = d.sweep_vars[0] # type: ignore[index] + mn_x, mx_x = min(d.X_train[:, sweepvar_i]), max(d.X_train[:, sweepvar_i]) + N = 200 + sweep_x = np.linspace(mn_x, mx_x, N) + + # construct X + X = np.empty((N, d.n), dtype=float) + X[:, sweepvar_i] = sweep_x + for var_i in range(d.n): + if var_i == sweepvar_i: + continue + X[:, var_i] = d.slicing_x[var_i] + + # calc model response + yptrue = d.model.predict_ptrue(X) # [sample_i]: prob_of_being_true + + # line plot: model response surface + fig_line = go.Figure( + data=go.Scatter( + x=sweep_x, + y=yptrue, + mode="lines", + line={"color": "#636EFA"}, + name="model prob(true)", + ) + ) + + return fig_line + + @enforce_types def _plot_aimodel_contour( aimodel_plotdata: AimodelPlotdata, @@ -251,6 +289,8 @@ def plot_aimodel_varimps(d: AimodelPlotdata): varnames = d.colnames n = len(varnames) + sweep_vars = d.sweep_vars if hasattr(d, "sweep_vars") else [] + # if >40 vars, truncate to top 40+1 if n > 40: rest_avg = sum(imps_avg[40:]) @@ -273,6 +313,7 @@ def plot_aimodel_varimps(d: AimodelPlotdata): imps_stddev = imps_stddev * 100.0 labelalias = {} + colors = [] for i in range(n): # avoid overlap in figure by giving different labels, @@ -281,12 +322,15 @@ def plot_aimodel_varimps(d: AimodelPlotdata): varnames[i] = f"var{i}" labelalias[varnames[i]] = "" + colors.append("#636EFA" if not sweep_vars or i in sweep_vars else "#D4D5DD") + fig_bars = go.Figure( data=go.Bar( x=imps_avg, y=varnames, error_x={"type": "data", "array": imps_stddev * 2}, orientation="h", + marker_color=colors, ) ) diff --git a/pdr_backend/aimodel/test/test_aimodel_factory.py b/pdr_backend/aimodel/test/test_aimodel_factory.py index 196666e37..972695ed9 100644 --- a/pdr_backend/aimodel/test/test_aimodel_factory.py +++ b/pdr_backend/aimodel/test/test_aimodel_factory.py @@ -135,7 +135,7 @@ def test_aimodel_accuracy_from_create_xy(): @enforce_types -def test_aimodel_factory_1var_main(): +def test_aimodel_factory_1varmodel_lineplot(): """1 input var. It will plot that var on both axes""" # settings, factory ss = AimodelSS(aimodel_ss_test_dict(approach="LinearLogistic")) @@ -159,7 +159,64 @@ def test_aimodel_factory_1var_main(): # plot colnames = ["x0"] slicing_x = np.array([0.1]) # arbitrary - aimodel_plotdata = AimodelPlotdata(model, X, ytrue, colnames, slicing_x) + sweep_vars = [0] + aimodel_plotdata = AimodelPlotdata( + model, + X, + ytrue, + colnames, + slicing_x, + sweep_vars, + ) + figure = plot_aimodel_response(aimodel_plotdata) + assert isinstance(figure, Figure) + + if SHOW_PLOT: + figure.show() + + +@enforce_types +def test_aimodel_factory_5varmodel_lineplot(): + """5 input vars; sweep 1 var.""" + # settings, factory + ss = AimodelSS(aimodel_ss_test_dict(approach="LinearLogistic")) + factory = AimodelFactory(ss) + + # data + N = 1000 + mn, mx = -10.0, +10.0 + X = np.random.uniform(mn, mx, (N, 5)) + ycont = ( + 10.0 + + 0.1 * X[:, 0] + + 1.0 * X[:, 1] + + 2.0 * X[:, 2] + + 3.0 * X[:, 3] + + 4.0 * X[:, 4] + ) + y_thr = np.average(ycont) # avg gives good class balance + ytrue = ycont > y_thr + + # build model + model = factory.build(X, ytrue, show_warnings=False) + + # test variable importances + imps = model.importance_per_var() + assert len(imps) == 5 + assert imps[0] < imps[1] < imps[2] < imps[3] < imps[4] + + # plot + colnames = ["x0", "x1", "x2", "x3", "x4"] + slicing_x = np.array([0.1] * 5) # arbitrary + sweep_vars = [2] # x2 + aimodel_plotdata = AimodelPlotdata( + model, + X, + ytrue, + colnames, + slicing_x, + sweep_vars, + ) figure = plot_aimodel_response(aimodel_plotdata) assert isinstance(figure, Figure) diff --git a/pdr_backend/cli/test/test_cli_module.py b/pdr_backend/cli/test/test_cli_module.py index b6a509085..3b85d7f73 100644 --- a/pdr_backend/cli/test/test_cli_module.py +++ b/pdr_backend/cli/test/test_cli_module.py @@ -408,7 +408,7 @@ def test_do_topup(monkeypatch): @enforce_types -def test_do_main(monkeypatch, capfd): +def test_do_main(capfd): with patch("sys.argv", ["pdr", "help"]): with pytest.raises(SystemExit): _do_main() @@ -420,11 +420,3 @@ def test_do_main(monkeypatch, capfd): _do_main() assert "Predictoor tool" in capfd.readouterr().out - - mock_f = Mock() - monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f) - - with patch("sys.argv", ["streamlit_entrypoint.py", "sim", "ppss.yaml"]): - _do_main() - - assert mock_f.called diff --git a/pdr_backend/sim/sim_plotter.py b/pdr_backend/sim/sim_plotter.py index 744443d7f..02a6b1e50 100644 --- a/pdr_backend/sim/sim_plotter.py +++ b/pdr_backend/sim/sim_plotter.py @@ -27,11 +27,11 @@ def __init__( self.multi_id = None @staticmethod - def available_snapshots(): - all_state_files = glob.glob("sim_state/st_*.pkl") + def available_snapshots(multi_id): + all_state_files = glob.glob(f"sim_state/{multi_id}/st_*.pkl") all_timestamps = [ - f.replace("sim_state/st_", "").replace(".pkl", "") + f.replace(f"sim_state/{multi_id}/st_", "").replace(".pkl", "") for f in all_state_files if "final" not in f ] diff --git a/pdr_dash_plots/__init__.py b/pdr_dash_plots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pdr_dash_plots/callbacks.py b/pdr_dash_plots/callbacks.py new file mode 100644 index 000000000..5b54bf340 --- /dev/null +++ b/pdr_dash_plots/callbacks.py @@ -0,0 +1,81 @@ +from dash import Input, Output, State + +from pdr_backend.sim.sim_plotter import SimPlotter +from pdr_dash_plots.util import get_figures_by_state, get_latest_run_id +from pdr_dash_plots.view_elements import ( + arrange_figures, + get_header_elements, + get_waiting_template, + non_final_state_div, + selected_var_checklist, + snapshot_slider, +) + + +def get_callbacks(app): + @app.callback( + Output("interval-component", "disabled"), + [Input("sim_current_ts", "className")], + [State("interval-component", "disabled")], + ) + # pylint: disable=unused-argument + def callback_func_start_stop_interval(value, disabled_state): + # stop refreshing if final state was reached + return value == "finalState" + + @app.callback( + Output("selected_vars", "value"), + Input("aimodel_varimps", "clickData"), + State("selected_vars", "value"), + ) + def update_selected_vars(clickData, selected_vars): + if clickData is None: + return selected_vars + + label = clickData["points"][0]["y"] + if label in selected_vars: + selected_vars.remove(label) + else: + selected_vars.append(label) + + return selected_vars + + @app.callback( + Output("live-graphs", "children"), + Input("interval-component", "n_intervals"), + Input("selected_vars", "value"), + Input("state_slider", "value"), + State("selected_vars", "value"), + ) + # pylint: disable=unused-argument + def update_graph_live(n, selected_vars, slider_value, selected_vars_old): + run_id = app.run_id if app.run_id else get_latest_run_id() + set_ts = None + + if slider_value is not None: + snapshots = SimPlotter.available_snapshots(run_id) + set_ts = snapshots[slider_value] + + sim_plotter = SimPlotter() + + try: + st, ts = sim_plotter.load_state(run_id, set_ts) + except Exception as e: + return [get_waiting_template(e)] + + elements = get_header_elements(run_id, st, ts) + + slider = ( + snapshot_slider(run_id, set_ts, slider_value) + if ts == "final" + else non_final_state_div + ) + elements.append(slider) + + state_options = sim_plotter.aimodel_plotdata.colnames + elements.append(selected_var_checklist(state_options, selected_vars_old)) + + figures = get_figures_by_state(sim_plotter, selected_vars) + elements = elements + arrange_figures(figures) + + return elements diff --git a/pdr_dash_plots/util.py b/pdr_dash_plots/util.py new file mode 100644 index 000000000..c804fbbe2 --- /dev/null +++ b/pdr_dash_plots/util.py @@ -0,0 +1,37 @@ +import os +from pathlib import Path + +from pdr_backend.aimodel import aimodel_plotter +from pdr_backend.sim.sim_plotter import SimPlotter +from pdr_dash_plots.view_elements import figure_names + + +def get_latest_run_id(): + path = sorted(Path("sim_state").iterdir(), key=os.path.getmtime)[-1] + return str(path).replace("sim_state/", "") + + +def get_all_run_names(): + path = Path("sim_state").iterdir() + return [str(p).replace("sim_state/", "") for p in path] + + +def get_figures_by_state(sim_plotter: SimPlotter, selected_vars): + figures = {} + + for key in figure_names: + if not key.startswith("aimodel"): + fig = getattr(sim_plotter, f"plot_{key}")() + else: + if key in ["aimodel_response", "aimodel_varimps"]: + sweep_vars = [] + for var in selected_vars: + sweep_vars.append(sim_plotter.aimodel_plotdata.colnames.index(var)) + sim_plotter.aimodel_plotdata.sweep_vars = sweep_vars + + func_name = getattr(aimodel_plotter, f"plot_{key}") + fig = func_name(sim_plotter.aimodel_plotdata) + + figures[key] = fig + + return figures diff --git a/pdr_dash_plots/view_elements.py b/pdr_dash_plots/view_elements.py new file mode 100644 index 000000000..fa5adcf5a --- /dev/null +++ b/pdr_dash_plots/view_elements.py @@ -0,0 +1,127 @@ +from dash import dcc, html +from plotly.graph_objs import Figure + +from pdr_backend.sim.sim_plotter import SimPlotter + +figure_names = [ + "pdr_profit_vs_time", + "trader_profit_vs_time", + "accuracy_vs_time", + "pdr_profit_vs_ptrue", + "trader_profit_vs_ptrue", + "aimodel_varimps", + "aimodel_response", + "f1_precision_recall_vs_time", +] + +empty_slider = dcc.Slider( + id="state_slider", + min=0, + max=0, + step=1, + disabled=True, +) + +empty_selected_vars = dcc.Checklist([], [], id="selected_vars") + +non_final_state_div = html.Div( + [empty_slider, empty_selected_vars], + style={"display": "none"}, +) + + +empty_graphs_template = html.Div( + [dcc.Graph(figure=Figure(), id=key) for key in figure_names] + + [empty_slider, empty_selected_vars], + style={"display": "none"}, +) + + +def get_waiting_template(err): + return html.Div( + [html.H2(f"Error/waiting: {err}", id="sim_state_text")] + + [empty_graphs_template], + id="live-graphs", + ) + + +def get_header_elements(run_id, st, ts): + return [ + html.H2(f"Simulation ID: {run_id}", id="sim_state_text"), + html.H2( + f"Iter #{st.iter_number} ({ts})" if ts != "final" else "Final sim state", + id="sim_current_ts", + # stops refreshing if final state was reached. Do not remove this class! + className="finalState" if ts == "final" else "runningState", + ), + ] + + +def arrange_figures(figures): + return [ + html.Div( + [ + dcc.Graph( + figure=figures["pdr_profit_vs_time"], style={"width": "100%"} + ), + dcc.Graph( + figure=figures["trader_profit_vs_time"], style={"width": "100%"} + ), + ], + style={"display": "flex", "justifyContent": "space-between"}, + ), + html.Div( + [ + dcc.Graph(figure=figures["accuracy_vs_time"]), + ] + ), + html.Div( + [ + dcc.Graph( + figure=figures["pdr_profit_vs_ptrue"], style={"width": "100%"} + ), + dcc.Graph( + figure=figures["trader_profit_vs_ptrue"], style={"width": "100%"} + ), + ], + style={"display": "flex", "justifyContent": "space-between"}, + ), + html.Div( + [ + dcc.Graph( + figure=figures["aimodel_varimps"], + id="aimodel_varimps", + style={"width": "100%"}, + ), + dcc.Graph(figure=figures["aimodel_response"], style={"width": "100%"}), + ], + style={"display": "flex", "justifyContent": "space-between"}, + ), + html.Div( + [ + dcc.Graph(figure=figures["f1_precision_recall_vs_time"]), + ] + ), + ] + + +def snapshot_slider(run_id, set_ts, slider_value): + snapshots = SimPlotter.available_snapshots(run_id)[:-1] + marks = {i: f"{s.replace('_', '')[:-4]}" for i, s in enumerate(snapshots)} + marks[len(snapshots)] = "final" + + return dcc.Slider( + id="state_slider", + marks=marks, + value=len(snapshots) if not set_ts else slider_value, + step=1, + ) + + +def selected_var_checklist(state_options, selected_vars_old): + return dcc.Checklist( + options=[{"label": var, "value": var} for var in state_options], + value=selected_vars_old, + id="selected_vars", + style={"display": "none"}, + ) diff --git a/setup.py b/setup.py index 1a7b90b9c..718577079 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "bumpversion", "ccxt==4.3.11", "coverage", + "dash==2.16.1", "enforce_typing", "eth-account==0.11.0", "eth-keys==0.5.1", @@ -42,7 +43,6 @@ "types-requests==2.31.0.20240406", "web3==6.17.2", "sapphire.py==0.2.2", - "streamlit==1.33.0", "typeguard==4.2.1", "ocean-contracts==2.0.4", # install this last ] diff --git a/sim_plots b/sim_plots new file mode 100755 index 000000000..94f5c0eb7 --- /dev/null +++ b/sim_plots @@ -0,0 +1,63 @@ +#!/usr/bin/env python +import argparse + +from dash import Dash, dcc, html + +from pdr_dash_plots.callbacks import get_callbacks +from pdr_dash_plots.view_elements import empty_graphs_template +from pdr_dash_plots.util import get_all_run_names + +app = Dash(__name__) +app.config["suppress_callback_exceptions"] = True +app.layout = html.Div( + html.Div( + [ + html.Div(empty_graphs_template, id="live-graphs"), + dcc.Interval( + id="interval-component", + interval=3 * 1000, # in milliseconds + n_intervals=0, + disabled=False, + ), + ] + ) +) + +get_callbacks(app) + + +def validate_run_id(run_id): + if run_id not in get_all_run_names(): + raise ValueError(f"Invalid run_id: {run_id}") + + return run_id + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='sim plots', + description='A script to visualize simulation data generated by pdr sim.', + epilog='powered by Dash, Plotly, and Flask.' + ) + + parser.add_argument( + '--run_id', + help=( + 'The run_id of the simulation to visualize. ' + 'If not provided, the latest run_id will be used.' + ), + type=validate_run_id + ) + + parser.add_argument( + '--port', + nargs='?', + help='The port to run the server on. Default is 8050.', + type=int, + default=8050 + ) + + args = parser.parse_args() + + app.run_id = args.run_id + app.run(debug=True, port=args.port) diff --git a/sim_plots.py b/sim_plots.py deleted file mode 100755 index 441296278..000000000 --- a/sim_plots.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -import sys -import time -from pathlib import Path - -import streamlit - -from pdr_backend.aimodel import aimodel_plotter -from pdr_backend.sim.sim_plotter import SimPlotter - -streamlit.set_page_config(layout="wide") - -title = streamlit.empty() -subtitle = streamlit.empty() -inputs = streamlit.empty() -c1, c2, c3 = streamlit.columns((1, 1, 2)) -c4, c5 = streamlit.columns((1, 1)) -c6, c7 = streamlit.columns((1, 1)) -c8, _ = streamlit.columns((1, 1)) - -canvas = { - "pdr_profit_vs_time": c1.empty(), - "trader_profit_vs_time": c2.empty(), - "accuracy_vs_time": c3.empty(), - "pdr_profit_vs_ptrue": c4.empty(), - "trader_profit_vs_ptrue": c5.empty(), - "aimodel_varimps": c6.empty(), - "aimodel_response": c7.empty(), - "f1_precision_recall_vs_time": c8.empty(), -} - -last_ts = None -sim_plotter = SimPlotter() - - -def get_latest_state_id(): - path = sorted(Path("sim_state").iterdir(), key=os.path.getmtime)[-1] - return str(path).replace("sim_state/", "") - - -state_id = sys.argv[1] if len(sys.argv) > 1 else get_latest_state_id() -subtitle.markdown(f"Simulation ID: {state_id}") - - -def load_canvas_on_state(ts): - titletext = f"Iter #{st.iter_number} ({ts})" if ts != "final" else "Final sim state" - title.title(titletext) - - for key in canvas: - if not key.startswith("aimodel"): - fig = getattr(sim_plotter, f"plot_{key}")() - else: - func_name = getattr(aimodel_plotter, f"plot_{key}") - fig = func_name(sim_plotter.aimodel_plotdata) - - canvas[key].plotly_chart(fig, use_container_width=True, theme="streamlit") - - -while True: - try: - sim_plotter.load_state(state_id) - break - except Exception as e: - time.sleep(3) - title.title(f"Waiting for sim state... {e}") - continue - -while True: - try: - st, new_ts = sim_plotter.load_state(state_id) - except EOFError: - time.sleep(1) - continue - - if new_ts == last_ts: - time.sleep(1) - continue - - load_canvas_on_state(new_ts) - last_ts = new_ts - - if last_ts == "final": - snapshots = SimPlotter.available_snapshots() - timestamp = inputs.select_slider("Go to snapshot", snapshots, value="final") - st, new_ts = sim_plotter.load_state(state_id, timestamp) - load_canvas_on_state(timestamp) - break diff --git a/streamlit_entrypoint.py b/streamlit_entrypoint.py deleted file mode 100755 index af274c96f..000000000 --- a/streamlit_entrypoint.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python - -from pdr_backend.cli import cli_module - -if __name__ == "__main__": - cli_module._do_main()