Skip to content

Commit

Permalink
Merge pull request #69 from Arondondon/development
Browse files Browse the repository at this point in the history
Implemented channels caching
  • Loading branch information
sassless authored Oct 1, 2024
2 parents a9ab0d0 + daf0e0b commit 344dfaa
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 46 deletions.
5 changes: 4 additions & 1 deletion snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import google.protobuf.internal.api_implementation

from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider

with warnings.catch_warnings():
# Suppress the eth-typing package`s warnings related to some new networks
warnings.filterwarnings("ignore", "Network .* does not have a valid ChainId. eth-typing should be "
Expand Down Expand Up @@ -76,6 +78,7 @@ def __init__(self, sdk_config: Config, metadata_provider=None):
self.registry_contract = get_contract_object(self.web3, "Registry", _registry_contract_address)

self.account = Account(self.web3, sdk_config, self.mpe_contract)
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)

def create_service_client(self, org_id: str, service_id: str, group_name=None,
payment_channel_management_strategy=None,
Expand Down Expand Up @@ -122,7 +125,7 @@ def create_service_client(self, org_id: str, service_id: str, group_name=None,
pb2_module = self.get_module_by_keyword(org_id, service_id, keyword="pb2.py")

service_client = ServiceClient(org_id, service_id, service_metadata, group, service_stub, strategy,
options, self.mpe_contract, self.account, self.web3, pb2_module)
options, self.mpe_contract, self.account, self.web3, pb2_module, self.payment_channel_provider)
return service_client

def get_service_stub(self, org_id: str, service_id: str) -> ServiceStub:
Expand Down
3 changes: 0 additions & 3 deletions snet/sdk/mpe/mpe_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ def __init__(self, w3, address=None):
self.contract = get_contract_object(self.web3, "MultiPartyEscrow")
else:
self.contract = get_contract_object(self.web3, "MultiPartyEscrow", address)
self.event_topics = [self.web3.keccak(
text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()]
self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow")

def balance(self, address):
return self.contract.functions.balances(address).call()
Expand Down
126 changes: 92 additions & 34 deletions snet/sdk/mpe/payment_channel_provider.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,124 @@
from pathlib import Path

from web3._utils.events import get_event_data
from eth_abi.codec import ABICodec
import pickle

from snet.sdk.mpe.payment_channel import PaymentChannel
from snet.contracts import get_contract_deployment_block


BLOCKS_PER_BATCH = 5000
CHANNELS_DIR = Path.home().joinpath(".snet", "cache", "mpe")


class PaymentChannelProvider(object):
def __init__(self, w3, payment_channel_state_service_client, mpe_contract):
def __init__(self, w3, mpe_contract):
self.web3 = w3

self.mpe_contract = mpe_contract
self.event_topics = [self.web3.keccak(
text="ChannelOpen(uint256,uint256,address,address,address,bytes32,uint256,uint256)").hex()]
self.deployment_block = get_contract_deployment_block(self.web3, "MultiPartyEscrow")
self.payment_channel_state_service_client = payment_channel_state_service_client

def get_past_open_channels(self, account, payment_address, group_id, starting_block_number=0, to_block_number=None):
if to_block_number is None:
to_block_number = self.web3.eth.block_number

if starting_block_number == 0:
starting_block_number = self.deployment_block

self.mpe_address = mpe_contract.contract.address
self.channels_file = CHANNELS_DIR.joinpath(str(self.mpe_address), "channels.pickle")

def update_cache(self):
channels = []
last_read_block = self.deployment_block

if not self.channels_file.exists():
print(f"Channels cache is empty. Caching may take some time when first accessing channels.\nCaching in progress...")
self.channels_file.parent.mkdir(parents=True, exist_ok=True)
with open(self.channels_file, "wb") as f:
empty_dict = {
"last_read_block": last_read_block,
"channels": channels
}
pickle.dump(empty_dict, f)
else:
with open(self.channels_file, "rb") as f:
load_dict = pickle.load(f)
last_read_block = load_dict["last_read_block"]
channels = load_dict["channels"]

current_block_number = self.web3.eth.block_number

if last_read_block < current_block_number:
new_channels = self._get_all_channels_from_blockchain_logs_to_dicts(last_read_block, current_block_number)
channels = channels + new_channels
last_read_block = current_block_number

with open(self.channels_file, "wb") as f:
dict_to_save = {
"last_read_block": last_read_block,
"channels": channels
}
pickle.dump(dict_to_save, f)

def _event_data_args_to_dict(self, event_data):
return {
"channel_id": event_data["channelId"],
"sender": event_data["sender"],
"signer": event_data["signer"],
"recipient": event_data["recipient"],
"group_id": event_data["groupId"],
}

def _get_all_channels_from_blockchain_logs_to_dicts(self, starting_block_number, to_block_number):
codec: ABICodec = self.web3.codec

logs = []
from_block = starting_block_number
while from_block <= to_block_number:
to_block = min(from_block + BLOCKS_PER_BATCH, to_block_number)
logs = logs + self.web3.eth.get_logs({"fromBlock": from_block, "toBlock": to_block,
"address": self.mpe_contract.contract.address,
"topics": self.event_topics})
logs = logs + self.web3.eth.get_logs({"fromBlock": from_block,
"toBlock": to_block,
"address": self.mpe_address,
"topics": self.event_topics})
from_block = to_block + 1

event_abi = self.mpe_contract.contract._find_matching_event_abi(event_name="ChannelOpen")
channels_opened = list(filter(
lambda
channel: (channel.sender == account.address or channel.signer == account.signer_address) and channel.recipient ==
payment_address and channel.groupId == group_id,

[get_event_data(codec, event_abi, l)["args"] for l in logs]
))
return list(map(lambda channel: PaymentChannel(channel["channelId"], self.web3, account,
self.payment_channel_state_service_client, self.mpe_contract),
channels_opened))

def open_channel(self, account, amount, expiration, payment_address, group_id):
event_data_list = [get_event_data(codec, event_abi, l)["args"] for l in logs]
channels_opened = list(map(self._event_data_args_to_dict, event_data_list))

return channels_opened

def _get_channels_from_cache(self):
self.update_cache()
with open(self.channels_file, "rb") as f:
load_dict = pickle.load(f)
return load_dict["channels"]

def get_past_open_channels(self, account, payment_address, group_id, payment_channel_state_service_client):

dict_channels = self._get_channels_from_cache()

channels_opened = list(filter(lambda channel: (channel["sender"] == account.address
or channel["signer"] == account.signer_address)
and channel["recipient"] == payment_address
and channel["group_id"] == group_id,
dict_channels))

return list(map(lambda channel: PaymentChannel(channel["channel_id"],
self.web3,
account,
payment_channel_state_service_client,
self.mpe_contract),
channels_opened))

def open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client):
receipt = self.mpe_contract.open_channel(account, payment_address, group_id, amount, expiration)
return self._get_newly_opened_channel(receipt, account, payment_address, group_id)
return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client)

