Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer entry point #116

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pmml-evaluator-example/src/main/resources/regression.pmml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<PMML xmlns="http://www.dmg.org/PMML-4_2" version="4.2">
<Header copyright="DMG.org"/>
<DataDictionary numberOfFields="2">
<DataField name="X" optype="continuous" dataType="double"/>
<DataField name="y" optype="continuous" dataType="double"/>
</DataDictionary>
<TransformationDictionary>
<DerivedField name="Y1" dataType="double" optype="continuous">
<Apply function="pow">
<FieldRef field="X"/>
<Constant>2.0</Constant>
</Apply>
</DerivedField>
</TransformationDictionary>
<RegressionModel modelName="Simple linear regression" functionName="regression" algorithmName="linearRegression" targetFieldName="y">
<MiningSchema>
<MiningField name="X"/>
<MiningField name="y" usageType="target"/>
</MiningSchema>
<RegressionTable intercept="0.0">
<NumericPredictor name="Y1" exponent="1" coefficient="2.0"/>
</RegressionTable>
</RegressionModel>
</PMML>
107 changes: 107 additions & 0 deletions pmml-evaluator/src/main/java/org/jpmml/evaluator/Transformer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package org.jpmml.evaluator;

import java.util.*;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import org.dmg.pmml.*;

public class Transformer {

private final PMML pmml;

private final Map<FieldName, DataField> dataFields;

private final Map<FieldName, DerivedField> derivedFields;


public Transformer(PMML pmml) {
Objects.requireNonNull(pmml);
this.pmml = pmml;

DataDictionary dataDictionary = pmml.getDataDictionary();
if (dataDictionary == null) {
throw new MissingElementException(pmml, PMMLElements.PMML_DATADICTIONARY);
} // End if

if (dataDictionary.hasDataFields()) {
this.dataFields = CacheUtil.getValue(dataDictionary, Transformer.dataFieldCache);
} else {
this.dataFields = Collections.emptyMap();
}

TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
if (transformationDictionary != null && transformationDictionary.hasDerivedFields()) {
this.derivedFields = CacheUtil.getValue(transformationDictionary, Transformer.derivedFieldCache);
} else {
this.derivedFields = Collections.emptyMap();
}
}

/**
* <p>
* Gets a short description of the {@link Transformer}.
* </p>
*/
public String getSummary(){
return "Transformer";
}

public DataField getDataField(FieldName name) {
return this.dataFields.get(name);
}

public DerivedField getDerivedField(FieldName name) {
return this.derivedFields.get(name);
}

private static final LoadingCache<DataDictionary, Map<FieldName, DataField>> dataFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<DataDictionary, Map<FieldName, DataField>>() {

@Override
public Map<FieldName, DataField> load(DataDictionary dataDictionary) {
return IndexableUtil.buildMap(dataDictionary.getDataFields());
}
});

private static final LoadingCache<TransformationDictionary, Map<FieldName, DerivedField>> derivedFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<TransformationDictionary, Map<FieldName, DerivedField>>() {

@Override
public Map<FieldName, DerivedField> load(TransformationDictionary transformationDictionary) {
return IndexableUtil.buildMap(transformationDictionary.getDerivedFields());
}
});

/**
* <p>
* Gets the transformed output fields.
* </p>
*/
public List<DerivedField> getTransformFields() {
return new LinkedList<>(this.derivedFields.values());
}

public List<DataField> getArgumentFields() {
return new LinkedList<>(this.dataFields.values());
}

public Map<FieldName, ?> evaluate(Map<FieldName, ?> arguments) {
TransformerContext context = new TransformerContext(this);
context.setArguments(arguments);

return evaluate(context);
}

public Map<FieldName, ?> evaluate(TransformerContext context) {
Map<FieldName, Object> result = new LinkedHashMap<>();
List<DerivedField> derivedFields = new LinkedList<>(getTransformFields());
for (DerivedField derivedField : derivedFields) {
FieldValue value = context.evaluate(derivedField.getName());
result.put(derivedField.getName(), value);
}
return result;
}

