Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #772 - TableRegistry Implementation #781

Merged
merged 5 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pdr_backend/accuracy/test/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

@enforce_types
def test_calculate_prediction_result():

# Test the calculate_prediction_prediction_result function with expected inputs
result = calculate_prediction_result(150.0, 200.0)
assert result
Expand Down
1 change: 0 additions & 1 deletion pdr_backend/aimodel/aimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

@enforce_types
class Aimodel:

def __init__(self, skm, scaler, imps: np.ndarray):
self._skm = skm # sklearn model
self._scaler = scaler # for scaling X-inputs
Expand Down
40 changes: 14 additions & 26 deletions pdr_backend/lake/etl.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Dict
import time

from pdr_backend.ppss.ppss import PPSS
from pdr_backend.lake.gql_data_factory import GQLDataFactory
from pdr_backend.lake.table import Table
from pdr_backend.lake.table_bronze_pdr_predictions import (
bronze_pdr_predictions_table_name,
bronze_pdr_predictions_schema,
get_bronze_pdr_predictions_data_with_SQL,
)
from pdr_backend.lake.table_registry import TableRegistry


class ETL:
Expand All @@ -25,9 +24,16 @@ class ETL:

def __init__(self, ppss: PPSS, gql_data_factory: GQLDataFactory):
self.ppss = ppss

self.gql_data_factory = gql_data_factory
self.tables: Dict[str, Table] = {}

TableRegistry().register_table(
bronze_pdr_predictions_table_name,
(
bronze_pdr_predictions_table_name,
bronze_pdr_predictions_schema,
self.ppss,
),
)

def do_etl(self):
"""
Expand All @@ -48,18 +54,6 @@ def do_etl(self):
except Exception as e:
print(f"Error when executing ETL: {e}")

def do_sync_step(self):
"""
@description
Call data factory to fetch data and update lake
The sync will try 3 times to fetch from data_factory, and update the local gql_dfs
"""
gql_tables = self.gql_data_factory.get_gql_tables()

# rather than override the whole dict, we update the dict
for key in gql_tables:
self.tables[key] = gql_tables[key]

def do_bronze_step(self):
"""
@description
Expand All @@ -83,14 +77,8 @@ def update_bronze_pdr_predictions(self):
@description
Update bronze_pdr_predictions table
"""
if bronze_pdr_predictions_table_name not in self.tables:
# Load existing bronze tables
table = Table(
bronze_pdr_predictions_table_name,
bronze_pdr_predictions_schema,
self.ppss,
)
self.tables[bronze_pdr_predictions_table_name] = table

print("update_bronze_pdr_predictions - Update bronze_pdr_predictions table.")
data = get_bronze_pdr_predictions_data_with_SQL(self.ppss)
table.append_to_storage(data)
TableRegistry().get_table(bronze_pdr_predictions_table_name).append_to_storage(
data
)
52 changes: 32 additions & 20 deletions pdr_backend/lake/gql_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pdr_backend.util.time_types import UnixTimeMs
from pdr_backend.lake.plutil import _object_list_to_df
from pdr_backend.lake.table_pdr_predictions import _transform_timestamp_to_ms
from pdr_backend.lake.table_registry import TableRegistry

logger = logging.getLogger("gql_data_factory")

Expand Down Expand Up @@ -63,31 +64,40 @@ def __init__(self, ppss: PPSS):

# configure all tables that will be recorded onto lake
self.record_config = {
"tables": {
"pdr_predictions": Table(
predictions_table_name,
predictions_schema,
ppss,
),
"pdr_subscriptions": Table(
subscriptions_table_name,
subscriptions_schema,
ppss,
),
"pdr_truevals": Table(truevals_table_name, truevals_schema, ppss),
"pdr_payouts": Table(payouts_table_name, payouts_schema, ppss),
},
"fetch_functions": {
"pdr_predictions": fetch_filtered_predictions,
"pdr_subscriptions": fetch_filtered_subscriptions,
"pdr_truevals": fetch_truevals,
"pdr_payouts": fetch_payouts,
predictions_table_name: fetch_filtered_predictions,
subscriptions_table_name: fetch_filtered_subscriptions,
truevals_table_name: fetch_truevals,
payouts_table_name: fetch_payouts,
},
"config": {
"contract_list": contract_list,
},
"gql_tables": [
predictions_table_name,
subscriptions_table_name,
truevals_table_name,
payouts_table_name,
],
}

TableRegistry().register_tables(
{
predictions_table_name: (
predictions_table_name,
predictions_schema,
self.ppss,
),
subscriptions_table_name: (
subscriptions_table_name,
subscriptions_schema,
self.ppss,
),
truevals_table_name: (truevals_table_name, truevals_schema, self.ppss),
payouts_table_name: (payouts_table_name, payouts_schema, self.ppss),
}
)

@enforce_types
def get_gql_tables(self) -> Dict[str, Table]:
"""
Expand All @@ -110,7 +120,7 @@ def get_gql_tables(self) -> Dict[str, Table]:
self._update()
logger.info("Get historical data across many subgraphs. Done.")

