Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kl_Ucb implementation #1657

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6bf81df
v1test
W0lfgunbl00d Nov 5, 2024
6aa4869
Update rls.py
W0lfgunbl00d Nov 5, 2024
814fd8a
Added an v0 adpredictor
Mo3ad-S Nov 9, 2024
b309747
Added an v0 adpredictor
Mo3ad-S Nov 9, 2024
069cee9
Merge branch 'online-ml:main' into main
slach31 Nov 16, 2024
6a229d8
adpredictor algorithm
Mo3ad-S Nov 16, 2024
e0e6c75
add adpredictor
Mo3ad-S Nov 16, 2024
ff8c617
added an adpredictor function
Mo3ad-S Nov 16, 2024
67e7e14
remooved adpredictor here
Mo3ad-S Nov 16, 2024
6f43ec8
fixed bugs
Mo3ad-S Nov 17, 2024
89cd67e
fixed bugs
Mo3ad-S Nov 17, 2024
54a94e3
removed rls
Mo3ad-S Nov 17, 2024
4a7bc49
Fix test pre commit
Nov 17, 2024
7311788
Fixed imports
Nov 17, 2024
648b1a4
adjusted the adpredictor algorithm
Mo3ad-S Nov 26, 2024
322b924
Merge branch 'ad_predict' of https://github.com/slach31/riverIDLIB in…
Mo3ad-S Nov 26, 2024
dcb0f98
updated the rest of the project
Mo3ad-S Nov 26, 2024
2855d11
Merge branch 'online-ml:main' into ad_predict
slach31 Nov 26, 2024
1b82c24
modified defaultdict
Mo3ad-S Nov 26, 2024
11a0282
add KL-UCB skeleton
Nov 26, 2024
4d76a0f
added a first version of klucb
Mo3ad-S Nov 27, 2024
e518b8c
Merge branch 'online-ml:main' into KlUcb
slach31 Nov 27, 2024
df3beaa
added a new version of the algorithm
Mo3ad-S Nov 27, 2024
49aecc8
added a new version of the algorithm
Mo3ad-S Nov 27, 2024
65cb16c
added a first version of kl_ucb
Mo3ad-S Nov 27, 2024
7b54884
first round of tests KLUCB
Nov 27, 2024
d87fbaa
fix implementation KLUCB
Nov 27, 2024
165fbd3
Merge branch 'online-ml:main' into KlUcb
slach31 Nov 30, 2024
d5ca2ce
a new version of Kl_Ucb
Mo3ad-S Nov 30, 2024
9a28b0c
a new version of Kl_Ucb
Mo3ad-S Nov 30, 2024
0d60194
Final version of Kl_Ucb
Mo3ad-S Nov 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions river/bandit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .epsilon_greedy import EpsilonGreedy
from .evaluate import evaluate, evaluate_offline
from .exp3 import Exp3
from .kl_ucb import KLUCB
from .lin_ucb import LinUCBDisjoint
from .random import RandomPolicy
from .thompson import ThompsonSampling
Expand All @@ -31,4 +32,5 @@
"ThompsonSampling",
"UCB",
"RandomPolicy",
"KLUCB",
]
196 changes: 196 additions & 0 deletions river/bandit/kl_ucb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from __future__ import annotations

import math
import random


