From 7309fad445a87feec992c2d79451396cdef8c74a Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Fri, 5 Jul 2019 15:16:21 +0300 Subject: [PATCH] Refactored the encoding of PMML elements. Fixes #11 --- .../org/jpmml/converter/ModelEncoder.java | 72 +++++++++---------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/src/main/java/org/jpmml/converter/ModelEncoder.java b/src/main/java/org/jpmml/converter/ModelEncoder.java index 0a4f5fe3..372b1a54 100644 --- a/src/main/java/org/jpmml/converter/ModelEncoder.java +++ b/src/main/java/org/jpmml/converter/ModelEncoder.java @@ -23,7 +23,6 @@ import java.util.List; import java.util.Map; -import org.dmg.pmml.DataDictionary; import org.dmg.pmml.DataField; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningField; @@ -48,58 +47,53 @@ public class ModelEncoder extends PMMLEncoder { public PMML encodePMML(Model model){ - List transformers = getTransformers(); - PMML pmml = encodePMML(); + List transformers = getTransformers(); if(transformers.size() > 0){ List models = new ArrayList<>(transformers); - models.add(model); - model = MiningModelUtil.createModelChain(models); - } - - pmml.addModels(model); - - VisitorBattery modelCleanerBattery = new ModelCleanerBattery(); - if(modelCleanerBattery.size() > 0){ - modelCleanerBattery.applyTo(pmml); - } + if(model != null){ + models.add(model); + } - MiningSchema miningSchema = model.getMiningSchema(); + model = MiningModelUtil.createModelChain(models); + } // End if - List miningFields = miningSchema.getMiningFields(); - for(MiningField miningField : miningFields){ - FieldName name = miningField.getName(); + if(model != null){ + pmml.addModels(model); - List decorators = getDecorators(name); - if(decorators == null){ - continue; + VisitorBattery modelCleanerBattery = new ModelCleanerBattery(); + if(modelCleanerBattery.size() > 0){ + modelCleanerBattery.applyTo(pmml); } - DataField dataField = getDataField(name); - if(dataField == null){ - throw new IllegalArgumentException(); - } + MiningSchema miningSchema = model.getMiningSchema(); - for(Decorator decorator : decorators){ - decorator.decorate(dataField, miningField); - } - } + List miningFields = miningSchema.getMiningFields(); + for(MiningField miningField : miningFields){ + FieldName name = miningField.getName(); - DataDictionary dataDictionary = pmml.getDataDictionary(); + DataField dataField = getDataField(name); + if(dataField == null){ + throw new IllegalArgumentException("Field " + name.getValue() + " is not referentiable"); + } - List dataFields = dataDictionary.getDataFields(); - for(DataField dataField : dataFields){ - UnivariateStats univariateStats = getUnivariateStats(dataField.getName()); + List decorators = getDecorators(name); + if(decorators != null){ - if(univariateStats == null){ - continue; - } + for(Decorator decorator : decorators){ + decorator.decorate(dataField, miningField); + } + } - ModelStats modelStats = ModelUtil.ensureModelStats(model); + UnivariateStats univariateStats = getUnivariateStats(name); + if(univariateStats != null){ + ModelStats modelStats = ModelUtil.ensureModelStats(model); - modelStats.addUnivariateStats(univariateStats); + modelStats.addUnivariateStats(univariateStats); + } + } } Visitor pmmlCleaner = new AttributeCleaner(); @@ -120,6 +114,10 @@ public List getDecorators(FieldName name){ return this.decorators.get(name); } + public void addDecorator(DataField dataField, Decorator decorator){ + addDecorator(dataField.getName(), decorator); + } + public void addDecorator(FieldName name, Decorator decorator){ List decorators = this.decorators.get(name);