Skip to content

Commit

Permalink
fix: mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
EvolveArt committed Jul 5, 2024
1 parent d8114b8 commit 8ea98ae
Show file tree
Hide file tree
Showing 16 changed files with 142 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def fetch(
else:
logger.debug("Skipping StarknetAMM for non supported pair: %s", pair)

return list(await asyncio.gather(*entries, return_exceptions=True))
return list(await asyncio.gather(*entries, return_exceptions=True)) # type: ignore[call-overload]

def _construct(self, pair: Pair, result: float) -> SpotEntry:
price_int = int(result * (10 ** pair.decimals()))
Expand Down
15 changes: 11 additions & 4 deletions pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/binance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def _fetch_volume(
volume_arr.append((element["symbol"], int(element["quoteVolume"])))
return volume_arr

async def fetch_pair(
async def fetch_pair( # type: ignore[override]
self, pair: Pair, session: ClientSession
) -> List[FutureEntry] | PublisherFetchError:
filtered_data = []
Expand All @@ -59,7 +59,7 @@ async def fetch_pair(
async def fetch(
self, session: ClientSession
) -> List[Entry | PublisherFetchError | BaseException]:
entries: Sequence[FutureEntry | PublisherFetchError] = []
entries: List[Entry | PublisherFetchError | BaseException] = []
for pair in self.pairs:
future_entries = await self.fetch_pair(pair, session)
if isinstance(future_entries, list):
Expand All @@ -71,14 +71,21 @@ async def fetch(
def format_url(self, pair: Optional[Pair] = None) -> str:
return self.BASE_URL

def _retrieve_volume(self, pair: Pair, volume_arr) -> int:
def _retrieve_volume(
self, pair: Pair, volume_arr: List[Tuple[str, int]] | PublisherFetchError
) -> int:
if isinstance(volume_arr, PublisherFetchError):
return 0
for list_pair, list_vol in volume_arr:
if pair == list_pair:
return list_vol
return 0

def _construct(
self, pair: Pair, result: Any, volume_arr: List[int]
self,
pair: Pair,
result: Any,
volume_arr: List[Tuple[str, int]] | PublisherFetchError,
) -> List[FutureEntry]:
result_arr = []
decimals = pair.decimals()
Expand Down
13 changes: 7 additions & 6 deletions pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/okx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ async def fetch_expiry_timestamp(
def format_expiry_timestamp_url(self, instrument_id: str) -> str:
return f"{self.TIMESTAMP_URL}?instType=FUTURES&instId={instrument_id}"

async def fetch_pair(
async def fetch_pair( # type: ignore[override]
self, pair: Pair, session: ClientSession
) -> PublisherFetchError | List[Entry]:
url = self.format_url(pair)
future_entries = []
future_entries: List[Entry] = []
async with session.get(url) as resp:
if resp.status == 404:
return PublisherFetchError(f"No data found for {pair} from OKX")
Expand All @@ -63,15 +63,16 @@ async def fetch_pair(
expiry_timestamp = await self.fetch_expiry_timestamp(
pair, result["data"][i]["instId"], session
)
future_entries.append(
self._construct(pair, result["data"][i], expiry_timestamp)
)
if not isinstance(expiry_timestamp, PublisherFetchError):
future_entries.append(
self._construct(pair, result["data"][i], expiry_timestamp)
)
return future_entries

async def fetch(
self, session: ClientSession
) -> List[Entry | PublisherFetchError | BaseException]:
entries = []
entries: List[Entry | PublisherFetchError | BaseException] = []
for pair in self.pairs:
future_entries = await self.fetch_pair(pair, session)
if isinstance(future_entries, list):
Expand Down
6 changes: 4 additions & 2 deletions pragma-sdk/pragma_sdk/common/types/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import List, Optional, Any
from typing import List, Optional, Union
from abc import ABC, abstractmethod

from pragma_sdk.common.types.entry import Entry
from pragma_sdk.common.types.types import ExecutionConfig
from pragma_sdk.common.utils import add_sync_methods
from pragma_sdk.offchain.types import PublishEntriesAPIResult
from pragma_sdk.onchain.types.types import PublishEntriesOnChainResult


@add_sync_methods
class PragmaClient(ABC):
@abstractmethod
async def publish_entries(
self, entries: List[Entry], execution_config: Optional[ExecutionConfig] = None
) -> Any:
) -> Union[PublishEntriesAPIResult, PublishEntriesOnChainResult]:
"""
Publish entries to some destination.
Expand Down
64 changes: 35 additions & 29 deletions pragma-sdk/pragma_sdk/offchain/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import asyncio
import time
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, Union

import aiohttp
from pragma_sdk.onchain.types.types import PublishEntriesOnChainResult
from requests import HTTPError
from starknet_py.net.models import StarknetChainId
from starknet_py.net.signer.stark_curve_signer import KeyPair, StarkCurveSigner

from pragma_sdk.common.types.entry import Entry, FutureEntry, SpotEntry
from pragma_sdk.common.types.types import AggregationMode, DataTypes
from pragma_sdk.common.types.types import AggregationMode, DataTypes, ExecutionConfig
from pragma_sdk.common.utils import add_sync_methods, get_cur_from_pair
from pragma_sdk.offchain.signer import OffchainSigner
from pragma_sdk.offchain.types import Interval
from pragma_sdk.offchain.types import Interval, PublishEntriesAPIResult

from pragma_sdk.common.types.client import PragmaClient

Expand Down Expand Up @@ -58,9 +59,9 @@ def __init__(
async def get_ohlc(
self,
pair: str,
timestamp: int = None,
interval: Interval = None,
aggregation: AggregationMode = None,
timestamp: Optional[int] = None,
interval: Optional[Interval] = None,
aggregation: Optional[AggregationMode] = None,
) -> "EntryResult":
"""
Retrieve OHLC data from the Pragma API.
Expand Down Expand Up @@ -98,9 +99,9 @@ async def get_ohlc(
async with aiohttp.ClientSession() as session:
async with session.get(
url, headers=headers, params=path_params
) as response:
status_code: int = response.status
response: Dict = await response.json()
) as response_raw:
status_code: int = response_raw.status
response: Dict = await response_raw.json()
if status_code == 200:
print(f"Success: {response}")
print("Get Ohlc successful")
Expand All @@ -111,8 +112,8 @@ async def get_ohlc(
return EntryResult(pair_id=response["pair_id"], data=response["data"])

async def publish_entries(
self, entries: List[Entry]
) -> (Optional[Dict], Optional[Dict]): # type: ignore
self, entries: List[Entry], _execution_config: Optional[ExecutionConfig] = None
) -> Union[PublishEntriesAPIResult, PublishEntriesOnChainResult]:
"""
Publishes spot and future entries to the Pragma API.
This function accepts both type of entries - but they need to be sent through
Expand All @@ -123,8 +124,8 @@ async def publish_entries(
"""
# We accept both types of entries - but they need to be sent through
# different endpoints & signed differently, so we split them here.
spot_entries: list[SpotEntry] = []
future_entries: list[FutureEntry] = []
spot_entries: List[Entry] = []
future_entries: List[Entry] = []

for entry in entries:
if isinstance(entry, SpotEntry):
Expand All @@ -140,7 +141,7 @@ async def publish_entries(
return spot_response, future_response

async def _create_entries(
self, entries: List[Entry], data_type: Optional[DataTypes] = DataTypes.SPOT
self, entries: List[Entry], data_type: DataTypes = DataTypes.SPOT
) -> Optional[Dict]:
"""
Publishes entries to the Pragma API & returns the http response.
Expand All @@ -152,7 +153,7 @@ async def _create_entries(
raise PragmaAPIError("No offchain signer set")

if len(entries) == 0:
return
return None

assert all(isinstance(entry, type(entries[0])) for entry in entries)

Expand All @@ -175,9 +176,9 @@ async def _create_entries(
}

async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
status_code: int = response.status
response: Dict = await response.json()
async with session.post(url, headers=headers, json=data) as response_raw:
status_code: int = response_raw.status
response: Dict = await response_raw.json()
if status_code == 200:
print(f"Success: {response}")
print("Publish successful")
Expand All @@ -189,10 +190,10 @@ async def _create_entries(
async def get_entry(
self,
pair: str,
timestamp: int = None,
interval: Interval = None,
aggregation: AggregationMode = None,
routing: bool = None,
timestamp: Optional[int] = None,
interval: Optional[Interval] = None,
aggregation: Optional[AggregationMode] = None,
routing: Optional[bool] = None,
) -> "EntryResult":
"""
Get data aggregated on the Pragma API.
Expand Down Expand Up @@ -225,9 +226,9 @@ async def get_entry(
}

async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, params=params) as response:
status_code: int = response.status
response: Dict = await response.json()
async with session.get(url, headers=headers, params=params) as response_raw:
status_code: int = response_raw.status
response: Dict = await response_raw.json()
if status_code == 200:
print(f"Success: {response}")
print("Get Data successful")
Expand Down Expand Up @@ -270,9 +271,9 @@ async def get_volatility(self, pair: str, start: int, end: int):
url = f"{self.api_base_url}{endpoint}"
# Send GET request with headers
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, params=params) as response:
status_code: int = response.status
response: Dict = await response.json()
async with session.get(url, headers=headers, params=params) as response_raw:
status_code: int = response_raw.status
response: Dict = await response_raw.json()
if status_code == 200:
print(f"Success: {response}")
print("Get Volatility successful")
Expand All @@ -296,7 +297,12 @@ def get_endpoint_publish_offchain(data_type: DataTypes):

class EntryResult:
def __init__(
self, pair_id, data, num_sources_aggregated=0, timestamp=None, decimals=None
self,
pair_id: str,
data: Any,
num_sources_aggregated: int = 0,
timestamp: Optional[int] = None,
decimals: Optional[int] = None,
):
self.pair_id = pair_id
self.data = data
Expand Down
6 changes: 3 additions & 3 deletions pragma-sdk/pragma_sdk/offchain/signer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Tuple

from starknet_py.net.signer.stark_curve_signer import StarkCurveSigner
from starknet_py.utils.typed_data import TypedData
Expand Down Expand Up @@ -47,7 +47,7 @@ def build_publish_message(
},
}
if data_type == DataTypes.FUTURE:
message["types"]["Entry"] = message["types"]["Entry"] + [
message["types"]["Entry"] += [ # type: ignore[index]
{"name": "expiration_timestamp", "type": "felt"},
]

Expand All @@ -64,7 +64,7 @@ def __init__(self, signer: StarkCurveSigner):

def sign_publish_message(
self, entries: List[Entry], data_type: Optional[DataTypes] = DataTypes.SPOT
) -> (List[int], int): # type: ignore
) -> Tuple[List[int], int]:
"""
Sign a publish message
Expand Down
3 changes: 3 additions & 0 deletions pragma-sdk/pragma_sdk/offchain/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from enum import StrEnum, unique
from typing import Dict, Optional, Tuple

PublishEntriesAPIResult = Tuple[Optional[Dict], Optional[Dict]]


@unique
Expand Down
21 changes: 12 additions & 9 deletions pragma-sdk/pragma_sdk/onchain/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import List, Optional
from typing import List, Optional, Union

from pragma_sdk.offchain.types import PublishEntriesAPIResult
from pragma_sdk.onchain.types.types import NetworkName, PublishEntriesOnChainResult
from starknet_py.net.account.account import Account
from starknet_py.net.full_node_client import FullNodeClient
from starknet_py.net.client import Client
Expand Down Expand Up @@ -30,7 +32,7 @@
logger.setLevel(logging.INFO)


class PragmaOnChainClient(
class PragmaOnChainClient( # type: ignore[misc]
PragmaClient,
NonceMixin,
OracleMixin,
Expand All @@ -39,6 +41,7 @@ class PragmaOnChainClient(
):
"""
Client for interacting with Pragma on Starknet.
:param network: Target network for the client.
Can be a URL string, or one of
``"mainnet"``, ``"sepolia"`` or ``"devnet"``
Expand All @@ -64,7 +67,7 @@ def __init__(
account_contract_address: Optional[Address] = None,
contract_addresses_config: Optional[ContractAddresses] = None,
port: Optional[int] = None,
chain_name: Optional[str] = None,
chain_name: Optional[NetworkName] = None,
execution_config: Optional[ExecutionConfig] = None,
):
full_node_client: FullNodeClient = get_full_node_client_from_network(
Expand All @@ -73,13 +76,13 @@ def __init__(
self.full_node_client = full_node_client
self.client = full_node_client

if network.startswith("http") and chain_name is None:
if network.startswith("http") and chain_name is None: # type: ignore[union-attr]
raise ClientException(
f"Network provided is a URL: {network} but `chain_name` is not provided."
)

self.network = (
network if not (network.startswith("http") and chain_name) else chain_name
network if not (network.startswith("http") and chain_name) else chain_name # type: ignore[union-attr]
)

if execution_config is not None:
Expand All @@ -93,14 +96,14 @@ def __init__(
)

if not contract_addresses_config:
contract_addresses_config = CONTRACT_ADDRESSES[self.network]
contract_addresses_config = CONTRACT_ADDRESSES[self.network] # type: ignore[index]

self.contract_addresses_config = contract_addresses_config
self._setup_contracts()

async def publish_entries(
self, entries: List[Entry], execution_config: Optional[ExecutionConfig] = None
) -> List[InvokeResult]:
) -> Union[PublishEntriesAPIResult, PublishEntriesOnChainResult]:
"""
Publish entries on-chain.
Expand Down Expand Up @@ -145,7 +148,7 @@ async def get_balance(self, account_contract_address, token_address=None) -> int
key_pair=KeyPair.from_private_key(1),
chain=CHAIN_IDS[self.network],
)
return await client.get_balance(token_address)
return await client.get_balance(token_address) # type: ignore[no-any-return]

def set_account(
self,
Expand Down Expand Up @@ -185,7 +188,7 @@ def account_address(self) -> Address:
Return the account address.
"""

return self.account.address
return self.account.address # type: ignore[no-any-return]

def init_stats_contract(
self,
Expand Down
2 changes: 1 addition & 1 deletion pragma-sdk/pragma_sdk/onchain/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from starknet_py.net.models.chains import StarknetChainId


CHAIN_IDS: Dict[Network, int] = {
CHAIN_IDS: Dict[Network, StarknetChainId] = {
"devnet": StarknetChainId.MAINNET,
"mainnet": StarknetChainId.MAINNET,
"sepolia": StarknetChainId.SEPOLIA,
Expand Down
Loading

0 comments on commit 8ea98ae

Please sign in to comment.