Skip to content

Commit

Permalink
Vectorization of cMAB Predict (#61)
Browse files Browse the repository at this point in the history
### Changes
 * Changed predict method of BaseCmabBernoulli on cmab.py to vectorized version, rather than for loop.
shaharbar1 authored Sep 24, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent dfaab15 commit 9c15f78
Showing 2 changed files with 27 additions and 28 deletions.
53 changes: 26 additions & 27 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
@@ -123,33 +123,32 @@ def predict(
probs = len(context) * [{k: 0.5 for k in valid_actions}] # all probs are set to 0.5
weighted_sums = len(context) * [{k: 0 for k in valid_actions}] # all weighted sum are set to 1
else:
selected_actions: List[ActionId] = []
probs: List[Dict[ActionId, Probability]] = []
weighted_sums: List[Dict[ActionId, float]] = []

# sample_proba() and select_action() each row of the context
for i in range(len(context)):
# p is a dict of the sampled probability "prob" and weighted_sum "ws", e.g.
#
# p = {'a1': ([0.5], [200]), 'a2': ([0.4], [100]), ...}
# | | | |
# prob ws prob ws
p = {
action: model.sample_proba(context=context[i].reshape(1, -1)) # reshape row i-th to (1, n_features)
for action, model in self.actions.items()
if action in valid_actions
}

prob = {a: x[0][0] for a, x in p.items()} # e.g. prob = {'a1': 0.5, 'a2': 0.4, ...}
ws = {a: x[1][0] for a, x in p.items()} # e.g. ws = {'a1': 200, 'a2': 100, ...}

# select either "prob" or "ws" to use as input argument in select_actions()
p_to_select_action = prob if self.predict_with_proba else ws

# predict actions, probs, weighted_sums
selected_actions.append(self._select_epsilon_greedy_action(p=p_to_select_action, actions=self.actions))
probs.append(prob)
weighted_sums.append(ws)
# p is a dict of the sampled probability "prob" and weighted_sum "ws", e.g.
#
# p = {'a1': ([0.5, 0.2, 0.3], [200, 100, 130]), 'a2': ([0.4, 0.5, 0.6], [180, 200, 230]), ...}
# | | | |
# prob ws prob ws
p = {
action: model.sample_proba(context=context) # sample probabilities for the entire context matrix
for action, model in self.actions.items()
if action in valid_actions
}

prob = {a: x[0] for a, x in p.items()} # e.g. prob = {'a1': [0.5, 0.4, ...], 'a2': [0.4, 0.3, ...], ...}
ws = {a: x[1] for a, x in p.items()} # e.g. ws = {'a1': [200, 100, ...], 'a2': [100, 50, ...], ...}

# select either "prob" or "ws" to use as input argument in select_actions()
p_to_select_action = prob if self.predict_with_proba else ws

# predict actions, probs, weighted_sums
selected_actions = [
self._select_epsilon_greedy_action(
p={a: p_to_select_action[a][i] for a in p_to_select_action}, actions=self.actions
)
for i in range(len(context))
]
probs = [{a: prob[a][i] for a in prob} for i in range(len(context))]
weighted_sums = [{a: ws[a][i] for a in ws} for i in range(len(context))]

return selected_actions, probs, weighted_sums

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pybandits"
version = "0.6.0"
version = "0.6.1"
description = "Python Multi-Armed Bandit Library"
authors = [
"Dario d'Andrea <[email protected]>",

0 comments on commit 9c15f78

Please sign in to comment.