forked from Elvenson/xgboost-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
xgbensemble.go
39 lines (34 loc) · 893 Bytes
/
xgbensemble.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
package xgboost
import (
"github.com/Elvenson/xgboost-go/mat"
)
type xgbEnsemble struct {
Trees []*xgbTree
name string
numClasses int
numFeat int
}
// Name returns name of ensemble model.
func (e *xgbEnsemble) Name() string {
return e.name
}
// NumClasses returns number of features for this ensemble model.
func (e *xgbEnsemble) NumClasses() int {
return e.numClasses
}
// PredictInner returns prediction of this ensemble model.
func (e *xgbEnsemble) PredictInner(features mat.SparseVector) (mat.Vector, error) {
// number of trees for 1 class.
pred := make([]float64, e.numClasses)
numTreesPerClass := len(e.Trees) / e.numClasses
for i := 0; i < e.numClasses; i++ {
for k := 0; k < numTreesPerClass; k++ {
p, err := e.Trees[k*e.numClasses+i].predict(features)
if err != nil {
return mat.Vector{}, nil
}
pred[i] += p
}
}
return pred, nil
}