Skip to content

Commit

Permalink
Expose staker sampling as an iterator (ish)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Jun 9, 2020
1 parent 44cb58b commit b3a031f
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 113 deletions.
14 changes: 5 additions & 9 deletions nucypher/blockchain/eth/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
PolicyManagerAgent,
PreallocationEscrowAgent,
StakingEscrowAgent,
WorkLockAgent
WorkLockAgent,
StakersReservoir,
)
from nucypher.blockchain.eth.constants import NULL_ADDRESS
from nucypher.blockchain.eth.decorators import (
Expand Down Expand Up @@ -1458,16 +1459,11 @@ def generate_policy_parameters(self,
payload = {**blockchain_payload, **policy_end_time}
return payload

def recruit(self, quantity: int, **options) -> List[str]:
def get_stakers_reservoir(self, **options) -> StakersReservoir:
"""
Uses sampling logic to gather stakers from the blockchain and
caches the resulting node ethereum addresses.
:param quantity: Number of ursulas to sample from the blockchain.
Get a sampler object containing the currently registered stakers.
"""
staker_addresses = self.staking_agent.sample(quantity=quantity, **options)
return staker_addresses
return self.staking_agent.get_stakers_reservoir(**options)

def create_policy(self, *args, **kwargs):
"""
Expand Down
86 changes: 40 additions & 46 deletions nucypher/blockchain/eth/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,56 +659,23 @@ def swarm(self) -> Iterable[ChecksumAddress]:
yield staker_address

@contract_api(CONTRACT_CALL)
def sample(self,
quantity: int,
duration: int,
pagination_size: Optional[int] = None
) -> List[ChecksumAddress]:
"""
Select n random Stakers, according to their stake distribution.
The returned addresses are shuffled.
See full diagram here: https://github.com/nucypher/kms-whitepaper/blob/master/pdf/miners-ruler.pdf
This method implements the Probability Proportional to Size (PPS) sampling algorithm.
In few words, the algorithm places in a line all active stakes that have locked tokens for
at least `duration` periods; a staker is selected if an input point is within its stake.
For example:
```
Stakes: |----- S0 ----|--------- S1 ---------|-- S2 --|---- S3 ---|-S4-|----- S5 -----|
Points: ....R0.......................R1..................R2...............R3...........
```
In this case, Stakers 0, 1, 3 and 5 will be selected.
Only stakers which made a commitment to the current period (in the previous period) are used.
"""
def get_stakers_reservoir(self,
duration: int,
without: Iterable[ChecksumAddress] = [],
pagination_size: Optional[int] = None) -> 'StakersReservoir':
n_tokens, stakers_map = self.get_all_active_stakers(periods=duration,
pagination_size=pagination_size)

n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, pagination_size=pagination_size)
self.log.debug(f"Got {len(stakers_map)} stakers with {n_tokens} total tokens")

# TODO: can be implemented as an iterator if necessary, where the user can
# sample addresses one by one without calling get_all_active_stakers() repeatedly.
for address in without:
del stakers_map[address]

if n_tokens == 0:
# TODO: or is it enough to just make sure the number of remaining stakers is non-zero?
if sum(stakers_map.values()) == 0:
raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration))

if quantity > len(stakers_map):
raise self.NotEnoughStakers(f'Cannot sample {quantity} out of {len(stakers)} total stakers')

addresses = list(stakers_map.keys())
tokens = list(stakers_map.values())
sampler = WeightedSampler(addresses, tokens)

system_random = random.SystemRandom()
sampled_addresses = sampler.sample_no_replacement(system_random, quantity)

# Randomize the output to avoid the largest stakers always being the first in the list
system_random.shuffle(sampled_addresses) # inplace

self.log.debug(f"Sampled {len(addresses)} stakers: {list(sampled_addresses)}")

return sampled_addresses
return StakersReservoir(stakers_map)

@contract_api(CONTRACT_CALL)
def get_completed_work(self, bidder_address: ChecksumAddress) -> Work:
Expand Down Expand Up @@ -1584,7 +1551,10 @@ def sample_no_replacement(self, rng, quantity: int) -> list:
(does not mutate the object and only applies to the current invocation of the method).
"""

if quantity > len(self.totals):
if quantity == 0:
return []

if quantity > len(self):
raise ValueError("Cannot sample more than the total amount of elements without replacement")

totals = self.totals.copy()
Expand All @@ -1603,3 +1573,27 @@ def sample_no_replacement(self, rng, quantity: int) -> list:
totals[j] -= weight

return samples

def __len__(self):
return len(self.totals)


class StakersReservoir:

def __init__(self, stakers_map):
addresses = list(stakers_map.keys())
tokens = list(stakers_map.values())
self._sampler = WeightedSampler(addresses, tokens)
self._rng = random.SystemRandom()

def __len__(self):
return len(self._sampler)

def draw(self, quantity):
if quantity > len(self):
raise StakingEscrowAgent.NotEnoughStakers(f'Cannot sample {quantity} out of {len(self)} total stakers')

return self._sampler.sample_no_replacement(self._rng, quantity)

def draw_at_most(self, quantity):
return self.draw(min(quantity, len(self)))
9 changes: 5 additions & 4 deletions nucypher/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from twisted.internet import defer, reactor, task
from twisted.internet.threads import deferToThread
from twisted.logger import Logger
from typing import Set, Tuple, Union
from typing import Set, Tuple, Union, Iterable
from umbral.signing import Signature

import nucypher
Expand Down Expand Up @@ -604,9 +604,10 @@ def keep_learning_about_nodes(self):
# TODO: Allow the user to set eagerness? 1712
self.learn_from_teacher_node(eager=False)

def learn_about_specific_nodes(self, addresses: Set):
self._node_ids_to_learn_about_immediately.update(addresses) # hmmmm
self.learn_about_nodes_now()
def learn_about_specific_nodes(self, addresses: Iterable):
if len(addresses) > 0:
self._node_ids_to_learn_about_immediately.update(addresses) # hmmmm
self.learn_about_nodes_now()

# TODO: Dehydrate these next two methods. NRN

Expand Down
106 changes: 59 additions & 47 deletions nucypher/policy/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from bytestring_splitter import BytestringSplitter, VariableLengthBytestring
from constant_sorrow.constants import NOT_SIGNED, UNKNOWN_KFRAG
from twisted.logger import Logger
from typing import Generator, List, Set
from typing import Generator, List, Set, Optional
from umbral.keys import UmbralPublicKey
from umbral.kfrags import KFrag

Expand Down Expand Up @@ -381,7 +381,7 @@ def consider_arrangement(self, network_middleware, ursula, arrangement) -> bool:

def make_arrangements(self,
network_middleware: RestMiddleware,
handpicked_ursulas: Set[Ursula] = None,
handpicked_ursulas: Optional[Set[Ursula]] = None,
*args, **kwargs,
) -> None:

Expand All @@ -408,11 +408,12 @@ def make_arrangement(self, ursula: Ursula, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula]) -> Set[Ursula]:
raise NotImplementedError

def sample(self, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
selected_ursulas = set(handpicked_ursulas) if handpicked_ursulas else set()
def sample(self, handpicked_ursulas: Optional[Set[Ursula]] = None) -> Set[Ursula]:
handpicked_ursulas = handpicked_ursulas if handpicked_ursulas else set()
selected_ursulas = set(handpicked_ursulas)

# Calculate the target sample quantity
target_sample_quantity = self.n - len(selected_ursulas)
Expand Down Expand Up @@ -475,11 +476,11 @@ def make_arrangements(self, *args, **kwargs) -> None:
"Pass them here as handpicked_ursulas.".format(self.n)
raise self.MoreKFragsThanArrangements(error) # TODO: NotEnoughUrsulas where in the exception tree is this?

def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula]) -> Set[Ursula]:
known_nodes = self.alice.known_nodes
if handpicked_ursulas:
# Prevent re-sampling of handpicked ursulas.
known_nodes = set(known_nodes) - set(handpicked_ursulas)
known_nodes = set(known_nodes) - handpicked_ursulas
sampled_ursulas = set(random.sample(k=quantity, population=list(known_nodes)))
return sampled_ursulas

