-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkdtree_util.py
66 lines (59 loc) · 1.87 KB
/
kdtree_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
def get_median_idx(X, idxs, feature):
n = len(idxs)
k = n // 2
col = map(lambda i: (i, X[i][feature]), idxs)
sorted_idxs = map(lambda x: x[0], sorted(col, key=lambda x: x[1]))
media_idx = list(sorted_idxs)[k]
return media_idx
def get_variance(X, idxs, feature):
n = len(idxs)
col_sum = col_sum_sqr = 0
for idx in idxs:
xi = X[idx][feature]
col_sum += xi
col_sum_sqr += xi ** 2
return col_sum_sqr / n - (col_sum / n) ** 2
def choose_feature(X, idxs):
m = len(X[0])
variances = map(lambda j: (j, get_variance(X, idxs, j)), range(m))
a = max(variances, key=lambda x: x[1])[0]
return a
def split_feature(X, idxs, feature, median_idx):
idxs_split = [[], []]
split_val = X[median_idx][feature]
for idx in idxs:
if idx == median_idx:
idxs_split[0].append(idx)
continue
xi = X[idx][feature]
if xi < split_val:
idxs_split[0].append(idx)
else:
idxs_split[1].append(idx)
return idxs_split
def split_by_kdtree(points, threshold):
candidate_list = []
cur_list=[]
point = points
num = point.shape[0]
num_this_group = point.shape[0]
idxs = range(num_this_group)
if num>threshold:
cur_list.append(idxs)
else:
candidate_list.append(idxs)
while len(cur_list):
this_idxs = cur_list.pop(0)
dim = choose_feature(point, this_idxs)
median_id = get_median_idx(point, this_idxs, dim)
idxs_left, idxs_right = split_feature(point, this_idxs, dim, median_id)
if len(idxs_left)>threshold:
cur_list.append(idxs_left)
else:
candidate_list.append(idxs_left)
if len(idxs_right)>threshold:
cur_list.append(idxs_right)
else:
candidate_list.append(idxs_right)
return candidate_list