Skip to content

Commit

Permalink
modified code to black style #14
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamLwj committed Mar 9, 2023
1 parent f5f8eec commit addd94b
Show file tree
Hide file tree
Showing 37 changed files with 226 additions and 143 deletions.
26 changes: 18 additions & 8 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,38 @@ authors:
- family-names: "Li"
given-names: "Wenjie"
orcid: "https://orcid.org/0000-0000-0000-0000"
- family-names: "Song"
given-names: "Qifan"
- family-names: "Li"
given-names: "Haoze"
orcid: "https://orcid.org/0000-0000-0000-0000"
- family-names: "Honorio"
given-names: "Jean"
orcid: "https://orcid.org/0000-0000-0000-0000"
- family-names: "Song"
given-names: "Qifan"
orcid: "https://orcid.org/0000-0000-0000-0000"
title: "PyXAB - A Python Library for X-Armed Bandit and Online Blackbox Optimization Algorithms"
version: 0.0.0
doi: 10.48550/arXiv.2106.09215
date-released: 2022-10-21
url: "https://github.com/WilliamLwj/PyXAB/"
doi: 10.48550/ARXIV.2303.04030
date-released: 2023-03-07
url: "https://arxiv.org/abs/2303.04030"
preferred-citation:
type: article
authors:
- family-names: "Li"
given-names: "Wenjie"
orcid: "https://orcid.org/0000-0000-0000-0000"
- family-names: "Song"
given-names: "Qifan"
- family-names: "Li"
given-names: "Haoze"
orcid: "https://orcid.org/0000-0000-0000-0000"
- family-names: "Honorio"
given-names: "Jean"
orcid: "https://orcid.org/0000-0000-0000-0000"
- family-names: "Song"
given-names: "Qifan"
orcid: "https://orcid.org/0000-0000-0000-0000"
title: "PyXAB - A Python Library for X-Armed Bandit and Online Blackbox Optimization Algorithms"
year: 2022
year: 2023
doi: 10.48550/ARXIV.2303.04030
date-released: 2023-03-07
url: "https://arxiv.org/abs/2303.04030"
journal: arXiv
40 changes: 19 additions & 21 deletions PyXAB/algos/HCT.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def compute_t_plus(x):
return np.power(2, np.ceil(np.log(x) / np.log(2)))



class HCT_node(P_node):
"""
Implementation of HCT node
"""

def __init__(self, depth, index, parent, domain):
"""
Initialization of the HCT node
Expand Down Expand Up @@ -81,11 +81,12 @@ def compute_u_value(self, nu, rho, c, delta_tilde):
self.u_value = np.inf
else:
self.mean_reward = np.sum(np.array(self.rewards)) / self.visited_times
self.u_value = self.mean_reward \
+ nu * (rho ** self.get_depth()) \
+ math.sqrt(
c ** 2 * math.log(1 / delta_tilde) / self.visited_times
)
self.u_value = (
self.mean_reward
+ nu * (rho ** self.get_depth())
+ math.sqrt(c ** 2 * math.log(1 / delta_tilde) / self.visited_times)
)

def update_b_value(self, b_value):
"""
The function to update the b_{h,i} value of the node
Expand Down Expand Up @@ -203,26 +204,25 @@ def optTraverse(self):
for i in range(1, self.partition.get_depth() + 1):
self.tau_h.append(
np.ceil(
self.c**2
self.c ** 2
* math.log(1 / delta_tilde)
* self.rho ** (-2 * i)
/ self.nu**2
/ self.nu ** 2
)
)

curr_node = self.partition.get_root()
path = [curr_node]

while (
curr_node.get_visited_times()
>= self.tau_h[curr_node.get_depth()]
curr_node.get_visited_times() >= self.tau_h[curr_node.get_depth()]
and curr_node.get_children() is not None
):
children = curr_node.get_children()
maxchild = children[0]
for child in children[1:]:

if (child.get_b_value() >= maxchild.get_b_value()):
if child.get_b_value() >= maxchild.get_b_value():
maxchild = child

curr_node = maxchild
Expand Down Expand Up @@ -265,8 +265,9 @@ def updateUvalueTree(self):
node_list = self.partition.get_node_list()
for layer in node_list:
for node in layer:
node.compute_u_value(nu=self.nu, rho=self.rho, c=self.c, delta_tilde=delta_tilde)

node.compute_u_value(
nu=self.nu, rho=self.rho, c=self.c, delta_tilde=delta_tilde
)

def updateBackwardTree(self):
"""
Expand All @@ -285,11 +286,9 @@ def updateBackwardTree(self):
if children is None:
node.update_b_value(node.get_u_value())
else:
tempB = - np.inf
tempB = -np.inf
for child in node.get_children():
tempB = np.maximum(
tempB, child.get_b_value()
)
tempB = np.maximum(tempB, child.get_b_value())

node.update_b_value(np.minimum(node.get_u_value(), tempB))

Expand All @@ -312,8 +311,6 @@ def expand(self, parent):
else:
self.partition.make_children(parent=parent, newlayer=False)



def updateAllTree(self, path, reward):
"""
The function to update everything in the tree
Expand All @@ -336,13 +333,14 @@ def updateAllTree(self, path, reward):
self.updateUvalueTree()
self.updateBackwardTree()