Expand Down Expand Up @@ -572,57 +573,68 @@ def generate_policy_parameters(n: int,
params = dict(rate=rate, value=value)
return params

def __find_ursulas(self,
ether_addresses: List[str],
target_quantity: int,
timeout: int = 10) -> set: # TODO #843: Make timeout configurable
def sample_essential(self,
quantity: int,
handpicked_ursulas: Set[Ursula],
learner_timeout: int = 1,
timeout: int = 10) -> Set[Ursula]:

start_time = maya.now() # marker for timeout calculation
selected_addresses = set(handpicked_ursulas)
quantity_remaining = quantity

found_ursulas, unknown_addresses = set(), deque()
while len(found_ursulas) < target_quantity: # until there are enough Ursulas
# Need to sample some stakers

delta = maya.now() - start_time # check for a timeout
if delta.total_seconds() >= timeout:
missing_nodes = ', '.join(a for a in unknown_addresses)
raise RuntimeError("Timed out after {} seconds; Cannot find {}.".format(timeout, missing_nodes))
reservoir = self.alice.get_stakers_reservoir(duration=self.duration_periods,
without=handpicked_ursulas)
if len(reservoir) < quantity_remaining:
error = f"Cannot create policy with {quantity} arrangements"
raise self.NotEnoughBlockchainUrsulas(error)

# Select an ether_address: Prefer the selection pool, then unknowns queue
if ether_addresses:
ether_address = ether_addresses.pop()
else:
ether_address = unknown_addresses.popleft()
to_check = reservoir.draw(quantity_remaining)

try:
# Check if this is a known node.
selected_ursula = self.alice.known_nodes[ether_address]
# Sample stakers in a loop and feed them to the learner to check
# until we have enough in `selected_addresses`.

except KeyError:
# Unknown Node
self.alice.learn_about_specific_nodes({ether_address}) # enter address in learning loop
unknown_addresses.append(ether_address)
continue
start_time = maya.now()
new_to_check = to_check

while True:

# Check if the sampled addresses are already known.
# If we're lucky, we won't have to wait for the learner iteration to finish.
known = list(filter(lambda x: x in self.alice.known_nodes, to_check))
to_check = list(filter(lambda x: x not in self.alice.known_nodes, to_check))

known = known[:min(len(known), quantity_remaining)] # we only need so many
selected_addresses.update(known)
quantity_remaining -= len(known)

if quantity_remaining == 0:
break
else:
# Known Node
found_ursulas.add(selected_ursula) # We already knew, or just learned about this ursula
new_to_check = reservoir.draw_at_most(quantity_remaining)
to_check.extend(new_to_check)

return found_ursulas
# Feed newly sampled stakers to the learner
self.alice.learn_about_specific_nodes(new_to_check)

def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None) -> Set[Ursula]:
# TODO: Prevent re-sampling of handpicked ursulas.
selected_addresses = set()
try:
sampled_addresses = self.alice.recruit(quantity=quantity,
duration=self.duration_periods)
except StakingEscrowAgent.NotEnoughStakers as e:
error = f"Cannot create policy with {quantity} arrangements: {e}"
raise self.NotEnoughBlockchainUrsulas(error)
# TODO: would be nice to wait for the learner to finish an iteration here,
# because if it hasn't, we really have nothing to do.
time.sleep(learner_timeout)

