forked from dmitryikh/leaves
-
Notifications
You must be signed in to change notification settings - Fork 6
/
xgblinear_io.go
95 lines (87 loc) · 3.13 KB
/
xgblinear_io.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package leaves
import (
"bufio"
"fmt"
"github.com/dmitryikh/leaves/internal/xgjson"
"os"
"github.com/dmitryikh/leaves/internal/xgbin"
"github.com/dmitryikh/leaves/transformation"
)
// XGBLinearFromReader reads XGBoost's 'gblinear' model from `reader`
func XGBLinearFromReader(reader *bufio.Reader, loadTransformation bool) (*Ensemble, error) {
e := &xgLinear{}
//To support version after 1.0.0
xgbin.ReadBinf(reader)
// reading header info
header, err := xgbin.ReadModelHeader(reader)
if err != nil {
return nil, err
}
gbLinearModel, err := xgbin.ReadGBLinearModel(reader)
if err != nil {
return nil, err
}
if header.NameGbm != "gblinear" {
return nil, fmt.Errorf("only gblinear is supported (got %s). Use XGEnsembleFrom.. for gbtree", header.NameGbm)
}
if header.Param.NumFeatures == 0 {
return nil, fmt.Errorf("zero number of features")
}
e.BaseScore = 0
e.nRawOutputGroups = 1
if header.Param.MajorVersion > uint32(0) {
e.nRawOutputGroups = getNRawOutputGroups(header.Param.NumClass)
e.BaseScore = calculateBaseScoreFromLearnerParam(float64(header.Param.BaseScore))
e.NumFeature = int(header.Param.NumFeatures)
} else {
e.nRawOutputGroups = getNRawOutputGroups(gbLinearModel.Param.NumOutputGroup)
e.BaseScore = float64(header.Param.BaseScore)
e.NumFeature = int(gbLinearModel.Param.NumFeature)
}
e.Weights = gbLinearModel.Weights
var transform transformation.Transform
transform = &transformation.TransformRaw{NumOutputGroups: e.nRawOutputGroups}
if loadTransformation {
if header.NameObj == "binary:logistic" {
transform = &transformation.TransformLogistic{}
} else {
return nil, fmt.Errorf("unknown transformation function '%s'", header.NameObj)
}
}
return &Ensemble{e, transform}, nil
}
// XGBLinearFromFile reads XGBoost's 'gblinear' model from binary file
func XGBLinearFromFile(filename string, loadTransformation bool) (*Ensemble, error) {
if ensemble, err := xgbLinearFromJson(filename, loadTransformation); err == nil {
return ensemble, nil
}
reader, err := os.Open(filename)
if err != nil {
return nil, err
}
defer reader.Close()
bufReader := bufio.NewReader(reader)
return XGBLinearFromReader(bufReader, loadTransformation)
}
func xgbLinearFromJson(filename string, loadTransformation bool) (*Ensemble, error) {
gbLinearJson, err := xgjson.ReadGBLinear(filename)
if err != nil {
return nil, err
}
e := &xgLinear{}
gbLinearModel := gbLinearJson.Learner.GradientBooster.Model
e.nRawOutputGroups = getNRawOutputGroups(gbLinearJson.Learner.LearnerModelParam.NumClass)
e.NumFeature = int(gbLinearJson.Learner.LearnerModelParam.NumFeatures)
e.Weights = gbLinearModel.Weights
e.BaseScore = calculateBaseScoreFromLearnerParam(float64(gbLinearJson.Learner.LearnerModelParam.BaseScore))
var transform transformation.Transform
transform = &transformation.TransformRaw{NumOutputGroups: e.nRawOutputGroups}
if loadTransformation {
if gbLinearJson.Learner.Objective.Name == "binary:logistic" {
transform = &transformation.TransformLogistic{}
} else {
return nil, fmt.Errorf("unknown transformation function '%s'", gbLinearJson.Learner.Objective.Name)
}
}
return &Ensemble{e, transform}, nil
}