-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
Copy pathcustom_softmax.py
186 lines (144 loc) · 5.93 KB
/
custom_softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
'''
Demo for creating customized multi-class objective function
===========================================================
This demo is only applicable after (excluding) XGBoost 1.0.0, as before this version
XGBoost returns transformed prediction for multi-class objective function. More details
in comments.
See :doc:`/tutorials/custom_metric_obj` and :doc:`/tutorials/advanced_custom_obj` for
detailed tutorial and notes.
'''
import argparse
import numpy as np
from matplotlib import pyplot as plt
import xgboost as xgb
np.random.seed(1994)
kRows = 100
kCols = 10
kClasses = 4 # number of classes
kRounds = 10 # number of boosting rounds.
# Generate some random data for demo.
X = np.random.randn(kRows, kCols)
y = np.random.randint(0, 4, size=kRows)
m = xgb.DMatrix(X, y)
def softmax(x):
'''Softmax function with x as input vector.'''
e = np.exp(x)
return e / np.sum(e)
def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
'''Loss function. Computing the gradient and upper bound on the
Hessian with a diagonal structure for XGBoost (note that this is
not the true Hessian).
Reimplements the `multi:softprob` inside XGBoost.
'''
labels = data.get_label()
if data.get_weight().size == 0:
# Use 1 as weight if we don't have custom weight.
weights = np.ones((kRows, 1), dtype=float)
else:
weights = data.get_weight()
# The prediction is of shape (rows, classes), each element in a row
# represents a raw prediction (leaf weight, hasn't gone through softmax
# yet). In XGBoost 1.0.0, the prediction is transformed by a softmax
# function, fixed in later versions.
assert predt.shape == (kRows, kClasses)
grad = np.zeros((kRows, kClasses), dtype=float)
hess = np.zeros((kRows, kClasses), dtype=float)
eps = 1e-6
# compute the gradient and hessian upper bound, slow iterations in Python, only
# suitable for demo. Also the one in native XGBoost core is more robust to
# numeric overflow as we don't do anything to mitigate the `exp` in
# `softmax` here.
for r in range(predt.shape[0]):
target = labels[r]
p = softmax(predt[r, :])
for c in range(predt.shape[1]):
assert target >= 0 or target <= kClasses
g = p[c] - 1.0 if c == target else p[c]
g = g * weights[r]
h = max((2.0 * p[c] * (1.0 - p[c]) * weights[r]).item(), eps)
grad[r, c] = g
hess[r, c] = h
# After 2.1.0, pass the gradient as it is.
return grad, hess
def predict(booster: xgb.Booster, X):
'''A customized prediction function that converts raw prediction to
target class.
'''
# Output margin means we want to obtain the raw prediction obtained from
# tree leaf weight.
predt = booster.predict(X, output_margin=True)
out = np.zeros(kRows)
for r in range(predt.shape[0]):
# the class with maximum prob (not strictly prob as it haven't gone
# through softmax yet so it doesn't sum to 1, but result is the same
# for argmax).
i = np.argmax(predt[r])
out[r] = i
return out
def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
y = dtrain.get_label()
# Like custom objective, the predt is untransformed leaf weight when custom objective
# is provided.
# With the use of `custom_metric` parameter in train function, custom metric receives
# raw input only when custom objective is also being used. Otherwise custom metric
# will receive transformed prediction.
assert predt.shape == (kRows, kClasses)
out = np.zeros(kRows)
for r in range(predt.shape[0]):
i = np.argmax(predt[r])
out[r] = i
assert y.shape == out.shape
errors = np.zeros(kRows)
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / kRows
def plot_history(custom_results, native_results):
fig, axs = plt.subplots(2, 1)
ax0 = axs[0]
ax1 = axs[1]
pymerror = custom_results['train']['PyMError']
merror = native_results['train']['merror']
x = np.arange(0, kRounds, 1)
ax0.plot(x, pymerror, label='Custom objective')
ax0.legend()
ax1.plot(x, merror, label='multi:softmax')
ax1.legend()
plt.show()
def main(args):
custom_results = {}
# Use our custom objective function
booster_custom = xgb.train({'num_class': kClasses,
'disable_default_eval_metric': True},
m,
num_boost_round=kRounds,
obj=softprob_obj,
custom_metric=merror,
evals_result=custom_results,
evals=[(m, 'train')])
predt_custom = predict(booster_custom, m)
native_results = {}
# Use the same objective function defined in XGBoost.
booster_native = xgb.train({'num_class': kClasses,
"objective": "multi:softmax",
'eval_metric': 'merror'},
m,
num_boost_round=kRounds,
evals_result=native_results,
evals=[(m, 'train')])
predt_native = booster_native.predict(m)
# We are reimplementing the loss function in XGBoost, so it should
# be the same for normal cases.
assert np.all(predt_custom == predt_native)
np.testing.assert_allclose(custom_results['train']['PyMError'],
native_results['train']['merror'])
if args.plot != 0:
plot_history(custom_results, native_results)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Arguments for custom softmax objective function demo.')
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)