class KLUCB:
"""

KL-UCB is an algorithm for solving the multi-armed bandit problem. It uses Kullback-Leibler (KL)
divergence to calculate upper confidence bounds (UCBs) for each arm. The algorithm aims to balance
exploration (trying different arms) and exploitation (selecting the best-performing arm) in a principled way.

Parameters
----------
n_arms (int):
The total number of arms available for selection.
horizon (int):
The total number of time steps or trials during which the algorithm will run.
c (float, default=0):
A scaling parameter for the confidence bound. Larger values promote exploration,
while smaller values favor exploitation.

Attributes
----------
arm_count (list[int]):
A list where each element tracks the number of times an arm has been selected.
rewards (list[float]):
A list where each element accumulates the total rewards received from pulling each arm.
t (int):
The current time step in the algorithm.

Methods
-------
update(arm, reward):
Updates the statistics for the selected arm based on the observed reward.

kl_divergence(p, q):
Computes the Kullback-Leibler (KL) divergence between probabilities `p` and `q`.
This measures how one probability distribution differs from another.

kl_index(arm):
Calculates the KL-UCB index for a specific arm using binary search to determine the upper bound.

pull_arm(arm):
Simulates pulling an arm by generating a reward based on the empirical mean reward for that arm.


Examples:
----------

>>> from river.bandit import KLUCB
>>> n_arms = 3
>>> horizon = 100
>>> c = 1
>>> klucb = KLUCB(n_arms=n_arms, horizon=horizon, c=c)

>>> random.seed(42)

>>> def calculate_reward(arm):
... #Example: Bernoulli reward based on the true probability (for testing)
... true_probabilities = [0.3, 0.5, 0.7] # Example probabilities for each arm
... return 1 if random.random() < true_probabilities[arm] else 0
>>> # Initialize tracking variables
>>> selected_arms = []
>>> total_reward = 0
>>> cumulative_rewards = []
>>> for t in range(1, horizon + 1):
... klucb.t = t
... indices = [klucb.kl_index(arm) for arm in range(n_arms)]
... chosen_arm = indices.index(max(indices))
... reward = calculate_reward(chosen_arm)
... klucb.update(chosen_arm, reward)
... selected_arms.append(chosen_arm)
... total_reward += reward
... cumulative_rewards.append(total_reward)


>>> print("Selected arms:", selected_arms)
Selected arms: [0, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]



>>> print("Cumulative rewards:", cumulative_rewards)
Cumulative rewards: [0, 1, 2, 3, 3, 3, 3, 4, 5, 6, 7, 7, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 11, 12, 12, 13, 14, 15, 15, 16, 16, 16, 17, 17, 18, 19, 19, 20, 20, 20, 20, 21, 22, 23, 24, 25, 26, 27, 27, 28, 29, 30, 31, 31, 31, 31, 32, 32, 33, 34, 34, 34, 34, 35, 35, 35, 36, 37, 38, 39, 40, 40, 40, 41, 41, 42, 42, 42, 43, 44, 44, 45, 45, 45, 46, 47, 47, 48, 49, 50, 51, 52, 52, 53, 54, 55, 55, 56, 56, 56]



>>> print(f"Total Reward: {total_reward}")
Total Reward: 56

"""

def __init__(self, n_arms, horizon, c=0):
self.n_arms = n_arms
self.horizon = horizon
self.c = c
self.arm_count = [1 for _ in range(n_arms)]
self.rewards = [0.0 for _ in range(n_arms)]
self.t = 0

def update(self, arm, reward):
"""
Updates the number of times the arm has been pulled and the cumulative reward
for the given arm. Also increments the current time step.

Parameters
----------
arm (int): The index of the arm that was pulled.
reward (float): The reward obtained from pulling the arm.
"""
self.arm_count[arm] += 1
self.rewards[arm] += reward
self.t += 1

def kl_divergence(self, p, q):
"""
Computes the Kullback-Leibler (KL) divergence between two probabilities `p` and `q`.

Parameters
----------
p (float): The first probability (true distribution).
q (float): The second probability (approximated distribution).

Returns
-------
float: The KL divergence value. Returns infinity if `q` is not a valid probability.
"""

if p == 0:
return float("inf") if q >= 1 else -math.log(1 - q)
elif p == 1:
return float("inf") if q <= 0 else -math.log(q)
elif q <= 0 or q >= 1:
return float("inf")
return p * math.log(p / q) + (1 - p) * math.log((1 - p) / (1 - q))

def kl_index(self, arm):
"""
Computes the KL-UCB index for a given arm using binary search.
This determines the upper confidence bound for the arm.

Parameters
----------
arm (int): The index of the arm to compute the index for.

Returns
-------
float: The KL-UCB index for the arm.
"""

