diff --git a/model/src/main/java/com/msd/gin/halyard/model/AbstractArrayLiteral.java b/model/src/main/java/com/msd/gin/halyard/model/AbstractArrayLiteral.java new file mode 100644 index 000000000..6166ad28b --- /dev/null +++ b/model/src/main/java/com/msd/gin/halyard/model/AbstractArrayLiteral.java @@ -0,0 +1,135 @@ +package com.msd.gin.halyard.model; + +import java.util.Arrays; + +import org.eclipse.rdf4j.model.IRI; +import org.eclipse.rdf4j.model.Literal; +import org.eclipse.rdf4j.model.Value; +import org.eclipse.rdf4j.model.ValueFactory; +import org.eclipse.rdf4j.model.base.CoreDatatype; +import org.eclipse.rdf4j.model.base.CoreDatatype.XSD; +import org.eclipse.rdf4j.model.util.Values; +import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException; + +import com.msd.gin.halyard.model.vocabulary.HALYARD; + +public abstract class AbstractArrayLiteral extends AbstractDataLiteral implements ObjectLiteral { + private static final long serialVersionUID = -6423024672894102212L; + + public static boolean isArrayLiteral(Value v) { + return v != null && v.isLiteral() && HALYARD.ARRAY_TYPE.equals(((Literal)v).getDatatype()); + } + + public static Value[] toValues(Object[] oarr, ValueFactory vf) { + Value[] varr = new Value[oarr.length]; + for (int i=0; i createFromValues(Value[] values) { + AbstractArrayLiteral arrLiteral = null; + if (values.length > 0) { + Literal l = asLiteral(values[0]); + if (l.getCoreDatatype().asXSDDatatypeOrNull() == XSD.FLOAT) { + float[] farr = new float[values.length]; + farr[0] = l.floatValue(); + for (int i=1; i other = (AbstractArrayLiteral) o; + return Arrays.equals(elements(), other.elements()); + } else { + return super.equals(o); + } + } +} diff --git a/model/src/main/java/com/msd/gin/halyard/model/ArrayLiteral.java b/model/src/main/java/com/msd/gin/halyard/model/ArrayLiteral.java deleted file mode 100644 index 3c421e661..000000000 --- a/model/src/main/java/com/msd/gin/halyard/model/ArrayLiteral.java +++ /dev/null @@ -1,134 +0,0 @@ -package com.msd.gin.halyard.model; - -import java.util.Arrays; - -import org.eclipse.rdf4j.model.IRI; -import org.eclipse.rdf4j.model.Literal; -import org.eclipse.rdf4j.model.Value; -import org.eclipse.rdf4j.model.ValueFactory; -import org.eclipse.rdf4j.model.base.CoreDatatype; -import org.eclipse.rdf4j.model.base.CoreDatatype.XSD; -import org.eclipse.rdf4j.model.util.Values; -import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException; -import org.json.JSONArray; - -import com.msd.gin.halyard.model.vocabulary.HALYARD; - -public final class ArrayLiteral extends AbstractDataLiteral implements ObjectLiteral { - private static final long serialVersionUID = -6399155325720068478L; - - public static boolean isArrayLiteral(Value v) { - return v != null && v.isLiteral() && HALYARD.ARRAY_TYPE.equals(((Literal)v).getDatatype()); - } - - public static Object[] objectArray(Literal l) { - if (l instanceof ArrayLiteral) { - return ((ArrayLiteral)l).values; - } else { - return parse(l.getLabel()); - } - } - - public static Value[] toValues(Object[] oarr, ValueFactory vf) { - Value[] varr = new Value[oarr.length]; - for (int i=0; i { + private static final long serialVersionUID = -5786623818445151749L; + + private final float[] values; + + public FloatArrayLiteral(float... values) { + this.values = values; + } + + @Override + public String getLabel() { + JSONArray arr = new JSONArray(); + for (float o : this.values) { + arr.put(o); + } + return arr.toString(0); + } + + @Override + public float[] objectValue() { + return values; + } + + @Override + public Object[] elements() { + Object[] arr = new Object[values.length]; + for (int i=0; i { + private static final long serialVersionUID = 6409948385114916596L; + + private final Object[] values; + + public ObjectArrayLiteral(String s) { + this.values = parse(s); + } + + public ObjectArrayLiteral(Object... values) { + this.values = values; + } + + @Override + public String getLabel() { + JSONArray arr = new JSONArray(); + for (Object o : this.values) { + arr.put(o); + } + return arr.toString(0); + } + + @Override + public Object[] objectValue() { + return values; + } + + @Override + public Object[] elements() { + return values; + } + + @Override + public int length() { + return values.length; + } + + public static Object[] objectArray(Literal l) { + if (l instanceof AbstractArrayLiteral) { + return ((AbstractArrayLiteral)l).elements(); + } else { + return parse(l.getLabel()); + } + } + + private static Object[] parse(CharSequence s) { + JSONArray arr = new JSONArray(s.toString()); + int len = arr.length(); + Object[] values = new Object[len]; + for (int i=0; i) { return new MapLiteral((Map)v); } else { diff --git a/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/ArrayTest.java b/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/ArrayTest.java index 98fcb0e55..b84dda480 100644 --- a/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/ArrayTest.java +++ b/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/ArrayTest.java @@ -3,13 +3,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import org.eclipse.rdf4j.model.Literal; import org.eclipse.rdf4j.model.Value; import org.eclipse.rdf4j.model.ValueFactory; import org.eclipse.rdf4j.query.algebra.evaluation.TripleSource; import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException; import org.junit.jupiter.api.Test; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.FloatArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.query.algebra.evaluation.EmptyTripleSource; public class ArrayTest { @@ -19,9 +21,34 @@ public void test() { ValueFactory vf = ts.getValueFactory(); Value v1 = vf.createLiteral("foobar"); Value v2 = vf.createLiteral(5); - ArrayLiteral l = (ArrayLiteral) new Array().evaluate(ts, v1, v2); - assertEquals("foobar", l.objectValue()[0]); - assertEquals(5, l.objectValue()[1]); + Object[] objs = ObjectArrayLiteral.objectArray((Literal) new Array().evaluate(ts, v1, v2)); + assertEquals("foobar", objs[0]); + assertEquals(5, objs[1]); + } + + @Test + public void test_float() { + TripleSource ts = new EmptyTripleSource(); + ValueFactory vf = ts.getValueFactory(); + Value v1 = vf.createLiteral(2.5f); + Value v2 = vf.createLiteral(5f); + Object[] objs = ObjectArrayLiteral.objectArray((Literal) new Array().evaluate(ts, v1, v2)); + assertEquals(2.5f, objs[0]); + assertEquals(5f, objs[1]); + float[] farr = FloatArrayLiteral.floatArray((Literal) new Array().evaluate(ts, v1, v2)); + assertEquals(2.5f, farr[0]); + assertEquals(5f, farr[1]); + } + + @Test + public void test_mixed() { + TripleSource ts = new EmptyTripleSource(); + ValueFactory vf = ts.getValueFactory(); + Value v1 = vf.createLiteral(-0.15f); + Value v2 = vf.createLiteral("foobar"); + Object[] objs = ObjectArrayLiteral.objectArray((Literal) new Array().evaluate(ts, v1, v2)); + assertEquals(-0.15f, objs[0]); + assertEquals("foobar", objs[1]); } @Test diff --git a/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/CosineSimilarityTest.java b/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/CosineSimilarityTest.java index bb7de8591..7bcebd1b6 100644 --- a/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/CosineSimilarityTest.java +++ b/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/CosineSimilarityTest.java @@ -7,15 +7,25 @@ import org.eclipse.rdf4j.query.algebra.evaluation.TripleSource; import org.junit.jupiter.api.Test; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.FloatArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.query.algebra.evaluation.EmptyTripleSource; public class CosineSimilarityTest { @Test public void test() { TripleSource ts = new EmptyTripleSource(); - Value v1 = new ArrayLiteral(1, 0); - Value v2 = new ArrayLiteral(1, 1); + Value v1 = new FloatArrayLiteral(1.0f, 0.0f); + Value v2 = new FloatArrayLiteral(1.0f, 1.0f); + Literal l = (Literal) new CosineSimilarity().evaluate(ts, v1, v2); + assertEquals(1.0/Math.sqrt(2), l.doubleValue()); + } + + @Test + public void test_objectArray() { + TripleSource ts = new EmptyTripleSource(); + Value v1 = new ObjectArrayLiteral(1.0f, 1.0f); + Value v2 = new ObjectArrayLiteral(1.0f, 0.0f); Literal l = (Literal) new CosineSimilarity().evaluate(ts, v1, v2); assertEquals(1.0/Math.sqrt(2), l.doubleValue()); } diff --git a/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/DynamicFunctionRegistryTest.java b/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/DynamicFunctionRegistryTest.java index 3a3e96e0b..0a2ced5d9 100644 --- a/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/DynamicFunctionRegistryTest.java +++ b/queryalgebra/src/test/java/com/msd/gin/halyard/query/algebra/evaluation/function/DynamicFunctionRegistryTest.java @@ -13,8 +13,8 @@ import org.eclipse.rdf4j.query.algebra.evaluation.TripleSource; import org.junit.jupiter.api.Test; -import com.msd.gin.halyard.model.ArrayLiteral; import com.msd.gin.halyard.model.MapLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.query.algebra.evaluation.EmptyTripleSource; public class DynamicFunctionRegistryTest { @@ -54,8 +54,8 @@ public void testIsWholeNumber() { public void testArray_put() { TripleSource ts = new EmptyTripleSource(); ValueFactory vf = ts.getValueFactory(); - ArrayLiteral result = (ArrayLiteral) new DynamicFunctionRegistry().get("http://www.w3.org/2005/xpath-functions/array#put").get().evaluate(ts, - new ArrayLiteral("foo", "bar"), vf.createLiteral(2), vf.createLiteral(5)); + ObjectArrayLiteral result = (ObjectArrayLiteral) new DynamicFunctionRegistry().get("http://www.w3.org/2005/xpath-functions/array#put").get().evaluate(ts, + new ObjectArrayLiteral("foo", "bar"), vf.createLiteral(2), vf.createLiteral(5)); assertEquals(2, result.objectValue().length); // NB: xsd:ints get coerced to xsd:integers assertEquals(BigInteger.valueOf(5), result.objectValue()[1]); @@ -65,7 +65,7 @@ public void testArray_put() { public void testArray_size() { TripleSource ts = new EmptyTripleSource(); ValueFactory vf = ts.getValueFactory(); - ArrayLiteral array = new ArrayLiteral("foo", "bar"); + ObjectArrayLiteral array = new ObjectArrayLiteral("foo", "bar"); // check works for non-specialist Literal type Literal unparsedArray = vf.createLiteral(array.getLabel(), array.getDatatype()); Literal result = (Literal) new DynamicFunctionRegistry().get("http://www.w3.org/2005/xpath-functions/array#size").get().evaluate(ts, diff --git a/sail/src/main/java/com/msd/gin/halyard/sail/model/embedding/function/VectorEmbedding.java b/sail/src/main/java/com/msd/gin/halyard/sail/model/embedding/function/VectorEmbedding.java index 5f3f82d7b..c5bce8c0d 100644 --- a/sail/src/main/java/com/msd/gin/halyard/sail/model/embedding/function/VectorEmbedding.java +++ b/sail/src/main/java/com/msd/gin/halyard/sail/model/embedding/function/VectorEmbedding.java @@ -1,6 +1,6 @@ package com.msd.gin.halyard.sail.model.embedding.function; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.FloatArrayLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import com.msd.gin.halyard.query.algebra.evaluation.ExtendedTripleSource; @@ -44,10 +44,6 @@ public Value evaluate(TripleSource ts, Value... args) throws ValueExprEvaluation ExtendedTripleSource extTs = (ExtendedTripleSource) ts; Response resp = extTs.getQueryHelper(EmbeddingModel.class).embed(l.getLabel()); float[] vec = resp.content().vector(); - Object[] arr = new Object[vec.length]; - for (int i = 0; i < vec.length; i++) { - arr[i] = Float.valueOf(vec[i]); - } - return new ArrayLiteral(arr); + return new FloatArrayLiteral(vec); } } diff --git a/sail/src/main/java/com/msd/gin/halyard/sail/search/KNNTupleFunction.java b/sail/src/main/java/com/msd/gin/halyard/sail/search/KNNTupleFunction.java index bbc280c95..c51d9c07f 100644 --- a/sail/src/main/java/com/msd/gin/halyard/sail/search/KNNTupleFunction.java +++ b/sail/src/main/java/com/msd/gin/halyard/sail/search/KNNTupleFunction.java @@ -2,7 +2,7 @@ import com.msd.gin.halyard.common.RDFFactory; import com.msd.gin.halyard.common.StatementIndices; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.ObjectLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import com.msd.gin.halyard.query.algebra.evaluation.ExtendedTripleSource; @@ -46,7 +46,7 @@ public CloseableIteration> evaluate(TripleSource throw new QueryEvaluationException("Invalid query value"); } int argPos = 0; - Object[] query = ArrayLiteral.objectArray((Literal) args[argPos++]); + Object[] query = ObjectArrayLiteral.objectArray((Literal) args[argPos++]); Float[] vec = new Float[query.length]; for (int i = 0; i < query.length; i++) { vec[i] = ((Number) query[i]).floatValue(); diff --git a/sail/src/main/java/com/msd/gin/halyard/sail/search/SearchTupleFunction.java b/sail/src/main/java/com/msd/gin/halyard/sail/search/SearchTupleFunction.java index 042241c1f..73db3d781 100644 --- a/sail/src/main/java/com/msd/gin/halyard/sail/search/SearchTupleFunction.java +++ b/sail/src/main/java/com/msd/gin/halyard/sail/search/SearchTupleFunction.java @@ -3,7 +3,8 @@ import com.google.common.collect.Lists; import com.msd.gin.halyard.common.RDFFactory; import com.msd.gin.halyard.common.StatementIndices; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.FloatArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.ObjectLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import com.msd.gin.halyard.query.algebra.evaluation.ExtendedTripleSource; @@ -123,15 +124,11 @@ protected List convert(List> matchValues) t for (MatchParams.FieldParams fieldParams : matchParams.fields) { Literal l; if (SearchDocument.VECTOR_FIELD.equals(fieldParams.name)) { - Object[] arr = new Object[doc.vector.length]; - for (int k = 0; k < doc.vector.length; k++) { - arr[k] = doc.vector[k]; - } - l = new ArrayLiteral(arr); + l = new FloatArrayLiteral(doc.vector); } else { Object v = doc.getAdditionalField(fieldParams.name); if (v instanceof List) { - l = new ArrayLiteral(((List) v).toArray()); + l = new ObjectArrayLiteral(((List) v).toArray()); } else if (v != null) { l = Values.literal(valueFactory, v, false); } else { diff --git a/sail/src/main/java/com/msd/gin/halyard/sail/search/function/EscapeTerm.java b/sail/src/main/java/com/msd/gin/halyard/sail/search/function/EscapeTerm.java index bdef5e288..df281de8a 100644 --- a/sail/src/main/java/com/msd/gin/halyard/sail/search/function/EscapeTerm.java +++ b/sail/src/main/java/com/msd/gin/halyard/sail/search/function/EscapeTerm.java @@ -1,6 +1,6 @@ package com.msd.gin.halyard.sail.search.function; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import java.util.Locale; @@ -36,7 +36,7 @@ public Value evaluate(ValueFactory valueFactory, Value... args) throws ValueExpr } Literal l = (Literal) args[0]; if (HALYARD.ARRAY_TYPE.equals(l.getDatatype())) { - Object[] entries = ArrayLiteral.objectArray(l); + Object[] entries = ObjectArrayLiteral.objectArray(l); Object[] escaped = new Object[entries.length]; for (int i = 0; i < entries.length; i++) { Object o = entries[i]; @@ -45,7 +45,7 @@ public Value evaluate(ValueFactory valueFactory, Value... args) throws ValueExpr } escaped[i] = o; } - return new ArrayLiteral(escaped); + return new ObjectArrayLiteral(escaped); } else { String s = l.getLabel(); return valueFactory.createLiteral(escape(s)); diff --git a/sail/src/main/java/com/msd/gin/halyard/sail/search/function/GroupTerms.java b/sail/src/main/java/com/msd/gin/halyard/sail/search/function/GroupTerms.java index 1a4d48480..7d7155f69 100644 --- a/sail/src/main/java/com/msd/gin/halyard/sail/search/function/GroupTerms.java +++ b/sail/src/main/java/com/msd/gin/halyard/sail/search/function/GroupTerms.java @@ -1,6 +1,6 @@ package com.msd.gin.halyard.sail.search.function; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import org.eclipse.rdf4j.model.Literal; @@ -48,7 +48,7 @@ public Value evaluate(ValueFactory valueFactory, Value... args) throws ValueExpr } Literal l = (Literal) arg; if (HALYARD.ARRAY_TYPE.equals(l.getDatatype())) { - Object[] entries = ArrayLiteral.objectArray(l); + Object[] entries = ObjectArrayLiteral.objectArray(l); for (Object entry : entries) { String s = entry.toString(); if (!s.isEmpty()) { diff --git a/sail/src/test/java/com/msd/gin/halyard/sail/search/function/EscapeTermTest.java b/sail/src/test/java/com/msd/gin/halyard/sail/search/function/EscapeTermTest.java index 83b7b4663..a0d7dc2e9 100644 --- a/sail/src/test/java/com/msd/gin/halyard/sail/search/function/EscapeTermTest.java +++ b/sail/src/test/java/com/msd/gin/halyard/sail/search/function/EscapeTermTest.java @@ -1,6 +1,6 @@ package com.msd.gin.halyard.sail.search.function; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import org.eclipse.rdf4j.model.Literal; import org.eclipse.rdf4j.model.ValueFactory; @@ -62,7 +62,7 @@ public void testEscapeUrl() { @Test public void testEscapeList() { - Literal escaped = (Literal) new EscapeTerm().evaluate(vf, new ArrayLiteral(":", "/")); - assertEquals(new ArrayLiteral("\\:", "\\/"), escaped); + Literal escaped = (Literal) new EscapeTerm().evaluate(vf, new ObjectArrayLiteral(":", "/")); + assertEquals(new ObjectArrayLiteral("\\:", "\\/"), escaped); } } diff --git a/sail/src/test/java/com/msd/gin/halyard/sail/search/function/GroupTermsTest.java b/sail/src/test/java/com/msd/gin/halyard/sail/search/function/GroupTermsTest.java index 449298f7d..ded0e4df9 100644 --- a/sail/src/test/java/com/msd/gin/halyard/sail/search/function/GroupTermsTest.java +++ b/sail/src/test/java/com/msd/gin/halyard/sail/search/function/GroupTermsTest.java @@ -1,6 +1,6 @@ package com.msd.gin.halyard.sail.search.function; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import org.eclipse.rdf4j.model.Literal; import org.eclipse.rdf4j.model.ValueFactory; @@ -32,7 +32,7 @@ public void testEmpty() { @Test public void testArray() { - Literal group = (Literal) new GroupTerms().evaluate(vf, vf.createLiteral("AND"), new ArrayLiteral("foo", "bar")); + Literal group = (Literal) new GroupTerms().evaluate(vf, vf.createLiteral("AND"), new ObjectArrayLiteral("foo", "bar")); assertEquals("(foo AND bar)", group.stringValue()); } } diff --git a/spin/src/main/java/com/msd/gin/halyard/spin/function/spif/ForEach.java b/spin/src/main/java/com/msd/gin/halyard/spin/function/spif/ForEach.java index 6df692bc1..84ab05687 100644 --- a/spin/src/main/java/com/msd/gin/halyard/spin/function/spif/ForEach.java +++ b/spin/src/main/java/com/msd/gin/halyard/spin/function/spif/ForEach.java @@ -24,7 +24,8 @@ import org.eclipse.rdf4j.query.QueryEvaluationException; import com.google.common.collect.Iterators; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.AbstractArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.TupleLiteral; import com.msd.gin.halyard.spin.function.InverseMagicProperty; @@ -43,8 +44,8 @@ public CloseableIteration> evaluate( Arrays.stream(args).flatMap(v -> { if (TupleLiteral.isTupleLiteral(v)) { return Arrays.stream(TupleLiteral.valueArray((Literal)v, valueFactory)); - } else if (ArrayLiteral.isArrayLiteral(v)) { - return Arrays.stream(ArrayLiteral.toValues(ArrayLiteral.objectArray((Literal)v), valueFactory)); + } else if (AbstractArrayLiteral.isArrayLiteral(v)) { + return Arrays.stream(AbstractArrayLiteral.toValues(ObjectArrayLiteral.objectArray((Literal)v), valueFactory)); } else { return Stream.of(v); } diff --git a/spin/src/test/java/com/msd/gin/halyard/spin/SailSpifTest.java b/spin/src/test/java/com/msd/gin/halyard/spin/SailSpifTest.java index f2bd27b3c..b1fb80dc5 100644 --- a/spin/src/test/java/com/msd/gin/halyard/spin/SailSpifTest.java +++ b/spin/src/test/java/com/msd/gin/halyard/spin/SailSpifTest.java @@ -25,7 +25,7 @@ import org.junit.Before; import org.junit.Test; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.TupleLiteral; /** @@ -177,7 +177,7 @@ public void testForEachTupleLiteral() throws Exception { @Test public void testForEachArrayLiteral() throws Exception { - String al = NTriplesUtil.toNTriplesString(new ArrayLiteral(2, 3)); + String al = NTriplesUtil.toNTriplesString(new ObjectArrayLiteral(2, 3)); TupleQuery tq = conn.prepareTupleQuery(QueryLanguage.SPARQL, "prefix spif: prefix halyard: " + "select ?x where {?x spif:foreach (1 "+ al + " 4)}"); try (TupleQueryResult tqr = tq.evaluate()) { for (int i = 1; i <= 4; i++) { diff --git a/strategy/src/main/java/com/msd/gin/halyard/strategy/MathOpEvaluator.java b/strategy/src/main/java/com/msd/gin/halyard/strategy/MathOpEvaluator.java index 32d51365f..b0c0eab99 100644 --- a/strategy/src/main/java/com/msd/gin/halyard/strategy/MathOpEvaluator.java +++ b/strategy/src/main/java/com/msd/gin/halyard/strategy/MathOpEvaluator.java @@ -1,6 +1,7 @@ package com.msd.gin.halyard.strategy; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.FloatArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import javax.annotation.concurrent.ThreadSafe; @@ -47,23 +48,55 @@ public Literal evaluate(Literal a, Literal b, MathOp op, ValueFactory vf) { } private static Literal operationBetweenVectors(Literal a, Literal b, MathOp op, ValueFactory vf) { - Object[] aarr = ArrayLiteral.objectArray(a); - Object[] barr = ArrayLiteral.objectArray(b); - if (aarr.length != barr.length) { - throw new ValueExprEvaluationException("Arrays have incompatible dimensions"); - } - try { + if ((a instanceof FloatArrayLiteral) && (b instanceof FloatArrayLiteral)) { + float[] aarr = ((FloatArrayLiteral) a).objectValue(); + float[] barr = ((FloatArrayLiteral) b).objectValue(); + if (aarr.length != barr.length) { + throw new ValueExprEvaluationException("Arrays have incompatible dimensions"); + } switch (op) { case PLUS: - return new ArrayLiteral(add(aarr, barr)); + return new FloatArrayLiteral(add(aarr, barr)); case MINUS: - return new ArrayLiteral(subtract(aarr, barr)); + return new FloatArrayLiteral(subtract(aarr, barr)); default: throw new AssertionError("Unsupported operator: " + op); } - } catch (ClassCastException ex) { - throw new ValueExprEvaluationException(ex); + } else { + Object[] aarr = ObjectArrayLiteral.objectArray(a); + Object[] barr = ObjectArrayLiteral.objectArray(b); + if (aarr.length != barr.length) { + throw new ValueExprEvaluationException("Arrays have incompatible dimensions"); + } + try { + switch (op) { + case PLUS: + return new ObjectArrayLiteral(add(aarr, barr)); + case MINUS: + return new ObjectArrayLiteral(subtract(aarr, barr)); + default: + throw new AssertionError("Unsupported operator: " + op); + } + } catch (ClassCastException ex) { + throw new ValueExprEvaluationException(ex); + } + } + } + + private static float[] add(float[] a, float[] b) { + float[] y = new float[a.length]; + for (int i=0; i { // NB: values.size() is expensive Value[] varr = values.toArray(new Value[0]); - return (varr.length > 0) ? ArrayLiteral.createFromValues(varr) : null; + return (varr.length > 0) ? AbstractArrayLiteral.createFromValues(varr) : null; }); } } diff --git a/strategy/src/test/java/com/msd/gin/halyard/strategy/HalyardStrategyExtendedTest.java b/strategy/src/test/java/com/msd/gin/halyard/strategy/HalyardStrategyExtendedTest.java index e7c533cd4..6c542b119 100644 --- a/strategy/src/test/java/com/msd/gin/halyard/strategy/HalyardStrategyExtendedTest.java +++ b/strategy/src/test/java/com/msd/gin/halyard/strategy/HalyardStrategyExtendedTest.java @@ -19,7 +19,7 @@ import static junit.framework.TestCase.assertFalse; import static junit.framework.TestCase.assertTrue; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; import com.msd.gin.halyard.model.vocabulary.SCHEMA_ORG; @@ -320,16 +320,16 @@ public void testConstantIn() { @Test public void testVectorMath() { - Literal a = new ArrayLiteral(3, 2.0, 0.5f); - Literal b = new ArrayLiteral(-5.15f, 1, 2.32); + Literal a = new ObjectArrayLiteral(3, 2.0, 0.5f); + Literal b = new ObjectArrayLiteral(-5.15f, 1, 2.32); String q = "SELECT (?a+?b as ?sum) (?a-?b as ?diff) (3*?a as ?mult) (?b/0.25 as ?div) { BIND("+a+" as ?a) BIND("+b+" as ?b) }"; try (TupleQueryResult res = con.prepareTupleQuery(q).evaluate()) { assertTrue(res.hasNext()); BindingSet bs = res.next(); - assertArrayEquals(new Object[] {-2.1500000000000004,3,2.82}, ArrayLiteral.objectArray((Literal)bs.getValue("sum"))); - assertArrayEquals(new Object[] {8.15,1,-1.8199999999999998}, ArrayLiteral.objectArray((Literal)bs.getValue("diff"))); - assertArrayEquals(new Object[] {9l,6l,1.5}, ArrayLiteral.objectArray((Literal)bs.getValue("mult"))); - assertArrayEquals(new Object[] {-20.6,4.0,9.28}, ArrayLiteral.objectArray((Literal)bs.getValue("div"))); + assertArrayEquals(new Object[] {-2.1500000000000004,3,2.82}, ObjectArrayLiteral.objectArray((Literal)bs.getValue("sum"))); + assertArrayEquals(new Object[] {8.15,1,-1.8199999999999998}, ObjectArrayLiteral.objectArray((Literal)bs.getValue("diff"))); + assertArrayEquals(new Object[] {9l,6l,1.5}, ObjectArrayLiteral.objectArray((Literal)bs.getValue("mult"))); + assertArrayEquals(new Object[] {-20.6,4.0,9.28}, ObjectArrayLiteral.objectArray((Literal)bs.getValue("div"))); } } diff --git a/strategy/src/test/java/com/msd/gin/halyard/strategy/MathOpEvaluatorTest.java b/strategy/src/test/java/com/msd/gin/halyard/strategy/MathOpEvaluatorTest.java new file mode 100644 index 000000000..1539caeec --- /dev/null +++ b/strategy/src/test/java/com/msd/gin/halyard/strategy/MathOpEvaluatorTest.java @@ -0,0 +1,41 @@ +package com.msd.gin.halyard.strategy; + +import com.msd.gin.halyard.model.FloatArrayLiteral; + +import org.eclipse.rdf4j.model.Literal; +import org.eclipse.rdf4j.model.ValueFactory; +import org.eclipse.rdf4j.model.impl.SimpleValueFactory; +import org.eclipse.rdf4j.query.algebra.MathExpr.MathOp; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class MathOpEvaluatorTest { + @Test + public void testAdd_floatVector() { + ValueFactory vf = SimpleValueFactory.getInstance(); + Literal r = new MathOpEvaluator().evaluate(new FloatArrayLiteral(0.7f, -1.2f), new FloatArrayLiteral(0.1f,0.5f), MathOp.PLUS, vf); + assertArrayEquals(new float[] {0.8f, -0.70000005f}, FloatArrayLiteral.floatArray(r)); + } + + @Test + public void testSubtract_floatVector() { + ValueFactory vf = SimpleValueFactory.getInstance(); + Literal r = new MathOpEvaluator().evaluate(new FloatArrayLiteral(0.7f, -1.2f), new FloatArrayLiteral(0.1f,0.5f), MathOp.MINUS, vf); + assertArrayEquals(new float[] {0.59999996f, -1.7f}, FloatArrayLiteral.floatArray(r)); + } + + @Test + public void testScalarMultiply_floatVector() { + ValueFactory vf = SimpleValueFactory.getInstance(); + Literal r = new MathOpEvaluator().evaluate(vf.createLiteral(0.7f), new FloatArrayLiteral(0.1f,0.5f), MathOp.MULTIPLY, vf); + assertArrayEquals(new float[] {0.07f, 0.35f}, FloatArrayLiteral.floatArray(r)); + } + + @Test + public void testScalarDivide_floatVector() { + ValueFactory vf = SimpleValueFactory.getInstance(); + Literal r = new MathOpEvaluator().evaluate(new FloatArrayLiteral(-0.7f, 1.2f), vf.createLiteral(-0.15f), MathOp.DIVIDE, vf); + assertArrayEquals(new float[] {4.6666665f, -8.0f}, FloatArrayLiteral.floatArray(r)); + } +} diff --git a/tools/src/main/java/com/msd/gin/halyard/tools/HalyardElasticIndexer.java b/tools/src/main/java/com/msd/gin/halyard/tools/HalyardElasticIndexer.java index 56bb630ea..5c6c8b6c2 100644 --- a/tools/src/main/java/com/msd/gin/halyard/tools/HalyardElasticIndexer.java +++ b/tools/src/main/java/com/msd/gin/halyard/tools/HalyardElasticIndexer.java @@ -22,7 +22,7 @@ import com.msd.gin.halyard.common.SSLSettings; import com.msd.gin.halyard.common.StatementIndex; import com.msd.gin.halyard.common.StatementIndices; -import com.msd.gin.halyard.model.ArrayLiteral; +import com.msd.gin.halyard.model.ObjectArrayLiteral; import com.msd.gin.halyard.model.TermRole; import com.msd.gin.halyard.model.TupleLiteral; import com.msd.gin.halyard.model.vocabulary.HALYARD; @@ -33,6 +33,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; +import java.net.URI; import java.net.URL; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; @@ -234,7 +235,7 @@ public int run(CommandLine cmd) throws Exception { configureIRI(cmd, 'g', null); String source = getConf().get(SOURCE_NAME_PROPERTY); String target = getConf().get(INDEX_URL_PROPERTY); - URL targetUrl = new URL(target); + URL targetUrl = new URI(target).toURL(); String indexName = targetUrl.getPath().substring(1); boolean createIndex = getConf().getBoolean(CREATE_INDEX_PROPERTY, false); String snapshotPath = getConf().get(SNAPSHOT_PATH_PROPERTY); @@ -492,7 +493,7 @@ private Object toObject(Value v) throws IOException { return l.getLabel(); } } else if (HALYARD.ARRAY_TYPE.equals(l.getDatatype())) { - return ArrayLiteral.objectArray(l); + return ObjectArrayLiteral.objectArray(l); } else if (HALYARD.TUPLE_TYPE.equals(l.getDatatype())) { Value[] varr = TupleLiteral.valueArray(l, valueFactory); Object[] oarr = new Object[varr.length];