diff --git a/pragma-sdk/pragma_sdk/common/fetchers/fetchers/starknetamm.py b/pragma-sdk/pragma_sdk/common/fetchers/fetchers/starknetamm.py index a0d4b390..430142af 100644 --- a/pragma-sdk/pragma_sdk/common/fetchers/fetchers/starknetamm.py +++ b/pragma-sdk/pragma_sdk/common/fetchers/fetchers/starknetamm.py @@ -78,7 +78,7 @@ async def fetch( else: logger.debug("Skipping StarknetAMM for non supported pair: %s", pair) - return list(await asyncio.gather(*entries, return_exceptions=True)) + return list(await asyncio.gather(*entries, return_exceptions=True)) # type: ignore[call-overload] def _construct(self, pair: Pair, result: float) -> SpotEntry: price_int = int(result * (10 ** pair.decimals())) diff --git a/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/binance.py b/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/binance.py index 6bf8c86f..9e0dde33 100644 --- a/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/binance.py +++ b/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/binance.py @@ -33,7 +33,7 @@ async def _fetch_volume( volume_arr.append((element["symbol"], int(element["quoteVolume"]))) return volume_arr - async def fetch_pair( + async def fetch_pair( # type: ignore[override] self, pair: Pair, session: ClientSession ) -> List[FutureEntry] | PublisherFetchError: filtered_data = [] @@ -59,7 +59,7 @@ async def fetch_pair( async def fetch( self, session: ClientSession ) -> List[Entry | PublisherFetchError | BaseException]: - entries: Sequence[FutureEntry | PublisherFetchError] = [] + entries: List[Entry | PublisherFetchError | BaseException] = [] for pair in self.pairs: future_entries = await self.fetch_pair(pair, session) if isinstance(future_entries, list): @@ -71,14 +71,21 @@ async def fetch( def format_url(self, pair: Optional[Pair] = None) -> str: return self.BASE_URL - def _retrieve_volume(self, pair: Pair, volume_arr) -> int: + def _retrieve_volume( + self, pair: Pair, volume_arr: List[Tuple[str, int]] | PublisherFetchError + ) -> int: + if isinstance(volume_arr, PublisherFetchError): + return 0 for list_pair, list_vol in volume_arr: if pair == list_pair: return list_vol return 0 def _construct( - self, pair: Pair, result: Any, volume_arr: List[int] + self, + pair: Pair, + result: Any, + volume_arr: List[Tuple[str, int]] | PublisherFetchError, ) -> List[FutureEntry]: result_arr = [] decimals = pair.decimals() diff --git a/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/okx.py b/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/okx.py index e57139f2..5a89fac1 100644 --- a/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/okx.py +++ b/pragma-sdk/pragma_sdk/common/fetchers/future_fetchers/okx.py @@ -36,11 +36,11 @@ async def fetch_expiry_timestamp( def format_expiry_timestamp_url(self, instrument_id: str) -> str: return f"{self.TIMESTAMP_URL}?instType=FUTURES&instId={instrument_id}" - async def fetch_pair( + async def fetch_pair( # type: ignore[override] self, pair: Pair, session: ClientSession ) -> PublisherFetchError | List[Entry]: url = self.format_url(pair) - future_entries = [] + future_entries: List[Entry] = [] async with session.get(url) as resp: if resp.status == 404: return PublisherFetchError(f"No data found for {pair} from OKX") @@ -63,15 +63,16 @@ async def fetch_pair( expiry_timestamp = await self.fetch_expiry_timestamp( pair, result["data"][i]["instId"], session ) - future_entries.append( - self._construct(pair, result["data"][i], expiry_timestamp) - ) + if not isinstance(expiry_timestamp, PublisherFetchError): + future_entries.append( + self._construct(pair, result["data"][i], expiry_timestamp) + ) return future_entries async def fetch( self, session: ClientSession ) -> List[Entry | PublisherFetchError | BaseException]: - entries = [] + entries: List[Entry | PublisherFetchError | BaseException] = [] for pair in self.pairs: future_entries = await self.fetch_pair(pair, session) if isinstance(future_entries, list): diff --git a/pragma-sdk/pragma_sdk/common/types/client.py b/pragma-sdk/pragma_sdk/common/types/client.py index e16dd8e3..1f70642e 100644 --- a/pragma-sdk/pragma_sdk/common/types/client.py +++ b/pragma-sdk/pragma_sdk/common/types/client.py @@ -1,9 +1,11 @@ -from typing import List, Optional, Any +from typing import List, Optional, Union from abc import ABC, abstractmethod from pragma_sdk.common.types.entry import Entry from pragma_sdk.common.types.types import ExecutionConfig from pragma_sdk.common.utils import add_sync_methods +from pragma_sdk.offchain.types import PublishEntriesAPIResult +from pragma_sdk.onchain.types.types import PublishEntriesOnChainResult @add_sync_methods @@ -11,7 +13,7 @@ class PragmaClient(ABC): @abstractmethod async def publish_entries( self, entries: List[Entry], execution_config: Optional[ExecutionConfig] = None - ) -> Any: + ) -> Union[PublishEntriesAPIResult, PublishEntriesOnChainResult]: """ Publish entries to some destination. diff --git a/pragma-sdk/pragma_sdk/offchain/client.py b/pragma-sdk/pragma_sdk/offchain/client.py index bed9e0e3..4e0aac61 100644 --- a/pragma-sdk/pragma_sdk/offchain/client.py +++ b/pragma-sdk/pragma_sdk/offchain/client.py @@ -1,17 +1,18 @@ import asyncio import time -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import aiohttp +from pragma_sdk.onchain.types.types import PublishEntriesOnChainResult from requests import HTTPError from starknet_py.net.models import StarknetChainId from starknet_py.net.signer.stark_curve_signer import KeyPair, StarkCurveSigner from pragma_sdk.common.types.entry import Entry, FutureEntry, SpotEntry -from pragma_sdk.common.types.types import AggregationMode, DataTypes +from pragma_sdk.common.types.types import AggregationMode, DataTypes, ExecutionConfig from pragma_sdk.common.utils import add_sync_methods, get_cur_from_pair from pragma_sdk.offchain.signer import OffchainSigner -from pragma_sdk.offchain.types import Interval +from pragma_sdk.offchain.types import Interval, PublishEntriesAPIResult from pragma_sdk.common.types.client import PragmaClient @@ -58,9 +59,9 @@ def __init__( async def get_ohlc( self, pair: str, - timestamp: int = None, - interval: Interval = None, - aggregation: AggregationMode = None, + timestamp: Optional[int] = None, + interval: Optional[Interval] = None, + aggregation: Optional[AggregationMode] = None, ) -> "EntryResult": """ Retrieve OHLC data from the Pragma API. @@ -98,9 +99,9 @@ async def get_ohlc( async with aiohttp.ClientSession() as session: async with session.get( url, headers=headers, params=path_params - ) as response: - status_code: int = response.status - response: Dict = await response.json() + ) as response_raw: + status_code: int = response_raw.status + response: Dict = await response_raw.json() if status_code == 200: print(f"Success: {response}") print("Get Ohlc successful") @@ -111,8 +112,8 @@ async def get_ohlc( return EntryResult(pair_id=response["pair_id"], data=response["data"]) async def publish_entries( - self, entries: List[Entry] - ) -> (Optional[Dict], Optional[Dict]): # type: ignore + self, entries: List[Entry], _execution_config: Optional[ExecutionConfig] = None + ) -> Union[PublishEntriesAPIResult, PublishEntriesOnChainResult]: """ Publishes spot and future entries to the Pragma API. This function accepts both type of entries - but they need to be sent through @@ -123,8 +124,8 @@ async def publish_entries( """ # We accept both types of entries - but they need to be sent through # different endpoints & signed differently, so we split them here. - spot_entries: list[SpotEntry] = [] - future_entries: list[FutureEntry] = [] + spot_entries: List[Entry] = [] + future_entries: List[Entry] = [] for entry in entries: if isinstance(entry, SpotEntry): @@ -140,7 +141,7 @@ async def publish_entries( return spot_response, future_response async def _create_entries( - self, entries: List[Entry], data_type: Optional[DataTypes] = DataTypes.SPOT + self, entries: List[Entry], data_type: DataTypes = DataTypes.SPOT ) -> Optional[Dict]: """ Publishes entries to the Pragma API & returns the http response. @@ -152,7 +153,7 @@ async def _create_entries( raise PragmaAPIError("No offchain signer set") if len(entries) == 0: - return + return None assert all(isinstance(entry, type(entries[0])) for entry in entries) @@ -175,9 +176,9 @@ async def _create_entries( } async with aiohttp.ClientSession() as session: - async with session.post(url, headers=headers, json=data) as response: - status_code: int = response.status - response: Dict = await response.json() + async with session.post(url, headers=headers, json=data) as response_raw: + status_code: int = response_raw.status + response: Dict = await response_raw.json() if status_code == 200: print(f"Success: {response}") print("Publish successful") @@ -189,10 +190,10 @@ async def _create_entries( async def get_entry( self, pair: str, - timestamp: int = None, - interval: Interval = None, - aggregation: AggregationMode = None, - routing: bool = None, + timestamp: Optional[int] = None, + interval: Optional[Interval] = None, + aggregation: Optional[AggregationMode] = None, + routing: Optional[bool] = None, ) -> "EntryResult": """ Get data aggregated on the Pragma API. @@ -225,9 +226,9 @@ async def get_entry( } async with aiohttp.ClientSession() as session: - async with session.get(url, headers=headers, params=params) as response: - status_code: int = response.status - response: Dict = await response.json() + async with session.get(url, headers=headers, params=params) as response_raw: + status_code: int = response_raw.status + response: Dict = await response_raw.json() if status_code == 200: print(f"Success: {response}") print("Get Data successful") @@ -270,9 +271,9 @@ async def get_volatility(self, pair: str, start: int, end: int): url = f"{self.api_base_url}{endpoint}" # Send GET request with headers async with aiohttp.ClientSession() as session: - async with session.get(url, headers=headers, params=params) as response: - status_code: int = response.status - response: Dict = await response.json() + async with session.get(url, headers=headers, params=params) as response_raw: + status_code: int = response_raw.status + response: Dict = await response_raw.json() if status_code == 200: print(f"Success: {response}") print("Get Volatility successful") @@ -296,7 +297,12 @@ def get_endpoint_publish_offchain(data_type: DataTypes): class EntryResult: def __init__( - self, pair_id, data, num_sources_aggregated=0, timestamp=None, decimals=None + self, + pair_id: str, + data: Any, + num_sources_aggregated: int = 0, + timestamp: Optional[int] = None, + decimals: Optional[int] = None, ): self.pair_id = pair_id self.data = data diff --git a/pragma-sdk/pragma_sdk/offchain/signer.py b/pragma-sdk/pragma_sdk/offchain/signer.py index ce536865..26bb0b81 100644 --- a/pragma-sdk/pragma_sdk/offchain/signer.py +++ b/pragma-sdk/pragma_sdk/offchain/signer.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple from starknet_py.net.signer.stark_curve_signer import StarkCurveSigner from starknet_py.utils.typed_data import TypedData @@ -47,7 +47,7 @@ def build_publish_message( }, } if data_type == DataTypes.FUTURE: - message["types"]["Entry"] = message["types"]["Entry"] + [ + message["types"]["Entry"] += [ # type: ignore[index] {"name": "expiration_timestamp", "type": "felt"}, ] @@ -64,7 +64,7 @@ def __init__(self, signer: StarkCurveSigner): def sign_publish_message( self, entries: List[Entry], data_type: Optional[DataTypes] = DataTypes.SPOT - ) -> (List[int], int): # type: ignore + ) -> Tuple[List[int], int]: """ Sign a publish message diff --git a/pragma-sdk/pragma_sdk/offchain/types.py b/pragma-sdk/pragma_sdk/offchain/types.py index e7c648a5..341310c8 100644 --- a/pragma-sdk/pragma_sdk/offchain/types.py +++ b/pragma-sdk/pragma_sdk/offchain/types.py @@ -1,4 +1,7 @@ from enum import StrEnum, unique +from typing import Dict, Optional, Tuple + +PublishEntriesAPIResult = Tuple[Optional[Dict], Optional[Dict]] @unique diff --git a/pragma-sdk/pragma_sdk/onchain/client.py b/pragma-sdk/pragma_sdk/onchain/client.py index 34ad7630..eaddc8d9 100644 --- a/pragma-sdk/pragma_sdk/onchain/client.py +++ b/pragma-sdk/pragma_sdk/onchain/client.py @@ -1,6 +1,8 @@ import logging -from typing import List, Optional +from typing import List, Optional, Union +from pragma_sdk.offchain.types import PublishEntriesAPIResult +from pragma_sdk.onchain.types.types import NetworkName, PublishEntriesOnChainResult from starknet_py.net.account.account import Account from starknet_py.net.full_node_client import FullNodeClient from starknet_py.net.client import Client @@ -30,7 +32,7 @@ logger.setLevel(logging.INFO) -class PragmaOnChainClient( +class PragmaOnChainClient( # type: ignore[misc] PragmaClient, NonceMixin, OracleMixin, @@ -39,6 +41,7 @@ class PragmaOnChainClient( ): """ Client for interacting with Pragma on Starknet. + :param network: Target network for the client. Can be a URL string, or one of ``"mainnet"``, ``"sepolia"`` or ``"devnet"`` @@ -64,7 +67,7 @@ def __init__( account_contract_address: Optional[Address] = None, contract_addresses_config: Optional[ContractAddresses] = None, port: Optional[int] = None, - chain_name: Optional[str] = None, + chain_name: Optional[NetworkName] = None, execution_config: Optional[ExecutionConfig] = None, ): full_node_client: FullNodeClient = get_full_node_client_from_network( @@ -73,13 +76,13 @@ def __init__( self.full_node_client = full_node_client self.client = full_node_client - if network.startswith("http") and chain_name is None: + if network.startswith("http") and chain_name is None: # type: ignore[union-attr] raise ClientException( f"Network provided is a URL: {network} but `chain_name` is not provided." ) self.network = ( - network if not (network.startswith("http") and chain_name) else chain_name + network if not (network.startswith("http") and chain_name) else chain_name # type: ignore[union-attr] ) if execution_config is not None: @@ -93,14 +96,14 @@ def __init__( ) if not contract_addresses_config: - contract_addresses_config = CONTRACT_ADDRESSES[self.network] + contract_addresses_config = CONTRACT_ADDRESSES[self.network] # type: ignore[index] self.contract_addresses_config = contract_addresses_config self._setup_contracts() async def publish_entries( self, entries: List[Entry], execution_config: Optional[ExecutionConfig] = None - ) -> List[InvokeResult]: + ) -> Union[PublishEntriesAPIResult, PublishEntriesOnChainResult]: """ Publish entries on-chain. @@ -145,7 +148,7 @@ async def get_balance(self, account_contract_address, token_address=None) -> int key_pair=KeyPair.from_private_key(1), chain=CHAIN_IDS[self.network], ) - return await client.get_balance(token_address) + return await client.get_balance(token_address) # type: ignore[no-any-return] def set_account( self, @@ -185,7 +188,7 @@ def account_address(self) -> Address: Return the account address. """ - return self.account.address + return self.account.address # type: ignore[no-any-return] def init_stats_contract( self, diff --git a/pragma-sdk/pragma_sdk/onchain/constants.py b/pragma-sdk/pragma_sdk/onchain/constants.py index 281cb253..182a1853 100644 --- a/pragma-sdk/pragma_sdk/onchain/constants.py +++ b/pragma-sdk/pragma_sdk/onchain/constants.py @@ -4,7 +4,7 @@ from starknet_py.net.models.chains import StarknetChainId -CHAIN_IDS: Dict[Network, int] = { +CHAIN_IDS: Dict[Network, StarknetChainId] = { "devnet": StarknetChainId.MAINNET, "mainnet": StarknetChainId.MAINNET, "sepolia": StarknetChainId.SEPOLIA, diff --git a/pragma-sdk/pragma_sdk/onchain/mixins/__init__.py b/pragma-sdk/pragma_sdk/onchain/mixins/__init__.py index 35fca0ce..1abe8e0c 100644 --- a/pragma-sdk/pragma_sdk/onchain/mixins/__init__.py +++ b/pragma-sdk/pragma_sdk/onchain/mixins/__init__.py @@ -5,9 +5,9 @@ from pragma_sdk.onchain.mixins.summary import SummaryStatsMixin __all__ = [ - NonceMixin, - OracleMixin, - PublisherRegistryMixin, - RandomnessMixin, - SummaryStatsMixin, + "NonceMixin", + "OracleMixin", + "PublisherRegistryMixin", + "RandomnessMixin", + "SummaryStatsMixin", ] diff --git a/pragma-sdk/pragma_sdk/onchain/mixins/nonce.py b/pragma-sdk/pragma_sdk/onchain/mixins/nonce.py index 844220cb..1d98373b 100644 --- a/pragma-sdk/pragma_sdk/onchain/mixins/nonce.py +++ b/pragma-sdk/pragma_sdk/onchain/mixins/nonce.py @@ -63,7 +63,7 @@ async def track_nonce( self, nonce: int, transaction_hash: int, - ): + ) -> None: """ Callback function to track the nonce of a transaction. Will update the nonce_dict and pending_nonce attributes. @@ -118,7 +118,8 @@ async def get_nonce( self.account_contract_address, block_number=block_number, ) - return nonce + + return int(nonce) async def get_status( self, diff --git a/pragma-sdk/pragma_sdk/onchain/mixins/oracle.py b/pragma-sdk/pragma_sdk/onchain/mixins/oracle.py index f635686b..3f432fa3 100644 --- a/pragma-sdk/pragma_sdk/onchain/mixins/oracle.py +++ b/pragma-sdk/pragma_sdk/onchain/mixins/oracle.py @@ -1,10 +1,11 @@ import time -from typing import Dict, List, Optional +from typing import Awaitable, Callable, Coroutine, Dict, List, Optional from deprecated import deprecated from starknet_py.contract import InvokeResult from starknet_py.net.account.account import Account from starknet_py.net.client import Client +from starknet_py.net.client_models import SentTransactionResponse from pragma_sdk.onchain.types import Contract @@ -27,6 +28,9 @@ class OracleMixin: client: Client account: Account execution_config: ExecutionConfig + oracle: Contract + is_user_client: bool = False + track_nonce: Callable[[object, int, int], Coroutine[None, None, None]] @deprecated async def publish_spot_entry( @@ -76,8 +80,12 @@ async def publish_many( invocations: List[InvokeResult] = [] - spot_entries = [entry for entry in entries if isinstance(entry, SpotEntry)] - future_entries = [entry for entry in entries if isinstance(entry, FutureEntry)] + spot_entries: List[Entry] = [ + entry for entry in entries if isinstance(entry, SpotEntry) + ] + future_entries: List[Entry] = [ + entry for entry in entries if isinstance(entry, FutureEntry) + ] invocations.extend( await self._publish_entries(spot_entries, DataTypes.SPOT, execution_config) @@ -195,7 +203,7 @@ async def get_spot( self, pair_id: str | int, aggregation_mode: AggregationMode = AggregationMode.MEDIAN, - sources: List[str | int] = None, + sources: Optional[List[str | int]] = None, block_number="latest", ) -> OracleResponse: """ @@ -242,7 +250,7 @@ async def get_future( pair_id: str | int, expiry_timestamp: int, aggregation_mode: AggregationMode = AggregationMode.MEDIAN, - sources: List[str | int] = None, + sources: Optional[List[str | int]] = None, block_number="latest", ) -> OracleResponse: """ @@ -299,7 +307,7 @@ async def get_decimals(self, asset: Asset, block_number="latest") -> Decimals: block_number=block_number, ) - return response + return response # type: ignore[no-any-return] # TODO (#000): Fix future checkpoints async def set_future_checkpoints( @@ -409,7 +417,7 @@ async def get_admin_address(self) -> Address: """ (response,) = await self.oracle.functions["get_admin_address"].call() - return response + return response # type: ignore[no-any-return] async def update_oracle( self, diff --git a/pragma-sdk/pragma_sdk/onchain/mixins/publisher_registry.py b/pragma-sdk/pragma_sdk/onchain/mixins/publisher_registry.py index 9d77c113..d0ca13bf 100644 --- a/pragma-sdk/pragma_sdk/onchain/mixins/publisher_registry.py +++ b/pragma-sdk/pragma_sdk/onchain/mixins/publisher_registry.py @@ -24,7 +24,8 @@ async def get_all_publishers(self) -> List[int]: (publishers,) = await self.publisher_registry.functions[ "get_all_publishers" ].call() - return publishers + + return publishers # type: ignore[no-any-return] async def get_publisher_address(self, publisher: str) -> int: """ @@ -37,7 +38,8 @@ async def get_publisher_address(self, publisher: str) -> int: (address,) = await self.publisher_registry.functions[ "get_publisher_address" ].call(publisher) - return address + + return address # type: ignore[no-any-return] async def get_publisher_sources(self, publisher: str) -> List[int]: """ @@ -50,7 +52,8 @@ async def get_publisher_sources(self, publisher: str) -> List[int]: (sources,) = await self.publisher_registry.functions[ "get_publisher_sources" ].call(publisher) - return sources + + return sources # type: ignore[no-any-return] async def add_publisher( self, @@ -75,7 +78,7 @@ async def add_publisher( publisher_address, execution_config=execution_config, ) - return invocation + return invocation # type: ignore[no-any-return] async def add_source_for_publisher( self, @@ -102,7 +105,7 @@ async def add_source_for_publisher( str_to_felt(source), execution_config=execution_config, ) - return invocation + return invocation # type: ignore[no-any-return] async def add_sources_for_publisher( self, @@ -129,7 +132,7 @@ async def add_sources_for_publisher( [str_to_felt(source) for source in sources], execution_config=execution_config, ) - return invocation + return invocation # type: ignore[no-any-return] async def update_publisher_address( self, @@ -156,4 +159,4 @@ async def update_publisher_address( publisher_address, execution_config=execution_config, ) - return invocation + return invocation # type: ignore[no-any-return] diff --git a/pragma-sdk/pragma_sdk/onchain/mixins/randomness.py b/pragma-sdk/pragma_sdk/onchain/mixins/randomness.py index 5f3e8671..1cd6d3c0 100644 --- a/pragma-sdk/pragma_sdk/onchain/mixins/randomness.py +++ b/pragma-sdk/pragma_sdk/onchain/mixins/randomness.py @@ -6,6 +6,8 @@ from starknet_py.net.client import Client from starknet_py.net.client_errors import ClientError from starknet_py.net.client_models import EstimatedFee, EventsChunk +from starknet_py.net.full_node_client import FullNodeClient +from starknet_py.net.account.account import Account from pragma_sdk.onchain.abis.abi import ABIS from pragma_sdk.onchain.constants import RANDOMNESS_REQUEST_EVENT_SELECTOR @@ -28,7 +30,10 @@ class RandomnessMixin: client: Client - randomness: Optional[Contract] = None + randomness: Contract + account: Optional[Account] = None + is_user_client: bool = False + full_node_client: FullNodeClient def init_randomness_contract(self, contract_address: Address): provider = self.account if self.account else self.client @@ -226,7 +231,7 @@ async def get_total_fees(self, caller_address: Address, request_id: int) -> int: caller_address, request_id ) - return response + return response # type: ignore[no-any-return] async def compute_premium_fee(self, caller_address: Address) -> int: """ @@ -240,7 +245,7 @@ async def compute_premium_fee(self, caller_address: Address) -> int: caller_address ) - return response + return response # type: ignore[no-any-return] async def requestor_current_request_id(self, caller_address: Address) -> int: """ @@ -253,7 +258,7 @@ async def requestor_current_request_id(self, caller_address: Address) -> int: caller_address ) - return response + return response # type: ignore[no-any-return] async def get_pending_requests( self, @@ -275,7 +280,7 @@ async def get_pending_requests( max_len, ) - return response + return response # type: ignore[no-any-return] async def cancel_random_request( self, @@ -309,7 +314,7 @@ async def estimate_gas_cancel_random_op( self, vrf_cancel_params: VRFCancelParams, execution_config: Optional[ExecutionConfig] = None, - ): + ) -> EstimatedFee: """ Estimate the gas for the cancel_random_request operation. @@ -333,7 +338,7 @@ async def estimate_gas_cancel_random_op( max_fee=execution_config.max_fee, ) estimate_fee = await prepared_call.estimate_fee() - return estimate_fee + return estimate_fee # type: ignore[no-any-return] async def refund_operation( self, @@ -447,14 +452,14 @@ async def handle_random( seed = self._build_request_seed(event, block_hash) - beta_string, pi_string, _ = create_randomness(sk, seed) - beta_string = int.from_bytes(beta_string, sys.byteorder) + beta_string, pi_string, _ = create_randomness(sk, seed) # type: ignore[arg-type] + beta_string = int.from_bytes(beta_string, sys.byteorder) # type: ignore[arg-type, assignment] proof = [ - int.from_bytes(p, sys.byteorder) + int.from_bytes(p, sys.byteorder) # type: ignore[arg-type] for p in [pi_string[:31], pi_string[31:62], pi_string[62:]] ] - random_words = [beta_string] + random_words: List[int] = [beta_string] # type: ignore[list-item] vrf_submit_params = VRFSubmitParams( request_id=event.request_id, @@ -496,7 +501,7 @@ async def _get_randomness_requests_events( logger.info(f"Got {len(event_list.events)} events") return event_list - def _build_request_seed(self, event: RandomnessRequest, block_hash: int): + def _build_request_seed(self, event: RandomnessRequest, block_hash: int) -> int: """ Build the request seed. The seed is the hash of the request id, the block hash, the event seed and the caller address. @@ -505,7 +510,7 @@ def _build_request_seed(self, event: RandomnessRequest, block_hash: int): :param block_hash: The block hash. :return: The generated seed. """ - return ( + return int( event.request_id.to_bytes(8, sys.byteorder) + block_hash.to_bytes(32, sys.byteorder) + event.seed.to_bytes(32, sys.byteorder) diff --git a/pragma-sdk/pragma_sdk/onchain/types/types.py b/pragma-sdk/pragma_sdk/onchain/types/types.py index e0ddc3d4..a79fcd95 100644 --- a/pragma-sdk/pragma_sdk/onchain/types/types.py +++ b/pragma-sdk/pragma_sdk/onchain/types/types.py @@ -1,26 +1,28 @@ from enum import StrEnum, unique from collections import namedtuple from typing import Optional, Literal, List, Any, Dict -from pragma_sdk.common.types.asset import Asset from pydantic import HttpUrl - from dataclasses import dataclass +from pragma_sdk.common.types.asset import Asset from pragma_sdk.common.types.types import Address, AggregationMode +from starknet_py.contract import InvokeResult + ContractAddresses = namedtuple( "ContractAddresses", ["publisher_registry_address", "oracle_proxy_addresss", "summary_stats_address"], ) -Network = ( - HttpUrl - | Literal[ - "devnet", - "mainnet", - "sepolia", - ] -) +NetworkName = Literal[ + "devnet", + "mainnet", + "sepolia", +] + +Network = HttpUrl | NetworkName + +PublishEntriesOnChainResult = List[InvokeResult] @unique diff --git a/pragma-sdk/pyproject.toml b/pragma-sdk/pyproject.toml index 5a08017f..f8bc338d 100644 --- a/pragma-sdk/pyproject.toml +++ b/pragma-sdk/pyproject.toml @@ -72,6 +72,8 @@ test_all_unit = "coverage run -m pytest --net=devnet -v --reruns 5 --only-rerun test_hop_handler = "coverage run -m pytest --net=devnet -v --reruns 5 --only-rerun aiohttp.client_exceptions.ClientConnectorError tests/unit/hop_handler_test.py -s" test_index_aggregation = "coverage run -m pytest --net=devnet -v --reruns 5 --only-rerun aiohttp.client_exceptions.ClientConnectorError tests/unit/index_aggregation_test.py -s" +test_all = "coverage run -m pytest --net=devnet -v --reruns 5 --only-rerun aiohttp.client_exceptions.ClientConnectorError tests/ -s" + check_circular_imports = "poetry run python tests/check_circular_imports.py"