-
Notifications
You must be signed in to change notification settings - Fork 50
/
Copy pathDecisionTree.py
315 lines (244 loc) · 9.67 KB
/
DecisionTree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Data wrangling
import pandas as pd
# Array math
import numpy as np
# Quick value count calculator
from collections import Counter
class Node:
"""
Class for creating the nodes for a decision tree
"""
def __init__(
self,
Y: list,
X: pd.DataFrame,
min_samples_split=None,
max_depth=None,
depth=None,
node_type=None,
rule=None
):
# Saving the data to the node
self.Y = Y
self.X = X
# Saving the hyper parameters
self.min_samples_split = min_samples_split if min_samples_split else 20
self.max_depth = max_depth if max_depth else 5
# Default current depth of node
self.depth = depth if depth else 0
# Extracting all the features
self.features = list(self.X.columns)
# Type of node
self.node_type = node_type if node_type else 'root'
# Rule for spliting
self.rule = rule if rule else ""
# Calculating the counts of Y in the node
self.counts = Counter(Y)
# Getting the GINI impurity based on the Y distribution
self.gini_impurity = self.get_GINI()
# Sorting the counts and saving the final prediction of the node
counts_sorted = list(sorted(self.counts.items(), key=lambda item: item[1]))
# Getting the last item
yhat = None
if len(counts_sorted) > 0:
yhat = counts_sorted[-1][0]
# Saving to object attribute. This node will predict the class with the most frequent class
self.yhat = yhat
# Saving the number of observations in the node
self.n = len(Y)
# Initiating the left and right nodes as empty nodes
self.left = None
self.right = None
# Default values for splits
self.best_feature = None
self.best_value = None
@staticmethod
def GINI_impurity(y1_count: int, y2_count: int) -> float:
"""
Given the observations of a binary class calculate the GINI impurity
"""
# Ensuring the correct types
if y1_count is None:
y1_count = 0
if y2_count is None:
y2_count = 0
# Getting the total observations
n = y1_count + y2_count
# If n is 0 then we return the lowest possible gini impurity
if n == 0:
return 0.0
# Getting the probability to see each of the classes
p1 = y1_count / n
p2 = y2_count / n
# Calculating GINI
gini = 1 - (p1 ** 2 + p2 ** 2)
# Returning the gini impurity
return gini
@staticmethod
def ma(x: np.array, window: int) -> np.array:
"""
Calculates the moving average of the given list.
"""
return np.convolve(x, np.ones(window), 'valid') / window
def get_GINI(self):
"""
Function to calculate the GINI impurity of a node
"""
# Getting the 0 and 1 counts
y1_count, y2_count = self.counts.get(0, 0), self.counts.get(1, 0)
# Getting the GINI impurity
return self.GINI_impurity(y1_count, y2_count)
def best_split(self) -> tuple:
"""
Given the X features and Y targets calculates the best split
for a decision tree
"""
# Creating a dataset for spliting
df = self.X.copy()
df['Y'] = self.Y
# Getting the GINI impurity for the base input
GINI_base = self.get_GINI()
# Finding which split yields the best GINI gain
max_gain = 0
# Default best feature and split
best_feature = None
best_value = None
for feature in self.features:
# Droping missing values
Xdf = df.dropna().sort_values(feature)
# Sorting the values and getting the rolling average
xmeans = self.ma(Xdf[feature].unique(), 2)
for value in xmeans:
# Spliting the dataset
left_counts = Counter(Xdf[Xdf[feature]<value]['Y'])
right_counts = Counter(Xdf[Xdf[feature]>=value]['Y'])
# Getting the Y distribution from the dicts
y0_left, y1_left, y0_right, y1_right = left_counts.get(0, 0), left_counts.get(1, 0), right_counts.get(0, 0), right_counts.get(1, 0)
# Getting the left and right gini impurities
gini_left = self.GINI_impurity(y0_left, y1_left)
gini_right = self.GINI_impurity(y0_right, y1_right)
# Getting the obs count from the left and the right data splits
n_left = y0_left + y1_left
n_right = y0_right + y1_right
# Calculating the weights for each of the nodes
w_left = n_left / (n_left + n_right)
w_right = n_right / (n_left + n_right)
# Calculating the weighted GINI impurity
wGINI = w_left * gini_left + w_right * gini_right
# Calculating the GINI gain
GINIgain = GINI_base - wGINI
# Checking if this is the best split so far
if GINIgain > max_gain:
best_feature = feature
best_value = value
# Setting the best gain to the current one
max_gain = GINIgain
return (best_feature, best_value)
def grow_tree(self):
"""
Recursive method to create the decision tree
"""
# Making a df from the data
df = self.X.copy()
df['Y'] = self.Y
# If there is GINI to be gained, we split further
if (self.depth < self.max_depth) and (self.n >= self.min_samples_split):
# Getting the best split
best_feature, best_value = self.best_split()
if best_feature is not None:
# Saving the best split to the current node
self.best_feature = best_feature
self.best_value = best_value
# Getting the left and right nodes
left_df, right_df = df[df[best_feature]<=best_value].copy(), df[df[best_feature]>best_value].copy()
# Creating the left and right nodes
left = Node(
left_df['Y'].values.tolist(),
left_df[self.features],
depth=self.depth + 1,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
node_type='left_node',
rule=f"{best_feature} <= {round(best_value, 3)}"
)
self.left = left
self.left.grow_tree()
right = Node(
right_df['Y'].values.tolist(),
right_df[self.features],
depth=self.depth + 1,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
node_type='right_node',
rule=f"{best_feature} > {round(best_value, 3)}"
)
self.right = right
self.right.grow_tree()
def print_info(self, width=4):
"""
Method to print the infromation about the tree
"""
# Defining the number of spaces
const = int(self.depth * width ** 1.5)
spaces = "-" * const
if self.node_type == 'root':
print("Root")
else:
print(f"|{spaces} Split rule: {self.rule}")
print(f"{' ' * const} | GINI impurity of the node: {round(self.gini_impurity, 2)}")
print(f"{' ' * const} | Class distribution in the node: {dict(self.counts)}")
print(f"{' ' * const} | Predicted class: {self.yhat}")
def print_tree(self):
"""
Prints the whole tree from the current node to the bottom
"""
self.print_info()
if self.left is not None:
self.left.print_tree()
if self.right is not None:
self.right.print_tree()
def predict(self, X:pd.DataFrame):
"""
Batch prediction method
"""
predictions = []
for _, x in X.iterrows():
values = {}
for feature in self.features:
values.update({feature: x[feature]})
predictions.append(self.predict_obs(values))
return predictions
def predict_obs(self, values: dict) -> int:
"""
Method to predict the class given a set of features
"""
cur_node = self
while cur_node.depth < cur_node.max_depth:
# Traversing the nodes all the way to the bottom
best_feature = cur_node.best_feature
best_value = cur_node.best_value
if cur_node.n < cur_node.min_samples_split:
break
if (values.get(best_feature) < best_value):
if self.left is not None:
cur_node = cur_node.left
else:
if self.right is not None:
cur_node = cur_node.right
return cur_node.yhat
if __name__ == '__main__':
# Reading data
d = pd.read_csv("data/classification/train.csv")[['Age', 'Fare', 'Survived']].dropna()
# Constructing the X and Y matrices
X = d[['Age', 'Fare']]
Y = d['Survived'].values.tolist()
# Initiating the Node
root = Node(Y, X, max_depth=3, min_samples_split=100)
# Getting teh best split
root.grow_tree()
# Printing the tree information
root.print_tree()
# Predicting
Xsubset = X.copy()
Xsubset['yhat'] = root.predict(Xsubset)
print(Xsubset)