n_t = self.arm_count[arm]
if n_t == 0:
return float("inf") # Unseen arm
empirical_mean = self.rewards[arm] / n_t
log_t_over_n = math.log(self.t + 1) / n_t
c_factor = self.c * log_t_over_n

# Binary search to find the q that satisfies the KL-UCB condition
low = empirical_mean
high = 1.0
for _ in range(100): # Fixed number of iterations for binary search
mid = (low + high) / 2
kl = self.kl_divergence(empirical_mean, mid)
if kl > c_factor:
high = mid
else:
low = mid
return low

def pull_arm(self, arm):
"""
Simulates pulling an arm by generating a reward based on its empirical mean.

Parameters
----------
arm (int): The index of the arm to pull.

Returns
-------
int: 1 if the arm yields a reward, 0 otherwise.
"""
prob = self.rewards[arm] / self.arm_count[arm]
return 1 if random.random() < prob else 0

@staticmethod
def _unit_test_params():
"""
Returns a list of dictionaries with parameters to initialize the KLUCB class
for unit testing.
"""
return [
{"n_arms": 2, "horizon": 100, "c": 0.5},
{"n_arms": 5, "horizon": 1000, "c": 1.0},
{"n_arms": 10, "horizon": 500, "c": 0.1},
]
2 changes: 2 additions & 0 deletions river/linear_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from . import base
from .adpredictor import AdPredictor
from .alma import ALMAClassifier
from .bayesian_lin_reg import BayesianLinearRegression
from .lin_reg import LinearRegression
Expand All @@ -21,4 +22,5 @@
"PARegressor",
"Perceptron",
"SoftmaxRegression",
"AdPredictor",
]
156 changes: 156 additions & 0 deletions river/linear_model/adpredictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

import collections
import math

from river.base.classifier import Classifier


def default_mean():
return 0.0


def default_variance():
return 1.0


class AdPredictor(Classifier):
"""
AdPredictor is a machine learning algorithm designed to predict the probability of user
clicks on online advertisements. This algorithm plays a crucial role in computational advertising, where predicting
click-through rates (CTR) is essential for optimizing ad placements and maximizing revenue.
Parameters
----------
beta (float, default=0.1):
A smoothing parameter that regulates the weight updates. Smaller values allow for finer updates,
while larger values can accelerate convergence but may risk instability.
prior_probability (float, default=0.5):
The initial estimate rate. This value sets the bias weight, influencing the model's predictions
before observing any data.

epsilon (float, default=0.1):
A variance dynamics parameter that controls how the model balances prior knowledge and learned information.
Larger values prioritize prior knowledge, while smaller values favor data-driven updates.

num_features (int, default=10):
The maximum number of features the model can handle. This parameter affects scalability and efficiency,
especially for high-dimensional data.

Attributes
----------
weights (defaultdict):
A dictionary where each feature key maps to a dictionary containing:

mean (float): The current estimate of the feature's weight.
variance (float): The uncertainty associated with the weight estimate.

bias_weight (float):
The weight corresponding to the model bias, initialized using the prior_probability.
This attribute allows the model to make predictions even when no features are active.

Examples:
----------

>>> from river.linear_model import AdPredictor
>>> adpredictor = AdPredictor(beta=0.1, prior_probability=0.5, epsilon=0.1, num_features=5)
>>> data = [({"feature1": 1, "feature2": 1}, 1),({"feature1": 1, "feature3": 1}, 0),({"feature2": 1, "feature4": 1}, 1),({"feature1": 1, "feature2": 1, "feature3": 1}, 0),({"feature4": 1, "feature5": 1}, 1),]
>>> def train_and_test(model, data):
... for x, y in data:
... pred_before = model.predict_one(x)
... model.learn_one(x, y)
... pred_after = model.predict_one(x)
... print(f"Features: {x} | True label: {y} | Prediction before training: {pred_before:.4f} | Prediction after training: {pred_after:.4f}")

>>> train_and_test(adpredictor, data)
Features: {'feature1': 1, 'feature2': 1} | True label: 1 | Prediction before training: 0.5000 | Prediction after training: 0.7230
Features: {'feature1': 1, 'feature3': 1} | True label: 0 | Prediction before training: 0.6065 | Prediction after training: 0.3650
Features: {'feature2': 1, 'feature4': 1} | True label: 1 | Prediction before training: 0.6065 | Prediction after training: 0.7761
Features: {'feature1': 1, 'feature2': 1, 'feature3': 1} | True label: 0 | Prediction before training: 0.5455 | Prediction after training: 0.3197
Features: {'feature4': 1, 'feature5': 1} | True label: 1 | Prediction before training: 0.5888 | Prediction after training: 0.7699

"""

