Skip to content

Commit

Permalink
Refactored the encoding of PMML elements. Fixes #11
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 5, 2019
1 parent 34cf9cc commit 7309fad
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions src/main/java/org/jpmml/converter/ModelEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.List;
import java.util.Map;

import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningField;
Expand All @@ -48,58 +47,53 @@ public class ModelEncoder extends PMMLEncoder {


public PMML encodePMML(Model model){
List<Model> transformers = getTransformers();

PMML pmml = encodePMML();

List<Model> transformers = getTransformers();
if(transformers.size() > 0){
List<Model> models = new ArrayList<>(transformers);
models.add(model);

model = MiningModelUtil.createModelChain(models);
}

pmml.addModels(model);

VisitorBattery modelCleanerBattery = new ModelCleanerBattery();
if(modelCleanerBattery.size() > 0){
modelCleanerBattery.applyTo(pmml);
}
if(model != null){
models.add(model);
}

MiningSchema miningSchema = model.getMiningSchema();
model = MiningModelUtil.createModelChain(models);
} // End if

List<MiningField> miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
FieldName name = miningField.getName();
if(model != null){
pmml.addModels(model);

List<Decorator> decorators = getDecorators(name);
if(decorators == null){
continue;
VisitorBattery modelCleanerBattery = new ModelCleanerBattery();
if(modelCleanerBattery.size() > 0){
modelCleanerBattery.applyTo(pmml);
}

DataField dataField = getDataField(name);
if(dataField == null){
throw new IllegalArgumentException();
}
MiningSchema miningSchema = model.getMiningSchema();

for(Decorator decorator : decorators){
decorator.decorate(dataField, miningField);
}
}
List<MiningField> miningFields = miningSchema.getMiningFields();
for(MiningField miningField : miningFields){
FieldName name = miningField.getName();

DataDictionary dataDictionary = pmml.getDataDictionary();
DataField dataField = getDataField(name);
if(dataField == null){
throw new IllegalArgumentException("Field " + name.getValue() + " is not referentiable");
}

List<DataField> dataFields = dataDictionary.getDataFields();
for(DataField dataField : dataFields){
UnivariateStats univariateStats = getUnivariateStats(dataField.getName());
List<Decorator> decorators = getDecorators(name);
if(decorators != null){

if(univariateStats == null){
continue;
}
for(Decorator decorator : decorators){
decorator.decorate(dataField, miningField);
}
}

ModelStats modelStats = ModelUtil.ensureModelStats(model);
UnivariateStats univariateStats = getUnivariateStats(name);
if(univariateStats != null){
ModelStats modelStats = ModelUtil.ensureModelStats(model);

modelStats.addUnivariateStats(univariateStats);
modelStats.addUnivariateStats(univariateStats);
}
}
}

Visitor pmmlCleaner = new AttributeCleaner();
Expand All @@ -120,6 +114,10 @@ public List<Decorator> getDecorators(FieldName name){
return this.decorators.get(name);
}

public void addDecorator(DataField dataField, Decorator decorator){
addDecorator(dataField.getName(), decorator);
}

public void addDecorator(FieldName name, Decorator decorator){
List<Decorator> decorators = this.decorators.get(name);

Expand Down

0 comments on commit 7309fad

Please sign in to comment.