Skip to content

Commit

Permalink
Add support for vector arithmetic.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Oct 9, 2024
1 parent f940975 commit c98645d
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 126 deletions.
7 changes: 4 additions & 3 deletions spin/src/main/java/com/msd/gin/halyard/spin/function/Add.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.function.BinaryFunction;
import org.eclipse.rdf4j.query.algebra.evaluation.util.MathUtil;

import com.msd.gin.halyard.strategy.MathOpEvaluator;

public class Add extends BinaryFunction {

Expand All @@ -28,8 +29,8 @@ public String getURI() {

@Override
protected Value evaluate(ValueFactory valueFactory, Value arg1, Value arg2) throws ValueExprEvaluationException {
if (arg1 instanceof Literal && arg2 instanceof Literal) {
return MathUtil.compute((Literal) arg1, (Literal) arg2, MathOp.PLUS);
if (arg1.isLiteral() && arg2.isLiteral()) {
return new MathOpEvaluator().evaluate((Literal) arg1, (Literal) arg2, MathOp.PLUS, valueFactory);
}

throw new ValueExprEvaluationException("Both arguments must be numeric literals");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.function.BinaryFunction;
import org.eclipse.rdf4j.query.algebra.evaluation.util.MathUtil;

import com.msd.gin.halyard.strategy.MathOpEvaluator;

public class Divide extends BinaryFunction {

Expand All @@ -28,8 +29,8 @@ public String getURI() {

@Override
protected Value evaluate(ValueFactory valueFactory, Value arg1, Value arg2) throws ValueExprEvaluationException {
if (arg1 instanceof Literal && arg2 instanceof Literal) {
return MathUtil.compute((Literal) arg1, (Literal) arg2, MathOp.DIVIDE);
if (arg1.isLiteral() && arg2.isLiteral()) {
return new MathOpEvaluator().evaluate((Literal) arg1, (Literal) arg2, MathOp.DIVIDE, valueFactory);
}

throw new ValueExprEvaluationException("Both arguments must be numeric literals");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.function.BinaryFunction;
import org.eclipse.rdf4j.query.algebra.evaluation.util.MathUtil;

import com.msd.gin.halyard.strategy.MathOpEvaluator;

public class Multiply extends BinaryFunction {

Expand All @@ -28,8 +29,8 @@ public String getURI() {

@Override
protected Value evaluate(ValueFactory valueFactory, Value arg1, Value arg2) throws ValueExprEvaluationException {
if (arg1 instanceof Literal && arg2 instanceof Literal) {
return MathUtil.compute((Literal) arg1, (Literal) arg2, MathOp.MULTIPLY);
if (arg1.isLiteral() && arg2.isLiteral()) {
return new MathOpEvaluator().evaluate((Literal) arg1, (Literal) arg2, MathOp.MULTIPLY, valueFactory);
}

throw new ValueExprEvaluationException("Both arguments must be numeric literals");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.function.BinaryFunction;
import org.eclipse.rdf4j.query.algebra.evaluation.util.MathUtil;

import com.msd.gin.halyard.strategy.MathOpEvaluator;

public class Subtract extends BinaryFunction {

Expand All @@ -28,8 +29,8 @@ public String getURI() {

@Override
protected Value evaluate(ValueFactory valueFactory, Value arg1, Value arg2) throws ValueExprEvaluationException {
if (arg1 instanceof Literal && arg2 instanceof Literal) {
return MathUtil.compute((Literal) arg1, (Literal) arg2, MathOp.MINUS);
if (arg1.isLiteral() && arg2.isLiteral()) {
return new MathOpEvaluator().evaluate((Literal) arg1, (Literal) arg2, MathOp.MINUS, valueFactory);
}

throw new ValueExprEvaluationException("Both arguments must be numeric literals");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public class HalyardEvaluationStrategy implements EvaluationStrategy {
* Evaluates ValueExpr expressions and all implementations of that interface
*/
private final HalyardValueExprEvaluation valueEval;
private final MathOpEvaluator mathOpEval = new MathOpEvaluator();

private final boolean isStrict = false;

Expand Down Expand Up @@ -218,6 +219,10 @@ boolean isStrict() {
return isStrict;
}

MathOpEvaluator getMathOpEvaluator() {
return mathOpEval;
}

StrategyConfig getConfig() {
return config;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ private BindingSetPipeEvaluationStep precompileGroup(final Group group, final Qu
ValueExpr arg = ((UnaryValueOperator)op).getArg();
QueryValueStepEvaluator evaluator;
if (arg != null) {
evaluator = new QueryValueStepEvaluator(parentStrategy.precompile(arg, evalContext));
evaluator = new QueryValueStepEvaluator(parentStrategy.precompile(arg, evalContext), tripleSource.getValueFactory());
} else {
evaluator = null;
}
Expand Down Expand Up @@ -1461,12 +1461,12 @@ private Supplier<Aggregator<?,?,?>> getAggregatorFactory(AggregateOperator opera
} else if (operator instanceof Sum) {
return () -> {
Predicate<Value> distinct = isDistinct ? createDistinctValues() : (Predicate<Value>) ALWAYS_TRUE;
return ThreadSafeAggregator.create(new SumAggregateFunction(), distinct, new NumberCollector());
return ThreadSafeAggregator.create(new SumAggregateFunction(), distinct, new NumberCollector(parentStrategy.getMathOpEvaluator()));
};
} else if (operator instanceof Avg) {
return () -> {
Predicate<Value> distinct = isDistinct ? createDistinctValues() : (Predicate<Value>) ALWAYS_TRUE;
return ThreadSafeAggregator.create(new AvgAggregateFunction(), distinct, new AvgCollector());
return ThreadSafeAggregator.create(new AvgAggregateFunction(), distinct, new AvgCollector(parentStrategy.getMathOpEvaluator()));
};
} else if (operator instanceof Sample) {
return () -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
import org.eclipse.rdf4j.query.algebra.evaluation.function.datetime.Now;
import org.eclipse.rdf4j.query.algebra.evaluation.impl.QueryEvaluationContext;
import org.eclipse.rdf4j.query.algebra.evaluation.util.QueryEvaluationUtility;
import org.eclipse.rdf4j.query.algebra.evaluation.util.XMLDatatypeMathUtil;
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;

/**
Expand Down Expand Up @@ -1061,7 +1060,7 @@ private ValuePipeEvaluationStep precompileMathExpr(MathExpr node, QueryEvaluatio
Value leftVal = leftValue.get();
Value rightVal = rightValue.get();
if (leftVal.isLiteral() && rightVal.isLiteral()) {
return ValueOrError.of(() -> XMLDatatypeMathUtil.compute((Literal)leftVal, (Literal)rightVal, node.getOperator(), valueFactory));
return ValueOrError.of(() -> parentStrategy.getMathOpEvaluator().evaluate((Literal)leftVal, (Literal)rightVal, node.getOperator(), valueFactory));
} else {
return ValueOrError.fail("Both arguments must be numeric literals");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package com.msd.gin.halyard.strategy;

import com.msd.gin.halyard.model.ArrayLiteral;
import com.msd.gin.halyard.model.vocabulary.HALYARD;

import javax.annotation.concurrent.ThreadSafe;

import org.eclipse.rdf4j.model.IRI;
import org.eclipse.rdf4j.model.Literal;
import org.eclipse.rdf4j.model.ValueFactory;
import org.eclipse.rdf4j.model.base.CoreDatatype;
import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.util.XMLDatatypeMathUtil;

@ThreadSafe
public class MathOpEvaluator {
public Literal evaluate(Literal a, Literal b, MathOp op, ValueFactory vf) {
try {
return XMLDatatypeMathUtil.compute(a, b, op, vf);
} catch (ValueExprEvaluationException ex) {
IRI adt = a.getDatatype();
IRI bdt = b.getDatatype();
boolean aisvec = HALYARD.ARRAY_TYPE.equals(adt);
boolean bisvec = HALYARD.ARRAY_TYPE.equals(bdt);
if (aisvec) {
if (bisvec) {
return operationBetweenVectors(a, b, op, vf);
} else if (op == MathOp.DIVIDE) {
CoreDatatype.XSD bcdt = b.getCoreDatatype().asXSDDatatypeOrNull();
if (bcdt != null && bcdt.isNumericDatatype()) {
return operationVectorDivideScalar(a, b, op, vf);
}
}
} else if (bisvec) {
if (aisvec) {
return operationBetweenVectors(a, b, op, vf);
} else if (op == MathOp.MULTIPLY) {
CoreDatatype.XSD acdt = a.getCoreDatatype().asXSDDatatypeOrNull();
if (acdt != null && acdt.isNumericDatatype()) {
return operationScalarMultiplyVector(a, b, op, vf);
}
}
}
throw ex;
}
}

private static Literal operationBetweenVectors(Literal a, Literal b, MathOp op, ValueFactory vf) {
Object[] aarr = ArrayLiteral.objectArray(a);
Object[] barr = ArrayLiteral.objectArray(b);
if (aarr.length != barr.length) {
throw new ValueExprEvaluationException("Arrays have incompatible dimensions");
}
try {
switch (op) {
case PLUS:
return new ArrayLiteral(add(aarr, barr));
case MINUS:
return new ArrayLiteral(subtract(aarr, barr));
default:
throw new AssertionError("Unsupported operator: " + op);
}
} catch (ClassCastException ex) {
throw new ValueExprEvaluationException(ex);
}
}

private static Object[] add(Object[] a, Object[] b) {
Object[] y = new Object[a.length];
for (int i=0; i<a.length; i++) {
if (a[i] instanceof Double || b[i] instanceof Double) {
y[i] = ((Number) a[i]).doubleValue() + ((Number) b[i]).doubleValue();
} else if (a[i] instanceof Float || b[i] instanceof Float) {
y[i] = ((Number) a[i]).floatValue() + ((Number) b[i]).floatValue();
} else if (a[i] instanceof Long || b[i] instanceof Long) {
y[i] = ((Number) a[i]).longValue() + ((Number) b[i]).longValue();
} else {
y[i] = ((Number) a[i]).intValue() + ((Number) b[i]).intValue();
}
}
return y;
}

private static Object[] subtract(Object[] a, Object[] b) {
Object[] y = new Object[a.length];
for (int i=0; i<a.length; i++) {
if (a[i] instanceof Double || b[i] instanceof Double) {
y[i] = ((Number) a[i]).doubleValue() - ((Number) b[i]).doubleValue();
} else if (a[i] instanceof Float || b[i] instanceof Float) {
y[i] = ((Number) a[i]).floatValue() - ((Number) b[i]).floatValue();
} else if (a[i] instanceof Long || b[i] instanceof Long) {
y[i] = ((Number) a[i]).longValue() - ((Number) b[i]).longValue();
} else {
y[i] = ((Number) a[i]).intValue() - ((Number) b[i]).intValue();
}
}
return y;
}

private static Literal operationScalarMultiplyVector(Literal scalar, Literal vec, MathOp op, ValueFactory vf) {
CoreDatatype.XSD sdt = scalar.getCoreDatatype().asXSDDatatype().get();
Object[] arr = ArrayLiteral.objectArray(vec);
Object[] y = new Object[arr.length];
try {
for (int i=0; i<arr.length; i++) {
if (sdt == CoreDatatype.XSD.DOUBLE || sdt == CoreDatatype.XSD.DECIMAL || arr[i] instanceof Double) {
y[i] = scalar.doubleValue() * ((Number) arr[i]).doubleValue();
} else if (sdt == CoreDatatype.XSD.FLOAT || arr[i] instanceof Float) {
y[i] = scalar.floatValue() * ((Number) arr[i]).floatValue();
} else if (sdt == CoreDatatype.XSD.LONG || sdt == CoreDatatype.XSD.INTEGER || arr[i] instanceof Long) {
y[i] = scalar.longValue() * ((Number) arr[i]).longValue();
} else {
y[i] = scalar.intValue() * ((Number) arr[i]).intValue();
}
}
} catch (ClassCastException ex) {
throw new ValueExprEvaluationException(ex);
}
return new ArrayLiteral(y);
}

private static Literal operationVectorDivideScalar(Literal vec, Literal scalar, MathOp op, ValueFactory vf) {
CoreDatatype.XSD sdt = scalar.getCoreDatatype().asXSDDatatype().get();
Object[] arr = ArrayLiteral.objectArray(vec);
Object[] y = new Object[arr.length];
try {
for (int i=0; i<arr.length; i++) {
if (sdt == CoreDatatype.XSD.DOUBLE || sdt == CoreDatatype.XSD.DECIMAL || arr[i] instanceof Double) {
y[i] = ((Number) arr[i]).doubleValue() / scalar.doubleValue();
} else if (sdt == CoreDatatype.XSD.FLOAT || arr[i] instanceof Float) {
y[i] = ((Number) arr[i]).floatValue() / scalar.floatValue();
} else if (sdt == CoreDatatype.XSD.LONG || sdt == CoreDatatype.XSD.INTEGER || arr[i] instanceof Long) {
y[i] = ((Number) arr[i]).doubleValue() / scalar.doubleValue();
} else {
y[i] = ((Number) arr[i]).floatValue() / scalar.floatValue();
}
}
} catch (ClassCastException ex) {
throw new ValueExprEvaluationException(ex);
}
return new ArrayLiteral(y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import java.util.function.Function;

import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.model.ValueFactory;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.algebra.evaluation.QueryValueEvaluationStep;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;

public final class QueryValueStepEvaluator implements Function<BindingSet, Value> {
private final QueryValueEvaluationStep step;
private final ValueFactory vf;

QueryValueStepEvaluator(QueryValueEvaluationStep step) {
QueryValueStepEvaluator(QueryValueEvaluationStep step, ValueFactory vf) {
this.step = step;
this.vf = vf;
}

@Override
Expand All @@ -22,4 +25,8 @@ public Value apply(BindingSet bs) {
return null; // treat missing or invalid expressions as null
}
}

public ValueFactory getValueFactory() {
return vf;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import org.eclipse.rdf4j.model.Literal;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.model.base.CoreDatatype;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;

Expand All @@ -25,20 +24,18 @@ public void processAggregate(BindingSet bs, Predicate<Value> distinctPredicate,
if (v.isLiteral()) {
if (distinctPredicate.test(v)) {
Literal nextLiteral = (Literal) v;
// check if the literal is numeric.
CoreDatatype coreDatatype = nextLiteral.getCoreDatatype();
if (coreDatatype.isXSDDatatype() && ((CoreDatatype.XSD) coreDatatype).isNumericDatatype()) {
col.addValue(nextLiteral);
} else {
col.setError(new ValueExprEvaluationException("not a number: " + v));
try {
col.addValue(nextLiteral, evaluationStep.getValueFactory());
col.incrementCount();
} catch (ValueExprEvaluationException ex) {
col.setError(ex);
}
col.incrementCount();
}
} else {
// we do not actually throw the exception yet, but record it and
// stop further processing. The exception will be thrown when
// getValue() is invoked.
col.setError(new ValueExprEvaluationException("not a number: " + v));
col.setError(new ValueExprEvaluationException("not a literal: " + v));
}
}
}
Expand Down
Loading

0 comments on commit c98645d

Please sign in to comment.