Skip to content

Commit

Permalink
Fix #750: [Sim] Refactor plot code in sim_engine (PR #755)
Browse files Browse the repository at this point in the history
* split sim_engine.py into 3 modules: sim_engine.py, sim_plotter.py, sim_state.py. 
* each plot is modularized into its own method
  • Loading branch information
trentmc authored Mar 6, 2024
1 parent 9284b60 commit 45bdf2e
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 204 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion pdr_backend/aimodel/test/test_aimodel_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enforce_typing import enforce_types
from pytest import approx

from pdr_backend.aimodel.plot_model import plot_model
from pdr_backend.aimodel.model_plotter import plot_model
from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory
from pdr_backend.aimodel.aimodel_factory import AimodelFactory
from pdr_backend.ppss.aimodel_ss import AimodelSS, aimodel_ss_test_dict
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions pdr_backend/cli/test/test_cli_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_do_sim(monkeypatch):
mock_f = Mock()
monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f)

with patch("pdr_backend.sim.sim_engine.plt.show"):
with patch("pdr_backend.sim.sim_plotter.plt.show"):
do_sim(MockArgParser_PPSS_NETWORK().parse_args())

mock_f.assert_called()
Expand All @@ -385,7 +385,7 @@ def test_do_main(monkeypatch, capfd):
mock_f = Mock()
monkeypatch.setattr(f"{_CLI_PATH}.SimEngine.run", mock_f)

with patch("pdr_backend.sim.sim_engine.plt.show"):
with patch("pdr_backend.sim.sim_plotter.plt.show"):
with patch("sys.argv", ["pdr", "sim", "ppss.yaml"]):
_do_main()

Expand Down
206 changes: 6 additions & 200 deletions pdr_backend/sim/sim_engine.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,22 @@
import copy
import logging
import os
from typing import Dict, List

from enforce_typing import enforce_types
from matplotlib import gridspec
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import random
import polars as pl
from statsmodels.stats.proportion import proportion_confint

from pdr_backend.aimodel.aimodel_data_factory import AimodelDataFactory
from pdr_backend.aimodel.aimodel_factory import AimodelFactory
from pdr_backend.aimodel.plot_model import plot_model
from pdr_backend.lake.ohlcv_data_factory import OhlcvDataFactory
from pdr_backend.ppss.ppss import PPSS
from pdr_backend.util.currency_types import Eth
from pdr_backend.sim.sim_state import SimState
from pdr_backend.sim.sim_plotter import SimPlotter
from pdr_backend.util.mathutil import classif_acc
from pdr_backend.util.time_types import UnixTimeMs

logger = logging.getLogger("sim_engine")
FONTSIZE = 9


# pylint: disable=too-many-instance-attributes
class SimEngineState:
def __init__(self, init_holdings: Dict[str, Eth]):
self.holdings: Dict[str, float] = {
tok: float(amt.amt_eth) for tok, amt in init_holdings.items()
}
self.init_loop_attributes()

def init_loop_attributes(self):
self.accs_train: List[float] = []
self.ybools_test: List[float] = []
self.ybools_testhat: List[float] = []
self.probs_up: List[float] = []
self.corrects: List[bool] = []
self.trader_profits_USD: List[float] = []
self.pdr_profits_OCEAN: List[float] = []


# pylint: disable=too-many-instance-attributes
Expand All @@ -57,15 +34,15 @@ def __init__(self, ppss: PPSS):

self.ppss = ppss

self.st = SimEngineState(
self.st = SimState(
copy.copy(self.ppss.trader_ss.init_holdings),
)

self.plot_state = None
self.sim_plotter = None
if self.ppss.sim_ss.do_plot:
n = self.ppss.predictoor_ss.aimodel_ss.n # num input vars
include_contour = n == 2
self.plot_state = PlotState(include_contour)
self.sim_plotter = SimPlotter(self.ppss, self.st, include_contour)

self.logfile = ""

Expand Down Expand Up @@ -244,9 +221,7 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame):

