Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial review done #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions adaptive_softmax/bandits_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def generate_weighted_permutation(weights: np.ndarray, gen=np.random.default_rng
logits = np.log(weights) - np.log(np.sum(weights))
perturbed_logits = logits + gen.gumbel(size=logits.size)
permutation = perturbed_logits.argsort()[::-1]

return permutation, logits, perturbed_logits

class BanditsSoftmax:
Expand Down Expand Up @@ -65,8 +66,9 @@ def __init__(
query_importance_sampling=True,
randomized_hadamard_transform=False,
verbose=False,
seed=42):

seed=42,
):
# TODO(colin): Why are so many class members prefixed with underscores?
assert len(A.shape) == 2, 'A must be a 2D array'

self.n = A.shape[0]
Expand Down Expand Up @@ -94,7 +96,7 @@ def __init__(
self._permutation, self._logits, self._perturbed_logits = generate_weighted_permutation(self._atom_weights, gen=self._gen)

q = (self._atom_weights / (np.sum(self._atom_weights)) )[np.newaxis, :]
q[q == 0 | np.isnan(q)] = 1 # NOTE 0-weight columns will never be selected
q[q == 0 | np.isnan(q)] = 1 # NOTE 0-weight columns will never be selected
self._est_atom_sig2 = np.max(np.sum((self._A / q / self.d) ** 2 * q, axis=1))
self._est_query_sig2 = None
self._sparse_columns = None
Expand All @@ -111,8 +113,10 @@ def __init__(
print(f'Query importance sampling: {self.query_importance_sampling}')
print(f'Randomized Hadamard transform: {self.randomized_hadamard_transform}')
print(f'Permutation:\n{self._permutation}')

if atom_importance_sampling:
print(f'Atom weights:\n{self._atom_weights}')

if randomized_hadamard_transform:
print(f'Columns 0-padded: {A.shape[1]} --> {self.d}')

Expand Down
159 changes: 84 additions & 75 deletions adaptive_softmax/sftm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from typing import Tuple
from math import log, ceil, sqrt, exp
from math import log, ceil, sqrt
from scipy.special import logsumexp, softmax

from adaptive_softmax.bandits_softmax import BanditsSoftmax
Expand All @@ -18,48 +18,46 @@ class SFTM:
A : np.ndarray
The atom matrix A of shape (n, d) for the matrix-vector multiplication.
temperature : float, optional
The temperature parameter for the softmax function, by default 1.0.
The temperature parameter for the softmax function (default 1.0).
multiplicative_error : float, optional
The multiplicative error parameter for the PAC guarantee, by default 3e-1.
The multiplicative error parameter for the PAC guarantee, epsilon (default 3e-1).
failure_probability : float, optional
The failure probability parameter for the PAC guarantee, by default 1e-1.
The failure probability parameter for the PAC guarantee, delta (default 1e-1).
noise_bound : float, optional
The noise bound parameter for entries of the matrix-vector multiplication,
by default None.
The noise bound parameter for entries of the matrix-vector multiplication (default None).
fudge_pull : float, optional
The multiplier for the number of pulls used in the bandits algorithm to
account for loose bounds, by default 1.0.
account for loose bounds (default 1.0).
fudge_sigma2 : float, optional
The multiplier for the variance used in the bandits algorithm to account
for loose bounds, by default 1.0.
for loose bounds (default 1.0).
atom_importance_sampling : bool, optional
The flag to enable atom-based importance sampling in the bandits algorithm,
by default True.
The flag to enable atom-based importance sampling in the bandits algorithm (default True).
query_importance_sampling : bool, optional
The flag to enable query-based importance sampling in the bandits algorithm,
by default True.
The flag to enable query-based importance sampling in the bandits algorithm (default True).
randomized_hadamard_transform : bool, optional
The flag to enable randomized Hadamard transform of the atom matrix A
The flag to enable randomized Hadamard transform of the atom matrix A (default False)
verbose : bool, optional
The flag to enable verbose output, by default False.
The flag to enable verbose output (default False).
seed : int, optional
The seed for the random number generator used in the bandits algorithm, by
default 42.
The seed for the random number generator used in the bandits algorithm (default 42).
"""

def __init__(self,
A: np.ndarray,
temperature: float = 1.0,
multiplicative_error: float = 3e-1,
failure_probability: float = 1e-1,
noise_bound: float = None,
fudge_pull: float = 1.0,
fudge_sigma2: float = 1.0,
atom_importance_sampling: bool = True,
query_importance_sampling: bool = True,
randomized_hadamard_transform: bool = False,
verbose: bool = False,
seed=42):
def __init__(
self,
A: np.ndarray,
temperature: float = 1.0,
multiplicative_error: float = 3e-1,
failure_probability: float = 1e-1,
noise_bound: float = None,
fudge_pull: float = 1.0,
fudge_sigma2: float = 1.0,
atom_importance_sampling: bool = True,
query_importance_sampling: bool = True,
randomized_hadamard_transform: bool = False,
verbose: bool = False,
seed=42
):
self.A = A
self.n = A.shape[0]
self.d = A.shape[1]
Expand All @@ -70,11 +68,11 @@ def __init__(self,
self.verbose = verbose

if self.verbose:
print(f"Initializing SFTM for a matrix of shape ({self.n}, {self.d})...")
print(f"Initializing SFTM for a matrix of shape ({self.n} x {self.d})...")
print("Parameters:")
print(f"\t-temperature: {self.temperature}")
print(f"\t-multiplicative_error: {self.multiplicative_error}")
print(f"\t-failure_probability: {self.failure_probability}")
print(f"\t- temperature: {self.temperature}")
print(f"\t- multiplicative_error: {self.multiplicative_error}")
print(f"\t- failure_probability: {self.failure_probability}")

self.bandits = BanditsSoftmax(
A,
Expand All @@ -92,19 +90,19 @@ def __init__(self,
print("SFTM initialized.")
print("")

def softmax(self, x: np.ndarray, k: int=1) -> Tuple[np.ndarray, np.ndarray]:
def softmax(self, x: np.ndarray, k: int = 1) -> Tuple[np.ndarray, np.ndarray]:
"""
Computes the true softmax, returning the top-k indices and the softmax.

@param x: The query vector x of shape (d,).
@param k: The number of elements to return, by default 1.
@param k: The number of elements to return (default 1).
@return: The top-k indices and the softmax.
"""
mu = (self.A @ x) * self.temperature
top_k = np.sort(np.argpartition(mu, -k)[-k:])
return top_k, softmax(mu)

def adaptive_softmax(self, x: np.ndarray, k: int=1) -> Tuple[int, float]:
def adaptive_softmax(self, x: np.ndarray, k: int = 1) -> Tuple[int, float]:
"""
Computes the approximate softmax using the SFTM algorithm, returning the
top-k indices, the approximate softmax for these indices, and the
Expand All @@ -114,45 +112,48 @@ def adaptive_softmax(self, x: np.ndarray, k: int=1) -> Tuple[int, float]:
Efficient Softmax Approximation."

@param x: The query vector x of shape (d,).
@param k: The number of elements to return, by default 1.
@param k: The number of elements to return (default 1).
@return: The top-k indices, the approximate softmax, and the normalizing
constant.
constant Z.
"""

if self.verbose:
print(f"Computing adaptive softmax for query vector {x}...")

self.bandits.set_query(x)

bta = self.temperature
beta = self.temperature
eps = self.multiplicative_error
dlt = self.failure_probability
delta = self.failure_probability
sig2 = self.noise_bound if self.noise_bound is not None else self.bandits.variance

if self.verbose:
print(f"Noise bound: {sig2}")

i_star_hat = self.best_arms(dlt/2, bta, sig2, k)
# TODO(@colin): Did we decide whether this should be delta/2 or delta?
i_star_hat = self.best_arms(delta/2, beta, sig2, k)
# TODO(@colin): if i_star_hat is wrong, won't mu_star_hat also be the wrong value?
mu_star_hat = self.bandits.exact_values(i_star_hat)
log_S_hat = self.log_norm_estimation(bta, eps, dlt/2, sig2)
# TODO(@colin): Did we decide whether this should be delta/2 or delta?
log_S_hat = self.log_norm_estimation(beta, eps, delta/2, sig2)

if self.verbose:
print(f"Top-{k} arms: {i_star_hat}")
print(f"Estimated logit values: {mu_star_hat}")
print(f"Estimated log normalizing constant: {log_S_hat}")

return i_star_hat, np.exp(bta * mu_star_hat - log_S_hat), np.exp(log_S_hat)
def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray:
return i_star_hat, np.exp(beta * mu_star_hat - log_S_hat), np.exp(log_S_hat)

def best_arms(self, delta: float, beta: float, sig2: float, k: int) -> np.ndarray:
"""
Finds the top-k arms with the highest estimated logit values.

This method uses a round-based PAC bandits algorithm based on Algorithm 3
from the paper, "Distributed Exploration in Multi-Armed Bandits" by Hillel
et al. (2013).
@param dlt: The failure probability parameter.
@param bta: The temperature parameter.

@param delta: The failure probability parameter.
@param beta: The temperature parameter.
@param sig2: The noise bound parameter.
@param k: The number of arms to return.
@return: The top-k arms with the highest estimated logit values.
Expand All @@ -162,25 +163,31 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray:

n = self.n
d = self.bandits.max_pulls
T0 = int(ceil(17 * (bta ** 2) * sig2 * log(6 * n / dlt)))
T0 = int(ceil(17 * (beta ** 2) * sig2 * log(6 * n / delta)))

if self.verbose:
print(f"Initial number of pulls: {T0}")

# initialize parameters
confidence_set = np.arange(n)
num_pulls = T0
estimates = np.zeros(n)

while True:

keep_pulling = True
while keep_pulling is True:
# pull arms and update confidence interval
estimates = self.bandits.batch_pull(confidence_set, it=fpc(num_pulls, d))
confidence_interval = sqrt(2 * sig2 * log(6 * n * log(d) / dlt) / num_pulls)
confidence_interval = sqrt(2 * sig2 * log(6 * n * log(d) / delta) / num_pulls)

# update confidence set
keep = estimates >= np.max(estimates) - confidence_interval

# TODO(@colin): I don't think this is exactly correct. It may be the case that an arm is
# removed at some point, but then np.max(estimates) moves down and the arm gets added back later.
# The current implementation would say that arm has been pulled num_pulls times, but it's been pulled
# fewer times. For this reason, I think it's actually better to make num_pulls an *array* of how many
# times each arm has been pulled, and then update the confidence interval for each arm separately according
# to its number of pulls. This is how we did it in several other projects, see BanditPAM, FastForest, or BanditMIPS
# for examples.

Comment on lines +183 to +190
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Tavor pointed me to the algorithm I used (Algorithm 3 from the paper, "Distributed Exploration in Multi-Armed Bandits" by Hillel et al. (2013).) which eliminates arms entirely after they fall outside the confidence interval (CI). I checked with Tavor to ask if our current CIs (which differ from those used in the paper) align with this assumption, and he said yes, but there may have been a miscommunication.
  2. There is an array storing how many times each arm has truly been pulled (in bandits_softmax.py). However, in this case all arms in the confidence set at a particular round will have been pulled the same amount and will therefore have the same confidence interval. This is because we perform the bandits algorithm prior to log norm estimation or other operations.
  3. There may be confusion about why the true number of arm pulls differs from the amount specified in num_pulls. This is because the confidence interval is not finite-population-corrected (FPC). Instead, we keep the confidence interval the same and adjust the amount of pulls made to reflect how many we would need to get the current confidence interval (which is around the same when num_pulls is less than d / 2 but becomes much lower past this point. While this is not a bug, I did change this behavior in a new branch to allow for true exponential growth of the number of arm pulls (before, FPC would make the true pulls approach d quite slowly).

if self.verbose:
print(f"Number of pulls: {num_pulls}")
print(f"FPC-adjusted number of pulls: {fpc(num_pulls, d)}")
Expand All @@ -190,6 +197,7 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray:

# check stopping condition
if np.sum(keep) <= k or fpc(num_pulls, d) >= d:
keep_pulling = False
break

# update parameters
Expand All @@ -198,7 +206,7 @@ def best_arms(self, dlt: float, bta: float, sig2: float, k: int) -> np.ndarray:

return confidence_set[np.argsort(estimates)[-k:]]

def estimate_arm_logits(self, arms: np.ndarray, bta: float, eps: float, dlt: float, sig2: float) -> np.ndarray:
def estimate_arm_logits(self, arms: np.ndarray, beta: float, eps: float, delta: float, sig2: float) -> np.ndarray:
"""
Estimates the logit values of the specified arms with PAC guarantees.

Expand All @@ -207,59 +215,60 @@ def estimate_arm_logits(self, arms: np.ndarray, bta: float, eps: float, dlt: flo
paper.

@param arms: The indices of the arms to estimate.
@param bta: The temperature parameter.
@param eps: The multiplicative error parameter.
@param dlt: The failure probability parameter.
@param beta: The temperature parameter.
@param eps: The multiplicative error parameter.
@param delta: The failure probability parameter.
@param sig2: The noise bound parameter.
@return: The estimated logit values of the specified arms.
"""
if self.verbose:
print(f"Estimating logit values for arms {arms}...")

d = self.bandits.max_pulls
T = int(ceil(32 * (sig2) * (bta ** 2) * log(2 / dlt) / (eps ** 2)))
T = int(ceil(32 * sig2 * (beta ** 2) * log(2 / delta) / (eps ** 2)))
return self.bandits.pull(arms, its=np.array(fpc(T, d)))
def log_norm_estimation(self, bta: float, eps: float, dlt: float, sig2: float) -> float:

def log_norm_estimation(self, beta: float, eps: float, delta: float, sig2: float) -> float:
"""
Estimates the log normalizing constant of the softmax function with PAC
Estimates the log normalizing constant of the softmax function with PAC
guarantees.

This method is based on Algorithm 2 of the paper, "Adaptive Sampling for
Efficient Softmax Approximation."

@param bta: The temperature parameter.
@param beta: The temperature parameter.
@param eps: The multiplicative error parameter.
@param dlt: The failure probability parameter.
@param delta: The failure probability parameter.
@param sig2: The noise bound parameter.
@return: The estimated log normalizing constant of the softmax function.
"""

n = self.n
d = self.bandits.max_pulls

T0 = int(ceil(17 * (bta ** 2) * sig2 * log(6 * n / dlt)))
C = np.sqrt(2 * sig2 * log(6 * n / dlt) / T0)
T0 = int(ceil(17 * (beta ** 2) * sig2 * log(6 * n / delta)))
C = np.sqrt(2 * sig2 * log(6 * n / delta) / T0)

if self.verbose:
print("Estimating log normalizing constant of the softmax function...")
print(f"Initial number of pulls: {T0}")
print(f"Confidence interval constant: {C}")

# initial estimates
mu_hat = self.bandits.pull(np.arange(n), its=np.full(shape=n, fill_value=fpc(T0, d)))

if self.verbose:
print(f"Initial estimates: {mu_hat}")

log_alpha = bta * (mu_hat - C)
log_gamma = bta * (mu_hat - C) / 2
log_alpha = beta * (mu_hat - C)
log_gamma = beta * (mu_hat - C) / 2
log_alpha_sum = logsumexp(log_alpha)
log_gamma_sum = logsumexp(log_gamma)

# adapt sample sizes based on initial estimates
log_b = log(17 * (bta ** 2) * sig2 * log(6 * n / dlt))
log_c = log(16 * sqrt(2) * sig2 * log(6 * n / dlt) / eps) + 2 * log_gamma_sum - log_alpha_sum
log_d = log(16 * sig2 * log(12 / dlt) / (eps ** 2))
log_b = log(17 * (beta ** 2) * sig2 * log(6 * n / delta))
log_c = log(16 * sqrt(2) * sig2 * log(6 * n / delta) / eps) + 2 * log_gamma_sum - log_alpha_sum
log_d = log(16 * sig2 * log(12 / delta) / (eps ** 2))

it = np.exp(log_b)
it = np.maximum(it, np.exp(log_c + log_gamma - log_gamma_sum))
Expand All @@ -269,12 +278,12 @@ def log_norm_estimation(self, bta: float, eps: float, dlt: float, sig2: float) -
if self.verbose:
print(f"Adaptive sample sizes: {it}")

# make updated estimates
# make updated estimates
mu_hat = self.bandits.pull(np.arange(n), its=fpc(it, d))

if self.verbose:
print(f"Updated estimates: {mu_hat}")
print(f"Estimated log normalizing constant: {logsumexp(bta * mu_hat)}")
print(f"Estimated log normalizing constant: {logsumexp(beta * mu_hat)}")

return logsumexp(bta * mu_hat)
return logsumexp(beta * mu_hat)

Loading