return self.record_config["tables"]
return TableRegistry().get_tables(self.record_config["gql_tables"])

@enforce_types
def _do_subgraph_fetch(
Expand Down Expand Up @@ -225,7 +235,9 @@ def _update(self):
fin_ut -- a timestamp, in ms, in UTC
"""

for _, table in self.record_config["tables"].items():
for _, table in (
TableRegistry().get_tables(self.record_config["gql_tables"]).items()
):
st_ut = self._calc_start_ut(table)
fin_ut = self.ppss.lake_ss.fin_timestamp
print(f" Aim to fetch data from start time: {st_ut.pretty_timestr()}")
Expand Down
9 changes: 3 additions & 6 deletions pdr_backend/lake/table_bronze_pdr_predictions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import polars as pl
from polars import Boolean, Float64, Int64, Utf8

from pdr_backend.lake.table import Table
from pdr_backend.ppss.ppss import PPSS

from pdr_backend.lake.persistent_data_store import PersistentDataStore

bronze_pdr_predictions_table_name = "bronze_pdr_predictions"

Expand Down Expand Up @@ -32,11 +31,9 @@ def get_bronze_pdr_predictions_data_with_SQL(ppss: PPSS) -> pl.DataFrame:
Get the bronze pdr predictions data
"""
# get the table
table = Table(
bronze_pdr_predictions_table_name, bronze_pdr_predictions_schema, ppss
)
PDS = PersistentDataStore(ppss.lake_ss.parquet_dir)

return table.PDS.query_data(
return PDS.query_data(
f"""
SELECT
pdr_predictions.ID as ID,
Expand Down
55 changes: 55 additions & 0 deletions pdr_backend/lake/table_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Optional, Dict, Tuple, List
from polars.type_aliases import SchemaDict
from enforce_typing import enforce_types

from pdr_backend.ppss.ppss import PPSS
from pdr_backend.lake.table import Table


class TableRegistry:
_instance: Optional["TableRegistry"] = None
idiom-bytes marked this conversation as resolved.
Show resolved Hide resolved
_tables: Dict[str, Table] = {}

def __new__(cls):
if cls._instance is None:
cls._instance = super(TableRegistry, cls).__new__(cls)
return cls._instance

@enforce_types
def register_table(self, table_name: str, table_args: Tuple[str, SchemaDict, PPSS]):
if table_name in self._tables:
pass
self._tables[table_name] = Table(*table_args)
return self._tables[table_name]

@enforce_types
def register_tables(self, tables: Dict[str, Tuple[str, SchemaDict, PPSS]]):
for table_name, table_args in tables.items():
self.register_table(table_name, table_args)

@enforce_types
def get_table(self, table_name: str):
return self._tables.get(table_name)

@enforce_types
def get_tables(self, table_names: Optional[List] = None):
if table_names is None:
table_names = list(self._tables.keys())

target_tables = [self._tables.get(table_name) for table_name in table_names]

# do it this way to avoid returning None
target_tables = [table for table in target_tables if table is not None]

# do it a dictionary to preserve order
return dict(zip(table_names, target_tables))

@enforce_types
def unregister_table(self, table_name):
self._tables.pop(table_name, None)
return True

@enforce_types
def clear_tables(self):
self._tables = {}
return True
6 changes: 6 additions & 0 deletions pdr_backend/lake/test/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pdr_backend.ppss.lake_ss import LakeSS
from pdr_backend.ppss.ppss import mock_ppss
from pdr_backend.ppss.web3_pp import mock_web3_pp
from pdr_backend.lake.table_registry import TableRegistry


@enforce_types
Expand All @@ -33,6 +34,11 @@ def _lake_ss_1feed(tmpdir, feed, st_timestr=None, fin_timestr=None):
return ss, ohlcv_data_factory


@enforce_types
def _clean_up_table_registry():
TableRegistry()._tables = {}


@enforce_types
def _gql_data_factory(tmpdir, feed, st_timestr=None, fin_timestr=None):
network = "sapphire-mainnet"
Expand Down
15 changes: 7 additions & 8 deletions pdr_backend/lake/test/test_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from pdr_backend.lake.table_pdr_truevals import truevals_schema, truevals_table_name
from pdr_backend.lake.table_pdr_payouts import payouts_schema, payouts_table_name
from pdr_backend.lake.test.conftest import _clean_up_persistent_data_store
from pdr_backend.lake.table_registry import TableRegistry
from pdr_backend.lake.test.resources import _clean_up_table_registry

# ETL code-coverage
# Step 1. ETL -> do_sync_step()
Expand Down Expand Up @@ -40,6 +42,7 @@ def test_setup_etl(
tmpdir,
):
_clean_up_persistent_data_store(tmpdir)
_clean_up_table_registry()

# setup test start-end date
st_timestr = "2023-11-02_0:00"
Expand Down Expand Up @@ -84,15 +87,10 @@ def test_setup_etl(

assert etl is not None
assert etl.gql_data_factory == gql_data_factory
assert len(etl.tables) == 0

# Work 2: Complete ETL sync step - Assert 3 gql_dfs
etl.do_sync_step()

pds_instance = _get_test_PDS(tmpdir)

# Assert original gql has 6 predictions, but we only got 5 due to date
assert len(etl.tables) == 3
pdr_predictions_df = pds_instance.query_data("SELECT * FROM pdr_predictions")
assert len(pdr_predictions_df) == 5
assert len(_gql_datafactory_etl_predictions_df) == 6
Expand All @@ -110,6 +108,7 @@ def test_setup_etl(
assert len(pdr_payouts_df) == 4
assert len(pdr_predictions_df) == 5
assert len(pdr_truevals_df) == 5
assert len(TableRegistry().get_tables()) == 5


@enforce_types
Expand All @@ -123,6 +122,8 @@ def test_etl_do_bronze_step(
tmpdir,
):
_clean_up_persistent_data_store(tmpdir)
_clean_up_table_registry()

# please note date, including Nov 1st
st_timestr = "2023-11-01_0:00"
fin_timestr = "2023-11-07_0:00"
Expand Down Expand Up @@ -159,16 +160,14 @@ def test_etl_do_bronze_step(
# Work 1: Initialize ETL
etl = ETL(ppss, gql_data_factory)

# Work 2: Do sync
etl.do_sync_step()

pds_instance = _get_test_PDS(tmpdir)
pdr_predictions_records = pds_instance.query_data(
f"""
SELECT * FROM {predictions_table_name}
"""
)
assert len(pdr_predictions_records) == 6

# Work 3: Do bronze
etl.do_bronze_step()

Expand Down
Loading
Loading