# plot
if self.do_plot(test_i, self.ppss.sim_ss.test_n):
self.plot_state.make_plot( # type: ignore[union-attr]
self.st,
self.ppss,
self.sim_plotter.make_plot( # type: ignore[union-attr]
model,
X_train,
ybool_train,
Expand Down Expand Up @@ -340,172 +315,3 @@ def do_plot(self, i: int, N: int):
return False

return True


@enforce_types
class PlotState:
def __init__(self, include_contour: bool):
self.include_contour = include_contour

fig = plt.figure()
self.fig = fig

if include_contour:
gs = gridspec.GridSpec(2, 4, width_ratios=[5, 1, 1, 5])
else:
gs = gridspec.GridSpec(2, 3, width_ratios=[5, 1, 1])

self.ax00 = fig.add_subplot(gs[0, 0])
self.ax01 = fig.add_subplot(gs[0, 1:3])
self.ax10 = fig.add_subplot(gs[1, 0])
self.ax11 = fig.add_subplot(gs[1, 1])
self.ax12 = fig.add_subplot(gs[1, 2])
if include_contour:
self.ax03 = fig.add_subplot(gs[:, 3])

self.x: List[float] = []
self.y01_est: List[float] = []
self.y01_l: List[float] = []
self.y01_u: List[float] = []
self.plotted_before: bool = False
plt.ion()
plt.show()

# pylint: disable=too-many-statements
def make_plot(self, st, ppss, model, X_train, ybool_train, colnames):
stake_amt = ppss.predictoor_ss.stake_amount.amt_eth

fig = self.fig
ax00, ax01 = self.ax00, self.ax01
ax10, ax11, ax12 = self.ax10, self.ax11, self.ax12

N = len(st.pdr_profits_OCEAN)
N_done = len(self.x) # what # points have been plotted previously

# set x
self.x = list(range(0, N))
next_x = _slice(self.x, N_done, N)
next_hx = [next_x[0], next_x[-1]] # horizontal x

# plot row 0, col 0: predictoor profit vs time
y00 = list(np.cumsum(st.pdr_profits_OCEAN))
next_y00 = _slice(y00, N_done, N)
ax00.plot(next_x, next_y00, c="g")
ax00.plot(next_hx, [0, 0], c="0.2", ls="--", lw=1)
s = f"Predictoor profit vs time. Current:{y00[-1]:.2f} OCEAN"
_set_title(ax00, s)
if not self.plotted_before:
ax00.set_ylabel("predictoor profit (OCEAN)", fontsize=FONTSIZE)
ax00.set_xlabel("time", fontsize=FONTSIZE)
_ylabel_on_right(ax00)
ax00.margins(0.005, 0.05)

# plot row 0, col 1: % correct vs time
for i in range(N_done, N):
n_correct = sum(st.corrects[: i + 1])
n_trials = len(st.corrects[: i + 1])
l, u = proportion_confint(count=n_correct, nobs=n_trials)
self.y01_est.append(n_correct / n_trials * 100)
self.y01_l.append(l * 100)
self.y01_u.append(u * 100)
next_y01_est = _slice(self.y01_est, N_done, N)
next_y01_l = _slice(self.y01_l, N_done, N)
next_y01_u = _slice(self.y01_u, N_done, N)

ax01.plot(next_x, next_y01_est, "green")
ax01.fill_between(next_x, next_y01_l, next_y01_u, color="0.9")
ax01.plot(next_hx, [50, 50], c="0.2", ls="--", lw=1)
ax01.set_ylim(bottom=40, top=60)
now_s = f"{self.y01_est[-1]:.2f}% "
now_s += f"[{self.y01_l[-1]:.2f}%, {self.y01_u[-1]:.2f}%]"
_set_title(ax01, f"% correct vs time. Current: {now_s}")
if not self.plotted_before:
ax01.set_xlabel("time", fontsize=FONTSIZE)
ax01.set_ylabel("% correct", fontsize=FONTSIZE)
_ylabel_on_right(ax01)
ax01.margins(0.01, 0.01)

# plot row 0, col 2: model contour
if self.include_contour:
ax03 = self.ax03
labels = tuple([_shift_one_earlier(colname) for colname in colnames])
plot_model(model, X_train, ybool_train, labels, (fig, ax03))
if not self.plotted_before:
ax03.margins(0.01, 0.01)

# plot row 1, col 0: trader profit vs time
y10 = list(np.cumsum(st.trader_profits_USD))
next_y10 = _slice(y10, N_done, N)
ax10.plot(next_x, next_y10, c="b")
ax10.plot(next_hx, [0, 0], c="0.2", ls="--", lw=1)
_set_title(ax10, f"Trader profit vs time. Current: ${y10[-1]:.2f}")
if not self.plotted_before:
ax10.set_xlabel("time", fontsize=FONTSIZE)
ax10.set_ylabel("trader profit (USD)", fontsize=FONTSIZE)
_ylabel_on_right(ax10)
ax10.margins(0.005, 0.05)

# reusable profits scatterplot
def _scatter_profits(ax, actor: str, denomin, mnp, mxp, st_profits):
next_probs_up = _slice(st.probs_up, N_done, N)
next_profits = _slice(st_profits, N_done, N)
c = (random(), random(), random()) # random RGB color
ax.scatter(next_probs_up, next_profits, color=c, s=1)
avg = np.average(st_profits)
s = f"{actor} profit distr'n. avg={avg:.2f} {denomin}"
_set_title(ax, s)
ax.plot([0.5, 0.5], [mnp, mxp], c="0.2", ls="-", lw=1)
if not self.plotted_before:
ax.plot([0.0, 1.0], [0, 0], c="0.2", ls="--", lw=1)
_set_xlabel(ax, "prob(up)")
_set_ylabel(ax, f"{actor} profit ({denomin})")
_ylabel_on_right(ax)
ax.margins(0.05, 0.05)

# plot row 1, col 1: 1d scatter of predictoor profits
mnp, mxp = -stake_amt, +stake_amt
_scatter_profits(ax11, "pdr", "OCEAN", mnp, mxp, st.pdr_profits_OCEAN)

# plot row 1, col 2: 1d scatter of trader profits
mnp, mxp = min(st.trader_profits_USD), max(st.trader_profits_USD)
_scatter_profits(ax12, "trader", "USD", mnp, mxp, st.trader_profits_USD)

# final pieces
HEIGHT = 7.5 # magic number
WIDTH = int(HEIGHT * 3.2) # magic number
fig.set_size_inches(WIDTH, HEIGHT)
fig.tight_layout(pad=0.5, h_pad=1.0, w_pad=1.0)
plt.pause(0.001)
self.plotted_before = True


def _shift_one_earlier(s: str):
"""eg 'binance:BTC/USDT:close:t-3' -> 'binance:BTC/USDT:close:t-2'"""
val = int(s[-1])
return s[:-1] + str(val - 1)


def _set_xlabel(ax, s: str):
ax.set_xlabel(s, fontsize=FONTSIZE)


def _set_ylabel(ax, s: str):
ax.set_ylabel(s, fontsize=FONTSIZE)


def _set_title(ax, s: str):
ax.set_title(s, fontsize=FONTSIZE, fontweight="bold")


def _slice(a: list, N_done: int, N: int) -> list:
return [a[i] for i in range(max(0, N_done - 1), N)]


def _ylabel_on_right(ax):
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")


def _del_lines(ax):
for l in ax.lines:
l.remove()
Loading

0 comments on commit 45bdf2e

Please sign in to comment.