Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1263: Transform lake data into prediction signals #1297

Merged
merged 12 commits into from
Jun 26, 2024
76 changes: 70 additions & 6 deletions pdr_backend/sim/sim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import uuid
import time
from typing import Optional
from typing import Optional, Dict

import numpy as np
import polars as pl
Expand All @@ -25,9 +25,10 @@
from pdr_backend.sim.sim_state import SimState
from pdr_backend.util.strutil import shift_one_earlier
from pdr_backend.util.time_types import UnixTimeMs
from pdr_backend.lake.duckdb_data_store import DuckDBDataStore
from pdr_backend.subgraph.subgraph_feed_contracts import query_feed_contracts
from pdr_backend.lake.etl import ETL
from pdr_backend.lake.gql_data_factory import GQLDataFactory
from pdr_backend.lake.duckdb_data_store import DuckDBDataStore

logger = logging.getLogger("sim_engine")

Expand Down Expand Up @@ -62,6 +63,8 @@ def __init__(
else:
self.multi_id = str(uuid.uuid4())

self.crt_trained_model: Optional[Aimodel] = None
self.prediction_dataset: Optional[Dict[int, float]] = None
self.model: Optional[Aimodel] = None

@property
Expand Down Expand Up @@ -96,6 +99,10 @@ def run(self):
if not chain_prediction_data:
return

self.prediction_dataset = self._get_prediction_dataset(
UnixTimeMs(self.ppss.lake_ss.st_timestamp).to_seconds(),
UnixTimeMs(self.ppss.lake_ss.fin_timestamp).to_seconds(),
)

for test_i in range(self.ppss.sim_ss.test_n):
self.run_one_iter(test_i, mergedohlcv_df)
Expand All @@ -112,6 +119,7 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame):
revenue = pdr_ss.revenue.amt_eth

testshift = ppss.sim_ss.test_n - test_i - 1 # eg [99, 98, .., 2, 1, 0]

data_f = AimodelDataFactory(pdr_ss) # type: ignore[arg-type]
predict_feed = self.predict_train_feedset.predict
train_feeds = self.predict_train_feedset.train_on
Expand Down Expand Up @@ -155,8 +163,20 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame):
ut = UnixTimeMs(recent_ut - testshift * timeframe.ms)

# predict price direction
prob_up: float = self.model.predict_ptrue(X_test)[0] # in [0.0, 1.0]
prob_down: float = 1.0 - prob_up
if self.ppss.sim_ss.use_own_model is not False:
prob_up: float = self.model.predict_ptrue(X_test)[0] # in [0.0, 1.0]
else:
ut_seconds = ut.to_seconds()
if (
self.prediction_dataset is not None
and ut_seconds in self.prediction_dataset
):
# check if the current slot is in the keys
prob_up = self.prediction_dataset[ut_seconds]
else:
return

prob_down: Optional[float] = 1.0 - prob_up
conf_up = (prob_up - 0.5) * 2.0 # to range [0,1]
conf_down = (prob_down - 0.5) * 2.0 # to range [0,1]
conf_threshold = self.ppss.trader_ss.sim_confidence_threshold
Expand Down Expand Up @@ -271,6 +291,51 @@ def save_state(self, i: int, N: int):
return True, False

@enforce_types
def _get_prediction_dataset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you rename this and the test to _get_predictions_signals_data? This way is more specific

self, start_slot: int, end_slot: int
) -> Dict[int, Optional[float]]:
contracts = query_feed_contracts(
self.ppss.web3_pp.subgraph_url,
self.ppss.web3_pp.owner_addrs,
)

sPE = 300 if self.predict_feed.timeframe == "5m" else 3600
# Filter contracts with the correct token pair and timeframe
contract_to_use = [
addr
for addr, feed in contracts.items()
if feed.symbol
== f"{self.predict_feed.pair.base_str}/{self.predict_feed.pair.quote_str}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use self.predict_feed.pair.pair_str instead of f"{self.predict_feed.pair.base_str}/{self.predict_feed.pair.quote_str}"

and feed.seconds_per_epoch == sPE
]