def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id):
receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount,
expiration)
return self._get_newly_opened_channel(receipt, account, payment_address, group_id)
def deposit_and_open_channel(self, account, amount, expiration, payment_address, group_id, payment_channel_state_service_client):
receipt = self.mpe_contract.deposit_and_open_channel(account, payment_address, group_id, amount, expiration)
return self._get_newly_opened_channel(account, payment_address, group_id, receipt, payment_channel_state_service_client)

def _get_newly_opened_channel(self, receipt,account, payment_address, group_id):
open_channels = self.get_past_open_channels(account, payment_address, group_id, receipt["blockNumber"],
receipt["blockNumber"])
if len(open_channels) == 0:
def _get_newly_opened_channel(self, account, payment_address, group_id, receipt, payment_channel_state_service_client):
open_channels = self.get_past_open_channels(account, payment_address, group_id, payment_channel_state_service_client)
if not open_channels:
raise Exception(f"Error while opening channel, please check transaction {receipt.transactionHash.hex()} ")
return open_channels[0]
return open_channels[-1]

16 changes: 8 additions & 8 deletions snet/sdk/service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path, find_file_by_keyword

import snet.sdk.generic_client_interceptor as generic_client_interceptor
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider


class _ClientCallDetails(
Expand All @@ -26,7 +25,7 @@ class _ClientCallDetails(

class ServiceClient:
def __init__(self, org_id, service_id, service_metadata, group, service_stub, payment_strategy,
options, mpe_contract, account, sdk_web3, pb2_module):
options, mpe_contract, account, sdk_web3, pb2_module, payment_channel_provider):
self.org_id = org_id
self.service_id = service_id
self.options = options
Expand All @@ -38,9 +37,8 @@ def __init__(self, org_id, service_id, service_metadata, group, service_stub, pa
self.__base_grpc_channel = self._get_grpc_channel()
self.grpc_channel = grpc.intercept_channel(self.__base_grpc_channel,
generic_client_interceptor.create(self._intercept_call))
self.payment_channel_provider = PaymentChannelProvider(sdk_web3,
self._generate_payment_channel_state_service_client(),
mpe_contract)
self.payment_channel_provider = payment_channel_provider
self.payment_channel_state_service_client = self._generate_payment_channel_state_service_client()
self.service = self._generate_grpc_stub(service_stub)
self.pb2_module = importlib.import_module(pb2_module) if isinstance(pb2_module, str) else pb2_module
self.payment_channels = []
Expand Down Expand Up @@ -122,7 +120,8 @@ def load_open_channels(self):
payment_address = self.group["payment"]["payment_address"]
group_id = base64.b64decode(str(self.group["group_id"]))
new_payment_channels = self.payment_channel_provider.get_past_open_channels(self.account, payment_address,
group_id, self.last_read_block)
group_id,
self.payment_channel_state_service_client)
self.payment_channels = self.payment_channels + \
self._filter_existing_channels_from_new_payment_channels(new_payment_channels)
self.last_read_block = current_block_number
Expand Down Expand Up @@ -150,13 +149,14 @@ def open_channel(self, amount, expiration):
payment_address = self.group["payment"]["payment_address"]
group_id = base64.b64decode(str(self.group["group_id"]))
return self.payment_channel_provider.open_channel(self.account, amount, expiration, payment_address,
group_id)
group_id, self.payment_channel_state_service_client)

def deposit_and_open_channel(self, amount, expiration):
payment_address = self.group["payment"]["payment_address"]
group_id = base64.b64decode(str(self.group["group_id"]))
return self.payment_channel_provider.deposit_and_open_channel(self.account, amount, expiration,
payment_address, group_id)
payment_address, group_id,
self.payment_channel_state_service_client)

def get_price(self):
return self.group["pricing"][0]["price_in_cogs"]
Expand Down

0 comments on commit 344dfaa

Please sign in to comment.