self.updateRewardTree(path, reward)

end_node = path[-1]
en_depth = end_node.get_depth()

end_node.compute_u_value(nu=self.nu, rho=self.rho, c=self.c, delta_tilde=delta_tilde)
end_node.compute_u_value(
nu=self.nu, rho=self.rho, c=self.c, delta_tilde=delta_tilde
)

self.updateBackwardTree()

Expand Down
24 changes: 9 additions & 15 deletions PyXAB/algos/HOO.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from PyXAB.partition.Node import P_node
import pdb


class HOO_node(P_node):
"""
Implementation of the HOO_node
"""

def __init__(self, depth, index, parent, domain):
"""
Initialization of the HOO node
Expand Down Expand Up @@ -76,12 +78,9 @@ def compute_u_value(self, nu, rho, rounds):
self.b_value = np.inf
else:
self.mean_reward = np.sum(np.array(self.rewards)) / self.visited_times
UCB = math.sqrt(
2 * math.log(rounds) / self.visited_times
)
UCB = math.sqrt(2 * math.log(rounds) / self.visited_times)
self.u_value = self.mean_reward + UCB + nu * (rho ** self.depth)


def update_b_value(self, b_value):
"""
The function to update the b_{h,i} value of the node
Expand Down Expand Up @@ -142,6 +141,7 @@ class T_HOO(Algorithm):
"""
Implementation of the T_HOO algorithm
"""

def __init__(self, nu=1, rho=0.5, rounds=1000, domain=None, partition=None):
"""
Initialization of the T_HOO algorithm
Expand Down Expand Up @@ -190,11 +190,9 @@ def optTraverse(self):

while curr_node.get_children() is not None:
children = curr_node.get_children()
maxchild = children[0]
for child in children[1: ]:
if (
child.get_b_value() >= maxchild.get_b_value()
):
maxchild = children[0]
for child in children[1:]:
if child.get_b_value() >= maxchild.get_b_value():
maxchild = child

curr_node = maxchild
Expand Down Expand Up @@ -254,11 +252,9 @@ def updateBackwardTree(self):
if children is None:
node.update_b_value(node.get_u_value())
else:
tempB = - np.inf
tempB = -np.inf
for child in node.get_children():
tempB = np.maximum(
tempB, child.get_b_value()
)
tempB = np.maximum(tempB, child.get_b_value())
node.update_b_value(np.minimum(node.get_u_value(), tempB))

def expand(self, parent):
Expand Down Expand Up @@ -320,7 +316,6 @@ def pull(self, time):
curr_node, self.path = self.optTraverse()
return curr_node.get_cpoint()


def receive_reward(self, time, reward):
"""
The receive_reward function of T_HOO to obtain the reward and update the Statistics
Expand All @@ -337,4 +332,3 @@ def receive_reward(self, time, reward):
"""
self.updateAllTree(self.path, reward)

2 changes: 2 additions & 0 deletions PyXAB/algos/StoSOO.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class StoSOO_node(P_node):
"""
Implementation of the node in the StoSOO algorithm
"""

def __init__(self, depth, index, parent, domain):
"""
Initialization of the StoSOO node
Expand Down Expand Up @@ -115,6 +116,7 @@ class StoSOO(Algorithm):
"""
The implementation of the StoSOO algorithm (Valko et al., 2013)
"""

def __init__(
self, n=100, k=None, h_max=100, delta=None, domain=None, partition=None
):
Expand Down
8 changes: 4 additions & 4 deletions PyXAB/algos/StroquOOL.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def pull(self, time):
node = node_list[self.curr_depth][i]
if (
node.not_opened()
and node.get_visited_times() >= 2**self.curr_p
and node.get_visited_times() >= 2 ** self.curr_p
):
node.compute_mean_reward()
if node.get_mean_reward() >= max_reward:
Expand All @@ -137,11 +137,11 @@ def pull(self, time):
self.chosen.append(self.max_node.get_children()[0])
self.chosen.append(self.max_node.get_children()[1])
# evaluate children
if self.iteration <= self.time_stamp + 2**self.curr_p:
if self.iteration <= self.time_stamp + 2 ** self.curr_p:
self.curr_node = self.max_node.get_children()[0]
return self.max_node.get_children()[0].get_cpoint()
if (
self.time_stamp + 2**self.curr_p
self.time_stamp + 2 ** self.curr_p
< self.iteration
<= self.time_stamp + 2 ** (self.curr_p + 1)
):
Expand All @@ -163,7 +163,7 @@ def pull(self, time):
max_value = -np.inf
max_node = None
for i in range(len(self.chosen)):
if self.chosen[i].get_visited_times() >= 2**p:
if self.chosen[i].get_visited_times() >= 2 ** p:
if self.chosen[i].get_mean_reward() >= max_value:
max_value = self.chosen[i].get_mean_reward()
max_node = self.chosen[i]
Expand Down
Loading

0 comments on commit addd94b

Please sign in to comment.