query = f"""
SELECT
slot,
CASE
WHEN roundSumStakes = 0.0 THEN NULL
WHEN roundSumStakesUp = 0.0 THEN NULL
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this case when roundSumStakesUp = 0.0 the value shouldn't be NULL because maybe all the predictoors have staked down. It should be NULL only if roundSumStakes = 0.0

ELSE roundSumStakesUp / roundSumStakes
END AS probUp
FROM
pdr_payouts
WHERE
slot > {start_slot}
AND slot < {end_slot}
AND ID LIKE '{contract_to_use[0]}%'
"""

db = DuckDBDataStore(self.ppss.lake_ss.lake_dir)
df: pl.DataFrame = db.query_data(query)

result_dict = {}

for i in range(len(df)):
if df["probUp"][i] is not None:
result_dict[df["slot"][i]] = df["probUp"][i]

return result_dict

def _get_past_predictions_from_chain(self, ppss: PPSS):
# calculate needed data start date
current_time_s = int(time.time())
Expand All @@ -293,8 +358,7 @@ def _get_past_predictions_from_chain(self, ppss: PPSS):

# fetch data from subgraph
gql_data_factory = GQLDataFactory(ppss)
etl = ETL(ppss, gql_data_factory)
etl.do_etl()
gql_data_factory._update()
time.sleep(3)

# check if required data exists in the data base
Expand Down
55 changes: 49 additions & 6 deletions pdr_backend/sim/test/test_sim_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pdr_backend.sim.dash_plots.callbacks import get_callbacks
from pdr_backend.sim.dash_plots.view_elements import get_layout
from pdr_backend.sim.sim_engine import SimEngine
from pdr_backend.lake.duckdb_data_store import DuckDBDataStore
from pdr_backend.util.time_types import UnixTimeMs


@enforce_types
Expand Down Expand Up @@ -102,6 +104,47 @@ def test_sim_engine(tmpdir, check_chromedriver, dash_duo):
dash_duo.find_element(f"#{figure_name}")


@enforce_types
def test_get_prediction_dataset(tmpdir):
s = os.path.abspath("ppss.yaml")
d = PPSS.constructor_dict(s)

d["lake_ss"]["lake_dir"] = os.path.join(tmpdir, "lake_data")
d["lake_ss"]["st_timestr"] = "2 hours ago"
d["trader_ss"]["feed.timeframe"] = "5m"
d["sim_ss"]["test_n"] = 20
ppss = PPSS(d=d, network="sapphire-mainnet")
feedsets = ppss.predictoor_ss.predict_train_feedsets
sim_engine = SimEngine(ppss, feedsets[0])

# Getting prediction dataset
sim_engine._get_past_predictions_from_chain(ppss)

# check the duckdb file exists in the lake directory
assert os.path.exists(ppss.lake_ss.lake_dir)
assert os.path.exists(os.path.join(ppss.lake_ss.lake_dir, "duckdb.db"))

st_ut_s = UnixTimeMs(ppss.lake_ss.st_timestamp).to_seconds()
prediction_dataset = sim_engine._get_prediction_dataset(
st_ut_s,
UnixTimeMs(ppss.lake_ss.fin_timestamp).to_seconds(),
)

db = DuckDBDataStore(ppss.lake_ss.lake_dir)
test_query = f"""
SELECT
slot
FROM pdr_payouts
WHERE
slot > {st_ut_s}
LIMIT 1"""

df = db.query_data(test_query)
assert isinstance(prediction_dataset, dict)

assert df["slot"][0] in prediction_dataset.keys()


def test_get_past_predictions_from_chain():
s = os.path.abspath("ppss.yaml")
d = PPSS.constructor_dict(s)
Expand All @@ -124,10 +167,10 @@ def test_get_past_predictions_from_chain():
shutil.rmtree(path)

# needs to be inspected and fixed
# d["sim_ss"]["test_n"] = 20
# ppss = PPSS(d=d, network="sapphire-mainnet")
# print(ppss.lake_ss)
d["sim_ss"]["test_n"] = 20
ppss = PPSS(d=d, network="sapphire-mainnet")
print(ppss.lake_ss)

# sim_engine = SimEngine(ppss, feedsets[0])
# resp = sim_engine._get_past_predictions_from_chain(ppss)
# assert resp is True
sim_engine = SimEngine(ppss, feedsets[0])
resp = sim_engine._get_past_predictions_from_chain(ppss)
assert resp is True
Loading