-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #1263: Transform lake data into prediction signals #1297
Changes from 5 commits
00093e7
b933fbb
95f5d1d
475f89a
5421884
e6ef25f
fcaa9a1
ff4f828
4c84815
8d57668
6cb1fdf
17af2c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
import os | ||
import uuid | ||
import time | ||
from typing import Optional | ||
from typing import Optional, Dict | ||
|
||
import numpy as np | ||
import polars as pl | ||
|
@@ -25,9 +25,10 @@ | |
from pdr_backend.sim.sim_state import SimState | ||
from pdr_backend.util.strutil import shift_one_earlier | ||
from pdr_backend.util.time_types import UnixTimeMs | ||
from pdr_backend.lake.duckdb_data_store import DuckDBDataStore | ||
from pdr_backend.subgraph.subgraph_feed_contracts import query_feed_contracts | ||
from pdr_backend.lake.etl import ETL | ||
from pdr_backend.lake.gql_data_factory import GQLDataFactory | ||
from pdr_backend.lake.duckdb_data_store import DuckDBDataStore | ||
|
||
logger = logging.getLogger("sim_engine") | ||
|
||
|
@@ -62,6 +63,8 @@ def __init__( | |
else: | ||
self.multi_id = str(uuid.uuid4()) | ||
|
||
self.crt_trained_model: Optional[Aimodel] = None | ||
self.prediction_dataset: Optional[Dict[int, float]] = None | ||
self.model: Optional[Aimodel] = None | ||
|
||
@property | ||
|
@@ -96,6 +99,10 @@ def run(self): | |
if not chain_prediction_data: | ||
return | ||
|
||
self.prediction_dataset = self._get_prediction_dataset( | ||
UnixTimeMs(self.ppss.lake_ss.st_timestamp).to_seconds(), | ||
UnixTimeMs(self.ppss.lake_ss.fin_timestamp).to_seconds(), | ||
) | ||
|
||
for test_i in range(self.ppss.sim_ss.test_n): | ||
self.run_one_iter(test_i, mergedohlcv_df) | ||
|
@@ -112,6 +119,7 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame): | |
revenue = pdr_ss.revenue.amt_eth | ||
|
||
testshift = ppss.sim_ss.test_n - test_i - 1 # eg [99, 98, .., 2, 1, 0] | ||
|
||
data_f = AimodelDataFactory(pdr_ss) # type: ignore[arg-type] | ||
predict_feed = self.predict_train_feedset.predict | ||
train_feeds = self.predict_train_feedset.train_on | ||
|
@@ -155,8 +163,20 @@ def run_one_iter(self, test_i: int, mergedohlcv_df: pl.DataFrame): | |
ut = UnixTimeMs(recent_ut - testshift * timeframe.ms) | ||
|
||
# predict price direction | ||
prob_up: float = self.model.predict_ptrue(X_test)[0] # in [0.0, 1.0] | ||
prob_down: float = 1.0 - prob_up | ||
if self.ppss.sim_ss.use_own_model is not False: | ||
prob_up: float = self.model.predict_ptrue(X_test)[0] # in [0.0, 1.0] | ||
else: | ||
ut_seconds = ut.to_seconds() | ||
if ( | ||
self.prediction_dataset is not None | ||
and ut_seconds in self.prediction_dataset | ||
): | ||
# check if the current slot is in the keys | ||
prob_up = self.prediction_dataset[ut_seconds] | ||
else: | ||
return | ||
|
||
prob_down: Optional[float] = 1.0 - prob_up | ||
conf_up = (prob_up - 0.5) * 2.0 # to range [0,1] | ||
conf_down = (prob_down - 0.5) * 2.0 # to range [0,1] | ||
conf_threshold = self.ppss.trader_ss.sim_confidence_threshold | ||
|
@@ -271,6 +291,51 @@ def save_state(self, i: int, N: int): | |
return True, False | ||
|
||
@enforce_types | ||
def _get_prediction_dataset( | ||
self, start_slot: int, end_slot: int | ||
) -> Dict[int, Optional[float]]: | ||
contracts = query_feed_contracts( | ||
self.ppss.web3_pp.subgraph_url, | ||
self.ppss.web3_pp.owner_addrs, | ||
) | ||
|
||
sPE = 300 if self.predict_feed.timeframe == "5m" else 3600 | ||
# Filter contracts with the correct token pair and timeframe | ||
contract_to_use = [ | ||
addr | ||
for addr, feed in contracts.items() | ||
if feed.symbol | ||
== f"{self.predict_feed.pair.base_str}/{self.predict_feed.pair.quote_str}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use |
||
and feed.seconds_per_epoch == sPE | ||
] | ||
|
||
query = f""" | ||
SELECT | ||
slot, | ||
CASE | ||
WHEN roundSumStakes = 0.0 THEN NULL | ||
WHEN roundSumStakesUp = 0.0 THEN NULL | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this case when |
||
ELSE roundSumStakesUp / roundSumStakes | ||
END AS probUp | ||
FROM | ||
pdr_payouts | ||
WHERE | ||
slot > {start_slot} | ||
AND slot < {end_slot} | ||
AND ID LIKE '{contract_to_use[0]}%' | ||
""" | ||
|
||
db = DuckDBDataStore(self.ppss.lake_ss.lake_dir) | ||
df: pl.DataFrame = db.query_data(query) | ||
|
||
result_dict = {} | ||
|
||
for i in range(len(df)): | ||
if df["probUp"][i] is not None: | ||
result_dict[df["slot"][i]] = df["probUp"][i] | ||
|
||
return result_dict | ||
|
||
def _get_past_predictions_from_chain(self, ppss: PPSS): | ||
# calculate needed data start date | ||
current_time_s = int(time.time()) | ||
|
@@ -293,8 +358,7 @@ def _get_past_predictions_from_chain(self, ppss: PPSS): | |
|
||
# fetch data from subgraph | ||
gql_data_factory = GQLDataFactory(ppss) | ||
etl = ETL(ppss, gql_data_factory) | ||
etl.do_etl() | ||
gql_data_factory._update() | ||
time.sleep(3) | ||
|
||
# check if required data exists in the data base | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you rename this and the test to
_get_predictions_signals_data
? This way is more specific