public PMML getPMML() {
return pmml;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package org.jpmml.evaluator;

import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldName;

import java.util.Collections;
import java.util.Map;

public class TransformerContext extends EvaluationContext {
private final Transformer transformer;
private Map<FieldName, ?> arguments = Collections.emptyMap();

TransformerContext(Transformer transformer) {
this.transformer = transformer;
}

public Map<FieldName, ?> getArguments() {
return this.arguments;
}

public void setArguments(Map<FieldName, ?> arguments) {
this.arguments = arguments;
}

@Override
protected FieldValue resolve(FieldName name) {
Transformer transformer = getTransformer();
DataField dataField = transformer.getDataField(name);
// Fields that either need not or must not be referenced in the MiningSchema element
if (dataField == null) {
DerivedField derivedField = transformer.getDerivedField(name);
if (derivedField != null) {
FieldValue value = ExpressionUtil.evaluateTypedExpressionContainer(derivedField, this);
return declare(name, value);
}
} else

// Fields that must be referenced in the DataDictionary element
{
Map<FieldName, ?> arguments = getArguments();
Object value = arguments.get(name);
if (value == null) {
return declareMissing(name);
}
return declare(name, value);
}

throw new MissingFieldException(name);
}

/**
* Declare a field as Missing one
*
* @param name The name of the field
* @return The field value (would be 'null' also)
*/
private FieldValue declareMissing(FieldName name) {
// Casting should stay in place in order to avoid calling 'declare(FieldName,Object)'
// which would result in exception
return declare(name, (FieldValue) null);
}

@Override
protected FieldValue prepare(FieldName name, Object value) {
throw new UnsupportedOperationException();
}

@Override
public void reset(boolean purge) {
super.reset(purge);

this.arguments = Collections.emptyMap();
}

private Transformer getTransformer() {
return transformer;
}
}
158 changes: 158 additions & 0 deletions pmml-evaluator/src/test/java/org/jpmml/evaluator/TransformerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package org.jpmml.evaluator;

import org.dmg.pmml.Apply;
import org.dmg.pmml.Constant;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.model.PMMLUtil;
import org.junit.Before;
import org.junit.Test;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;

public class TransformerTest {
private Transformer transformer;

static private PMML readPMML(File file) throws Exception {
try (InputStream is = new FileInputStream(file)) {
return PMMLUtil.unmarshal(is);
}
}

@Before
public void setUp() throws Exception {
transformer = new Transformer(readPMML(new File("src/test/resources/pmml/basicTransformation.pmml")));
}

@Test
public void getSummary() {
assertEquals(transformer.getSummary(), "Transformer");
}

@Test
public void getDataField() {
DataField xField = transformer.getDataField(new FieldName("X"));
assertEquals(DataType.DOUBLE, xField.getDataType());
assertEquals("X", xField.getName().getValue());
assertEquals(OpType.CONTINUOUS, xField.getOpType());
}

@Test
public void getDerivedField() {
DerivedField derivedField = transformer.getDerivedField(new FieldName("Y1"));
assertEquals("Y1", derivedField.getName().getValue());
assertEquals(DataType.DOUBLE, derivedField.getDataType());
assertEquals(OpType.CONTINUOUS, derivedField.getOpType());
Expression exp = derivedField.getExpression();
assert exp instanceof Constant;
Constant cons = (Constant) exp;
assertEquals("1.0", cons.getValue());

derivedField = transformer.getDerivedField(new FieldName("spec"));
assertEquals("spec", derivedField.getName().getValue());
assertEquals(DataType.STRING, derivedField.getDataType());
assertEquals(OpType.CATEGORICAL, derivedField.getOpType());
exp = derivedField.getExpression();
assert exp instanceof Constant;
cons = (Constant) exp;
assertEquals("CAT", cons.getValue());

derivedField = transformer.getDerivedField(new FieldName("Y3"));
assertEquals("Y3", derivedField.getName().getValue());
assertEquals(DataType.DOUBLE, derivedField.getDataType());
assertEquals(OpType.CONTINUOUS, derivedField.getOpType());
exp = derivedField.getExpression();
assert exp instanceof Apply;
}

@Test
public void getTransformFields() {
List<DerivedField> transformFields = transformer.getTransformFields();
assertEquals(4, transformFields.size());
assertEquals("Y1", transformFields.get(0).getName().getValue());
assertEquals("spec", transformFields.get(1).getName().getValue());
assertEquals("Y2", transformFields.get(2).getName().getValue());
assertEquals("Y3", transformFields.get(3).getName().getValue());
}

@Test
public void getArgumentFields() {
List<DataField> argumentFields = transformer.getArgumentFields();
assertEquals(2, argumentFields.size());
assertEquals("X", argumentFields.get(0).getName().getValue());
assertEquals("Z", argumentFields.get(1).getName().getValue());
}

@Test
public void evaluate() {
Map<String, String> requestArguments = new HashMap<>();
requestArguments.put("X", "2.0");
requestArguments.put("Z", "2");

Map<FieldName, FieldValue> arguments = getArgumentsFromRequest(requestArguments);
Map<FieldName, ?> result = transformer.evaluate(arguments);

assertEquals(4, result.size());

Object y1 = result.get(new FieldName("Y1"));
assert y1 instanceof ContinuousValue;
assertEquals(1.0, ((ContinuousValue) y1).getValue());

Object spec = result.get(new FieldName("spec"));
assert spec instanceof CategoricalValue;
assertEquals("CAT", ((CategoricalValue) spec).getValue());

Object y2 = result.get(new FieldName("Y2"));
assert y2 instanceof ContinuousValue;
assertEquals(2.0, ((ContinuousValue) y2).getValue());

Object y3 = result.get(new FieldName("Y3"));
assert y3 instanceof ContinuousValue;
assertEquals(4.0, ((ContinuousValue) y3).getValue());

requestArguments.put("Z", "4");
arguments = getArgumentsFromRequest(requestArguments);
result = transformer.evaluate(arguments);
y3 = result.get(new FieldName("Y3"));
assert y3 instanceof ContinuousValue;
assertEquals(16.0, ((ContinuousValue) y3).getValue());
}

@Test(expected = FunctionException.class)
public void evaluateMissingArgument() {
Map<String, String> requestArguments = new HashMap<>();
requestArguments.put("X", "2.0");

Map<FieldName, FieldValue> arguments = getArgumentsFromRequest(requestArguments);
transformer.evaluate(arguments);
}

private Map<FieldName, FieldValue> getArgumentsFromRequest(Map<String, String> requestArguments) {
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
List<DataField> argumentFields = transformer.getArgumentFields();
for (DataField argumentField : argumentFields) {
FieldName activeName = argumentField.getName();
String key = activeName.getValue();
Object value = requestArguments.get(key);

FieldValue fieldValue = FieldValueUtil
.create(argumentField.getDataType(), argumentField.getOpType(), value);

arguments.put(activeName, fieldValue);
}
return arguments;
}
}
23 changes: 23 additions & 0 deletions pmml-evaluator/src/test/resources/pmml/basicTransformation.pmml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
<PMML xmlns="http://www.dmg.org/PMML-4_0" version="4.0">
<DataDictionary>
<DataField name="X" optype="continuous" dataType="double"/>
<DataField name="Z" optype="continuous" dataType="double"/>
</DataDictionary>
<TransformationDictionary>
<DerivedField name="Y1" dataType="double" optype="continuous">
<Constant>1.0</Constant>
</DerivedField>
<DerivedField name="spec" dataType="string" optype="categorical">
<Constant>CAT</Constant>
</DerivedField>
<DerivedField name="Y2" dataType="double" optype="continuous">
<FieldRef field="X"/>
</DerivedField>
<DerivedField name="Y3" dataType="double" optype="continuous">
<Apply function="pow">
<FieldRef field="Z"/>
<Constant>2.0</Constant>
</Apply>
</DerivedField>
</TransformationDictionary>
</PMML>