-
Notifications
You must be signed in to change notification settings - Fork 121
/
segment_tree.py
129 lines (107 loc) · 4.17 KB
/
segment_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-
"""Segment tree for Proirtized Replay Buffer."""
import operator
from typing import Callable
class SegmentTree:
"""Create SegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
Attributes:
capacity (int)
tree (list)
operation (function)
"""
def __init__(self, capacity: int, operation: Callable, init_value: float):
"""Initialization.
Args:
capacity (int)
operation (function)
init_value (float)
"""
assert (
capacity > 0 and capacity & (capacity - 1) == 0
), "capacity must be positive and a power of 2."
self.capacity = capacity
self.tree = [init_value for _ in range(2 * capacity)]
self.operation = operation
def _operate_helper(
self, start: int, end: int, node: int, node_start: int, node_end: int
) -> float:
"""Returns result of operation in segment."""
if start == node_start and end == node_end:
return self.tree[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._operate_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._operate_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self.operation(
self._operate_helper(start, mid, 2 * node, node_start, mid),
self._operate_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end),
)
def operate(self, start: int = 0, end: int = 0) -> float:
"""Returns result of applying `self.operation`."""
if end <= 0:
end += self.capacity
end -= 1
return self._operate_helper(start, end, 1, 0, self.capacity - 1)
def __setitem__(self, idx: int, val: float):
"""Set value in tree."""
idx += self.capacity
self.tree[idx] = val
idx //= 2
while idx >= 1:
self.tree[idx] = self.operation(self.tree[2 * idx], self.tree[2 * idx + 1])
idx //= 2
def __getitem__(self, idx: int) -> float:
"""Get real value in leaf node of tree."""
assert 0 <= idx < self.capacity
return self.tree[self.capacity + idx]
class SumSegmentTree(SegmentTree):
"""Create SumSegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
"""
def __init__(self, capacity: int):
"""Initialization.
Args:
capacity (int)
"""
super(SumSegmentTree, self).__init__(
capacity=capacity, operation=operator.add, init_value=0.0
)
def sum(self, start: int = 0, end: int = 0) -> float:
"""Returns arr[start] + ... + arr[end]."""
return super(SumSegmentTree, self).operate(start, end)
def retrieve(self, upperbound: float) -> int:
"""Find the highest index `i` about upper bound in the tree"""
# TODO: Check assert case and fix bug
assert 0 <= upperbound <= self.sum() + 1e-5, "upperbound: {}".format(upperbound)
idx = 1
while idx < self.capacity: # while non-leaf
left = 2 * idx
right = left + 1
if self.tree[left] > upperbound:
idx = 2 * idx
else:
upperbound -= self.tree[left]
idx = right
return idx - self.capacity
class MinSegmentTree(SegmentTree):
"""Create SegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
"""
def __init__(self, capacity: int):
"""Initialization.
Args:
capacity (int)
"""
super(MinSegmentTree, self).__init__(
capacity=capacity, operation=min, init_value=float("inf")
)
def min(self, start: int = 0, end: int = 0) -> float:
"""Returns min(arr[start], ..., arr[end])."""
return super(MinSegmentTree, self).operate(start, end)