def __init__(self, beta=0.1, prior_probability=0.5, epsilon=0.1, num_features=10):
# Initialization of model parameters
self.beta = beta
self.prior_probability = prior_probability
self.epsilon = epsilon
self.num_features = num_features
# Initialize weights as a defaultdict for each feature, with mean and variance attributes

self.means = collections.defaultdict(default_mean)
self.variances = collections.defaultdict(default_variance)

# Initialize bias weight based on prior probability
self.bias_weight = self.prior_bias_weight()

def prior_bias_weight(self):
# Calculate initial bias weight using prior probability

return math.log(self.prior_probability / (1 - self.prior_probability)) / self.beta

def _active_mean_variance(self, features):
"""_active_mean_variance(features) (method):
Computes the cumulative mean and variance for all active features in a sample,
including the bias. This is crucial for making predictions."""
# Calculate total mean and variance for all active features

total_mean = sum(self.means[f] for f in features) + self.bias_weight
total_variance = sum(self.variances[f] for f in features) + self.beta**2
return total_mean, total_variance

def predict_one(self, x):
# Generate a probability prediction for one sample
features = x.keys()
total_mean, total_variance = self._active_mean_variance(features)
# Sigmoid function for probability prediction based on Gaussian distribution
return 1 / (1 + math.exp(-total_mean / math.sqrt(total_variance)))

def learn_one(self, x, y):
# Online learning step to update the model with one sample
features = x.keys()
y = 1 if y else -1
total_mean, total_variance = self._active_mean_variance(features)
v, w = self.gaussian_corrections(y * total_mean / math.sqrt(total_variance))

# Update mean and variance for each feature in the sample
for feature in features:
mean = self.means[feature]
variance = self.variances[feature]

mean_delta = y * variance / math.sqrt(total_variance) * v # Update mean
variance_multiplier = 1.0 - variance / total_variance * w # Update variance

# Update weight
self.means[feature] = mean + mean_delta
self.variances[feature] = variance * variance_multiplier

def gaussian_corrections(self, score):
"""gaussian_corrections(score) (method):
Implements Bayesian update corrections using the Gaussian probability density function (PDF)
and cumulative density function (CDF)."""
# CDF calculation for Gaussian correction
cdf = 1 / (1 + math.exp(-score))
pdf = math.exp(-0.5 * score**2) / math.sqrt(2 * math.pi) # PDF calculation
v = pdf / cdf # Correction factor for mean update
w = v * (v + score) # Correction factor for variance update
return v, w

def _apply_dynamics(self, weight):
"""_apply_dynamics(weight) (method):
Regularizes the variance of a feature weight using a combination of prior variance and learned variance.
This helps maintain a balance between prior beliefs and observed data."""
# Apply variance dynamics for regularization
prior_variance = 1.0
# Adjust variance to manage prior knowledge and current learning balance
adjusted_variance = (
weight["variance"]
* prior_variance
/ ((1.0 - self.epsilon) * prior_variance + self.epsilon * weight["variance"])
)
# Adjust mean based on the dynamics, balancing previous and current knowledge
adjusted_mean = adjusted_variance * (
(1.0 - self.epsilon) * weight["mean"] / weight["variance"]
+ self.epsilon * 0 / prior_variance
)
return {"mean": adjusted_mean, "variance": adjusted_variance}