-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathTargetEncoder.java
117 lines (90 loc) · 3.51 KB
/
TargetEncoder.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*
* Copyright (c) 2024 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn.preprocessing;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
public class TargetEncoder extends BaseEncoder {
public TargetEncoder(String module, String name){
super(module, name);
}
@Override
public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder){
List<List<Object>> categories = getCategories();
List<List<Number>> encodings = getEncodings();
Number targetMean = getTargetMean();
@SuppressWarnings("unused")
String targetType = getTargetType();
ClassDictUtil.checkSize(features.size(), categories, encodings);
List<Feature> result = new ArrayList<>();
for(int i = 0; i < features.size(); i++){
Feature feature = features.get(i);
List<Object> featureCategories = categories.get(i);
List<Number> featureEncodings = encodings.get(i);
ClassDictUtil.checkSize(featureCategories, featureEncodings);
// A NaN value or null
Object missingCategory = getMissingCategory(featureCategories);
Number mapMissingTo = null;
int index = featureCategories.indexOf(missingCategory);
if(index > -1){
featureCategories = new ArrayList<>(featureCategories);
featureCategories.remove(index);
featureEncodings = new ArrayList<>(featureEncodings);
mapMissingTo = featureEncodings.remove(index);
}
encoder.toCategorical(feature.getName(), featureCategories);
MapValues mapValues = ExpressionUtil.createMapValues(feature.getName(), featureCategories, featureEncodings)
.setMapMissingTo(mapMissingTo)
.setDefaultValue(targetMean);
DerivedField derivedField = encoder.createDerivedField(createFieldName("targetEncoder", feature), OpType.CONTINUOUS, DataType.DOUBLE, mapValues);
result.add(new ContinuousFeature(encoder, derivedField));
}
return result;
}
public List<List<Number>> getEncodings(){
return getArrayList("encodings_", Number.class);
}
public Number getTargetMean(){
return getNumber("target_mean_");
}
public String getTargetType(){
return getEnum("target_type_", this::getString, Arrays.asList(TargetEncoder.TARGETTYPE_BINARY, TargetEncoder.TARGETTYPE_CONTINUOUS));
}
static
private Object getMissingCategory(List<?> categories){
for(Object category : categories){
if(ValueUtil.isNaN(category)){
return category;
}
}
return null;
}
private static final String TARGETTYPE_BINARY = "binary";
private static final String TARGETTYPE_CONTINUOUS = "continuous";
}