Skip to content

Commit

Permalink
Simplify staker sampling and add unit tests for proper sampling distr…
Browse files Browse the repository at this point in the history
…ibution
  • Loading branch information
fjarri committed Jun 8, 2020
1 parent 14bb63a commit 44cb58b
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 46 deletions.
102 changes: 64 additions & 38 deletions nucypher/blockchain/eth/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
along with nucypher. If not, see <https://www.gnu.org/licenses/>.
"""

from bisect import bisect_right
from itertools import accumulate
import random

import math
import sys
from constant_sorrow.constants import ( # type: ignore
Expand Down Expand Up @@ -237,7 +238,7 @@ class NotEnoughStakers(Exception):
def get_staker_population(self) -> int:
"""Returns the number of stakers on the blockchain"""
return self.contract.functions.getStakersLength().call()

@contract_api(CONTRACT_CALL)
def get_current_period(self) -> Period:
"""Returns the current period"""
Expand Down Expand Up @@ -321,7 +322,7 @@ def get_all_locked_tokens(self, periods: int, pagination_size: Optional[int] = N
#
# StakingEscrow Contract API
#

@contract_api(CONTRACT_CALL)
def get_global_locked_tokens(self, at_period: Optional[Period] = None) -> NuNits:
"""
Expand Down Expand Up @@ -435,7 +436,7 @@ def batch_deposit(self,
dry_run: bool = False,
gas_limit: Optional[Wei] = None
) -> Union[TxReceipt, Wei]:

min_gas_batch_deposit: Wei = Wei(250_000) # TODO: move elsewhere?
if gas_limit and gas_limit < min_gas_batch_deposit:
raise ValueError(f"{gas_limit} is not enough gas for any batch deposit")
Expand Down Expand Up @@ -618,7 +619,7 @@ def set_snapshots(self, staker_address: ChecksumAddress, activate: bool) -> TxRe
@contract_api(CONTRACT_CALL)
def staking_parameters(self) -> StakingEscrowParameters:
parameter_signatures = (

# Period
'secondsPerPeriod', # Seconds in single period

Expand Down Expand Up @@ -661,15 +662,11 @@ def swarm(self) -> Iterable[ChecksumAddress]:
def sample(self,
quantity: int,
duration: int,
additional_ursulas: float = 1.5,
attempts: int = 5,
pagination_size: Optional[int] = None
) -> List[ChecksumAddress]:
"""
Select n random Stakers, according to their stake distribution.
The returned addresses are shuffled, so one can request more than needed and
throw away those which do not respond.
The returned addresses are shuffled.
See full diagram here: https://github.com/nucypher/kms-whitepaper/blob/master/pdf/miners-ruler.pdf
Expand All @@ -688,42 +685,30 @@ def sample(self,
Only stakers which made a commitment to the current period (in the previous period) are used.
"""

system_random = random.SystemRandom()
n_tokens, stakers_map = self.get_all_active_stakers(periods=duration, pagination_size=pagination_size)

# TODO: can be implemented as an iterator if necessary, where the user can
# sample addresses one by one without calling get_all_active_stakers() repeatedly.

if n_tokens == 0:
raise self.NotEnoughStakers('There are no locked tokens for duration {}.'.format(duration))

sample_size = quantity
for _ in range(attempts):
sample_size = math.ceil(sample_size * additional_ursulas)
points = sorted(system_random.randrange(n_tokens) for _ in range(sample_size))
self.log.debug(f"Sampling {sample_size} stakers with random points: {points}")
if quantity > len(stakers_map):
raise self.NotEnoughStakers(f'Cannot sample {quantity} out of {len(stakers)} total stakers')

addresses = set()
stakers = list(stakers_map.items())
addresses = list(stakers_map.keys())
tokens = list(stakers_map.values())
sampler = WeightedSampler(addresses, tokens)

point_index = 0
sum_of_locked_tokens = 0
staker_index = 0
stakers_len = len(stakers)
while staker_index < stakers_len and point_index < sample_size:
current_staker = stakers[staker_index][0]
staker_tokens = stakers[staker_index][1]
next_sum_value = sum_of_locked_tokens + staker_tokens
system_random = random.SystemRandom()
sampled_addresses = sampler.sample_no_replacement(system_random, quantity)

point = points[point_index]
if sum_of_locked_tokens <= point < next_sum_value:
addresses.add(to_checksum_address(current_staker))
point_index += 1
else:
staker_index += 1
sum_of_locked_tokens = next_sum_value
# Randomize the output to avoid the largest stakers always being the first in the list
system_random.shuffle(sampled_addresses) # inplace

self.log.debug(f"Sampled {len(addresses)} stakers: {list(addresses)}")
if len(addresses) >= quantity:
return system_random.sample(addresses, quantity)
self.log.debug(f"Sampled {len(addresses)} stakers: {list(sampled_addresses)}")

raise self.NotEnoughStakers('Selection failed after {} attempts'.format(attempts))
return sampled_addresses

@contract_api(CONTRACT_CALL)
def get_completed_work(self, bidder_address: ChecksumAddress) -> Work:
Expand Down Expand Up @@ -1229,7 +1214,7 @@ def withdraw_compensation(self, checksum_address: ChecksumAddress) -> TxReceipt:
def check_claim(self, checksum_address: ChecksumAddress) -> bool:
has_claimed: bool = bool(self.contract.functions.workInfo(checksum_address).call()[2])
return has_claimed

#
# Internal
#
Expand Down Expand Up @@ -1577,3 +1562,44 @@ def get_agent_by_contract_name(cls,
agent_class: Type[EthereumContractAgent] = getattr(agents_module, agent_name)
agent: EthereumContractAgent = cls.get_agent(agent_class=agent_class, registry=registry, provider_uri=provider_uri)
return agent


class WeightedSampler:
"""
Samples random elements with probabilities proportioinal to given weights.
"""

def __init__(self, elements: Iterable, weights: Iterable[int]):
assert len(elements) == len(weights)
self.totals = list(accumulate(weights))
self.elements = elements

def sample_no_replacement(self, rng, quantity: int) -> list:
"""
Samples ``quantity`` of elements from the internal array.
The probablity of an element to appear is proportional
to the weight provided to the constructor.
The elements will not repeat; every time an element is sampled its weight is set to 0.
(does not mutate the object and only applies to the current invocation of the method).
"""

if quantity > len(self.totals):
raise ValueError("Cannot sample more than the total amount of elements without replacement")

totals = self.totals.copy()
samples = []

for i in range(quantity):
position = rng.randint(0, totals[-1] - 1)
idx = bisect_right(totals, position)
samples.append(self.elements[idx])

# Adjust the totals so that they correspond
# to the weight of the element `idx` being set to 0.
prev_total = totals[idx - 1] if idx > 0 else 0
weight = totals[idx] - prev_total
for j in range(idx, len(totals)):
totals[j] -= weight

return samples
7 changes: 1 addition & 6 deletions nucypher/policy/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,6 @@ def __init__(self,
self.treasure_map = TreasureMap(m=m)
self.expiration = expiration

# Keep track of this stuff
self.selection_buffer = 1

self._accepted_arrangements = set() # type: Set[Arrangement]
self._rejected_arrangements = set() # type: Set[Arrangement]
self._spare_candidates = set() # type: Set[Ursula]
Expand Down Expand Up @@ -532,7 +529,6 @@ def __init__(self,

super().__init__(alice=alice, expiration=expiration, *args, **kwargs)

self.selection_buffer = 1.5
self.validate_fee_value()

def validate_fee_value(self) -> None:
Expand Down Expand Up @@ -618,8 +614,7 @@ def sample_essential(self, quantity: int, handpicked_ursulas: Set[Ursula] = None
selected_addresses = set()
try:
sampled_addresses = self.alice.recruit(quantity=quantity,
duration=self.duration_periods,
additional_ursulas=self.selection_buffer)
duration=self.duration_periods)
except StakingEscrowAgent.NotEnoughStakers as e:
error = f"Cannot create policy with {quantity} arrangements: {e}"
raise self.NotEnoughBlockchainUrsulas(error)
Expand Down
44 changes: 42 additions & 2 deletions tests/acceptance/blockchain/agents/test_sampling_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
"""

from collections import Counter
from itertools import permutations
import random

import pytest

from nucypher.blockchain.economics import BaseEconomics
from nucypher.blockchain.eth.agents import StakingEscrowAgent
from nucypher.blockchain.eth.agents import StakingEscrowAgent, WeightedSampler
from nucypher.blockchain.eth.constants import NULL_ADDRESS, STAKING_ESCROW_CONTRACT_NAME


Expand Down Expand Up @@ -115,7 +117,7 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom
sampled, failed = 0, 0
while sampled < SAMPLES:
try:
addresses = set(staking_agent.sample(quantity=quantity, additional_ursulas=1, duration=1))
addresses = set(staking_agent.sample(quantity=quantity, duration=1))
addresses.discard(NULL_ADDRESS)
except staking_agent.NotEnoughStakers:
failed += 1
Expand All @@ -134,3 +136,41 @@ def test_sampling_distribution(testerchain, token, deploy_contract, token_econom
assert abs_error < ERROR_TOLERANCE

# TODO: Test something wrt to % of failed


def probability_reference_no_replacement(weights, idxs):
"""
The probability of drawing elements with (distinct) indices ``idxs`` (in given order),
given ``weights``. No replacement.
"""
assert len(set(idxs)) == len(idxs)
all_weights = sum(weights)
p = 1
for idx in idxs:
p *= weights[idx] / all_weights
all_weights -= weights[idx]
return p


@pytest.mark.parametrize('sample_size', [1, 2, 3])
def test_weighted_sampler(sample_size):
weights = [1, 9, 100, 2, 18, 70]
rng = random.SystemRandom()
counter = Counter()

elements = list(range(len(weights)))

samples = 100000
sampler = WeightedSampler(elements, weights)
for i in range(samples):
sample_set = sampler.sample_no_replacement(rng, sample_size)
counter.update({tuple(sample_set): 1})

for idxs in permutations(elements, sample_size):
test_prob = counter[idxs] / samples
ref_prob = probability_reference_no_replacement(weights, idxs)

# A rough estimate to check probabilities.
# A little too forgiving for samples with smaller probabilities,
# but can go up to 0.5 on occasion.
assert abs(test_prob - ref_prob) * samples**0.5 < 1

0 comments on commit 44cb58b

Please sign in to comment.