-
Notifications
You must be signed in to change notification settings - Fork 0
/
env.py
61 lines (49 loc) · 1.51 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
from utils import random_cov
class LinearContextualBandit:
def __init__(self, w, sigma=1, x_norm=1):
self.K = w.shape[0]
self.d = w.shape[1]
self.w = w
self.x_norm = x_norm
self.sigma = sigma
def sample_x(self):
x = np.random.randn(self.d)
normalization = np.linalg.norm(x, 2)
normalization *= self.x_norm
return x / normalization
def best_r(self, x):
rmax = -np.inf
for a in range(self.K):
r = x @ self.w[a]
if rmax < r:
rmax = r
return rmax
def sample_r(self, x, a):
real_r = x @ self.w[a]
noise = self.sigma * np.random.randn()
noisy_r = real_r + noise
return real_r, noisy_r
class CifarBandit:
def __init__(self, features, labels, sigma=1):
self.features = features
self.labels = labels
self.sigma = sigma
self.N = len(features)
self.x_norm = np.linalg.norm(features, axis=1).max()
self.idx = None
def sample_x(self):
self.idx = np.random.randint(self.N)
return self.features[self.idx]
def best_r(self, x):
return 1.
def sample_r(self, x, a):
if self.idx is None:
raise ValueError("x wasn't sampled yet")
if a == self.labels[self.idx]:
real_r = 1.
else:
real_r = 0.
noise = self.sigma * np.random.randn()
noisy_r = real_r + noise
return real_r, noisy_r