delta = maya.now() - start_time
if delta.total_seconds() >= timeout:
still_checking = ', '.join(to_check)
raise RuntimeError(f"Timed out after {timeout} seconds; "
f"need {quantity} more, still checking {still_checking}.")

found_ursulas = list(selected_addresses)

# Randomize the output to avoid the largest stakers always being the first in the list
system_random = random.SystemRandom()
system_random.shuffle(found_ursulas) # inplace

# Capture the selection and search the network for those Ursulas
selected_addresses.update(sampled_addresses)
found_ursulas = self.__find_ursulas(sampled_addresses, quantity)
return found_ursulas
return set(found_ursulas)

def publish_to_blockchain(self) -> dict:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom
sampled, failed = 0, 0
while sampled < SAMPLES:
try:
addresses = set(staking_agent.sample(quantity=quantity, duration=1))
reservoir = staking_agent.get_stakers_reservoir(duration=1)
addresses = set(reservoir.draw(quantity))
addresses.discard(NULL_ADDRESS)
except staking_agent.NotEnoughStakers:
failed += 1
Expand Down
17 changes: 11 additions & 6 deletions tests/acceptance/blockchain/agents/test_staking_escrow_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,25 @@ def test_sample_stakers(agency):
_token_agent, staking_agent, _policy_agent = agency
stakers_population = staking_agent.get_staker_population()

reservoir = staking_agent.get_stakers_reservoir(duration=1)
with pytest.raises(StakingEscrowAgent.NotEnoughStakers):
staking_agent.sample(quantity=stakers_population + 1, duration=1) # One more than we have deployed
reservoir.draw(stakers_population + 1) # One more than we have deployed

stakers = staking_agent.sample(quantity=3, duration=5)
reservoir = staking_agent.get_stakers_reservoir(duration=5)
stakers = reservoir.draw(3)
assert len(stakers) == 3 # Three...
assert len(set(stakers)) == 3 # ...unique addresses

# Same but with pagination
stakers = staking_agent.sample(quantity=3, duration=5, pagination_size=1)
reservoir = staking_agent.get_stakers_reservoir(duration=5, pagination_size=1)
stakers = reservoir.draw(3)
assert len(stakers) == 3
assert len(set(stakers)) == 3
light = staking_agent.blockchain.is_light
staking_agent.blockchain.is_light = not light
stakers = staking_agent.sample(quantity=3, duration=5)

reservoir = staking_agent.get_stakers_reservoir(duration=5)
stakers = reservoir.draw(3)
assert len(stakers) == 3
assert len(set(stakers)) == 3
staking_agent.blockchain.is_light = light
Expand Down Expand Up @@ -261,13 +266,13 @@ def test_lock_restaking(agency, testerchain, test_registry):
staking_agent = ContractAgency.get_agent(StakingEscrowAgent, registry=test_registry)
current_period = staking_agent.get_current_period()
terminal_period = current_period + 2

assert staking_agent.is_restaking(staker_account)
assert not staking_agent.is_restaking_locked(staker_account)
receipt = staking_agent.lock_restaking(staker_account, release_period=terminal_period)
assert receipt['status'] == 1, "Transaction Rejected"
assert staking_agent.is_restaking_locked(staker_account)

testerchain.time_travel(periods=2) # Wait for re-staking lock to be released.
assert not staking_agent.is_restaking_locked(staker_account)

Expand Down

0 comments on commit b3a031f

Please sign in to comment.