From 375c388b36a36d89a47ac9b008eb3167beec4eaa Mon Sep 17 00:00:00 2001 From: Gemma Lamont Date: Mon, 23 Dec 2024 11:58:37 +0100 Subject: [PATCH] Add other types to give a more robust grouping proc --- core/src/main/java/apoc/nodes/Grouping.java | 206 ++++++++++++++++-- .../test/java/apoc/nodes/GroupingTest.java | 170 ++++++++++++++- 2 files changed, 358 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/apoc/nodes/Grouping.java b/core/src/main/java/apoc/nodes/Grouping.java index 3026d5c99..a7afba8f3 100644 --- a/core/src/main/java/apoc/nodes/Grouping.java +++ b/core/src/main/java/apoc/nodes/Grouping.java @@ -21,16 +21,23 @@ import static java.util.Collections.*; import apoc.Pools; +import apoc.convert.ConvertUtils; import apoc.result.VirtualNode; import apoc.result.VirtualRelationship; import apoc.util.Util; import apoc.util.collection.Iterables; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetTime; +import java.time.ZonedDateTime; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.neo4j.exceptions.ArithmeticException; import org.neo4j.graphdb.*; import org.neo4j.logging.Log; import org.neo4j.procedure.Context; @@ -38,6 +45,8 @@ import org.neo4j.procedure.Name; import org.neo4j.procedure.NotThreadSafe; import org.neo4j.procedure.Procedure; +import org.neo4j.values.storable.DurationValue; +import org.neo4j.values.storable.PointValue; /** * @author mh @@ -368,6 +377,10 @@ private , T extends Entity> C fixAggregates(C pcs) { double[] values = (double[]) v; entry.setValue(values[1] == 0 ? 0 : values[0] / values[1]); } + if (k.matches("^avg_.+") && v instanceof DurationValue) { + Long count = ((Number) pc.getProperty(k + "_count", 0)).longValue(); + entry.setValue(divDurationValue((DurationValue) v, count)); + } if (k.matches("^collect_.+") && v instanceof Collection) { entry.setValue(((Collection) v).toArray()); } @@ -376,6 +389,20 @@ private , T extends Entity> C fixAggregates(C pcs) { return pcs; } + public DurationValue divDurationValue(DurationValue div, Long number) { + double divisor = number.doubleValue(); + + try { + return div.approximate( + (double) div.get("months").longValue() / divisor, + (double) div.get("days").longValue() / divisor, + (double) div.get("seconds").longValue() / divisor, + (double) div.get("nanoseconds").longValue() / divisor); + } catch (ArithmeticException | java.lang.ArithmeticException e) { + return div; + } + } + private void aggregate(Entity pc, Map> aggregations, Map properties) { aggregations.forEach((k2, aggNames) -> { for (String aggName : aggNames) { @@ -395,28 +422,36 @@ private void aggregate(Entity pc, Map> aggregations, Map> aggregations, Map col1, Collection col2) { + Iterator it1 = col1.iterator(); + Iterator it2 = col2.iterator(); + + while (it1.hasNext() && it2.hasNext()) { + Object elem1 = it1.next(); + Object elem2 = it2.next(); + + // Compare elements recursively + int comparison = compareElements(elem1, elem2); + if (comparison != 0) { + return comparison < 0; // Return true if col1 < col2 + } + } + + // If one collection runs out of elements, it is smaller + return col1.size() < col2.size(); + } + + private int compareElements(Object elem1, Object elem2) { + if (elem1 == null && elem2 == null) { + return 0; + } + if (elem1 == null) { + return -1; + } + if (elem2 == null) { + return 1; + } + + if (elem1 instanceof Comparable && elem2 instanceof Comparable) { + // Cast to Comparable and compare + return ((Comparable) elem1).compareTo(elem2); + } + + // If elements are not directly comparable, use orderOfType + return Integer.compare(orderOfType(elem1), orderOfType(elem2)); + } + + private Object returnMinOfDifferentValues(Object prop, Object value) { + return orderOfType(prop) < orderOfType(value) ? prop : value; + } + + private Object returnMaxOfDifferentValues(Object prop, Object value) { + return orderOfType(prop) < orderOfType(value) ? value : prop; + } + + private int orderOfType(Object value) { + if (value != null && value.getClass().isArray()) { + return 0; + } + return switch (value) { + case null -> 11; + case Collection ignored -> 0; + case PointValue ignored -> 1; + case ZonedDateTime ignored -> 2; + case LocalDateTime ignored -> 3; + case LocalDate ignored -> 4; + case OffsetTime ignored -> 5; + case LocalTime ignored -> 6; + case DurationValue ignored -> 7; + case String ignored -> 8; + case Boolean ignored -> 9; + case Number ignored -> 10; + default -> 12; + }; + } + /** * Returns the properties for the given node according to the specified keys. If a node does not have a property * assigned to given key, the value is set to {@code null}. diff --git a/core/src/test/java/apoc/nodes/GroupingTest.java b/core/src/test/java/apoc/nodes/GroupingTest.java index 82253320b..0c7a2a98f 100644 --- a/core/src/test/java/apoc/nodes/GroupingTest.java +++ b/core/src/test/java/apoc/nodes/GroupingTest.java @@ -21,6 +21,7 @@ import static apoc.util.TestUtil.testResult; import static apoc.util.Util.map; import static org.junit.Assert.*; +import static org.neo4j.configuration.GraphDatabaseSettings.procedure_unrestricted; import apoc.util.TestUtil; import apoc.util.collection.Iterators; @@ -35,6 +36,7 @@ import org.neo4j.graphdb.Relationship; import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; +import org.neo4j.values.storable.DurationValue; /** * @author mh @@ -43,11 +45,11 @@ public class GroupingTest { @Rule - public DbmsRule db = new ImpermanentDbmsRule(); + public DbmsRule db = new ImpermanentDbmsRule().withSetting(procedure_unrestricted, List.of("apoc*")); @Before public void setUp() { - TestUtil.registerProcedure(db, Grouping.class); + TestUtil.registerProcedure(db, Grouping.class, Nodes.class); } @After @@ -159,6 +161,170 @@ public void testRemoveOrphans() { TestUtil.testCallCount(db, "CALL apoc.nodes.group(['User'],['gender'],null,{orphans:true})", 1); } + @Test + public void testGroupWithDatetimes() { + db.executeTransactionally( + """ + UNWIND range(1, 1000) AS minutes + CREATE (f:Foo { + created_at: datetime({ year: 2019, month: 3, day: 23 }) - duration({minutes: minutes}) + }) + SET f.created_at_hour = datetime.truncate('hour', f.created_at) + """); + TestUtil.testCallCount( + db, + "CALL apoc.nodes.group(['Foo'], ['created_at_hour'], [{ created_at: 'min' }]) YIELD node\n" + + "RETURN node", + 17); + + // Delete nodes + db.executeTransactionally("MATCH (n:Foo) DELETE n"); + } + + public static class TestObject { + final String testValues; // Values to be inserted as nodes + final String expectedMin; // Expected minimum value + final String expectedMax; // Expected maximum value + + public TestObject(String testValues, String expectedMin, String expectedMax) { + this.testValues = testValues; + this.expectedMin = expectedMin; + this.expectedMax = expectedMax; + } + } + + @Test + public void testGroupWithVariousProperties() { + List testObjects = List.of( + new TestObject("42, 99, 12, 34", "12", "99"), + new TestObject("\"alpha\", \"beta\", \"zeta\"", "\"alpha\"", "\"zeta\""), + new TestObject("true, false, true", "false", "true"), + new TestObject( + "datetime({ year: 2022, month: 1, day: 1 }), datetime({ year: 2021, month: 1, day: 1 }), datetime({ year: 2023, month: 1, day: 1 })", + "datetime({ year: 2021, month: 1, day: 1 })", + "datetime({ year: 2023, month: 1, day: 1 })"), + new TestObject( + "localdatetime({ year: 2022, month: 1, day: 1 }), localdatetime({ year: 2021, month: 1, day: 1 }), localdatetime({ year: 2023, month: 1, day: 1 })", + "localdatetime({ year: 2021, month: 1, day: 1 })", + "localdatetime({ year: 2023, month: 1, day: 1 })"), + new TestObject( + "localtime({hour: 10, minute: 30, second: 1}), localtime({hour: 4, minute: 23, second: 3}), localtime({hour: 6, minute: 33, second: 15})", + "localtime({hour: 4, minute: 23, second: 3})", + "localtime({hour: 10, minute: 30, second: 1})"), + new TestObject( + "time({hour: 10, minute: 30, second: 1}), time({hour: 4, minute: 23, second: 3}), time({hour: 6, minute: 33, second: 15})", + "time({hour: 4, minute: 23, second: 3})", + "time({hour: 10, minute: 30, second: 1})"), + new TestObject( + "date({ year: 2022, month: 1, day: 1 }), date({ year: 2021, month: 1, day: 1 }), date({ year: 2023, month: 1, day: 1 })", + "date({ year: 2021, month: 1, day: 1 })", + "date({ year: 2023, month: 1, day: 1 })"), + new TestObject( + "duration('P11DT16H12M'), duration('P1DT16H12M'), duration('P1DT20H12M')", + "duration('P1DT16H12M')", + "duration('P11DT16H12M')"), + new TestObject("[1, 0, 3], [1, 2, 3]", "[1, 0, 3]", "[1, 2, 3]"), + // Mixed values + new TestObject("1, [1, 2, 3], false", "[1, 2, 3]", "1"), + new TestObject("1, [1, 2, 3], null", "[1, 2, 3]", "1"), + new TestObject("duration('P11DT16H12M'), \"alpha\", false", "duration('P11DT16H12M')", "false"), + new TestObject( + "date({ year: 2022, month: 1, day: 1 }), localtime({hour: 10, minute: 30, second: 1}), datetime({ year: 2022, month: 1, day: 1 })", + "datetime({ year: 2022, month: 1, day: 1 })", + "localtime({hour: 10, minute: 30, second: 1})")); + + for (TestObject testObject : testObjects) { + runTestForProperty(testObject); + } + } + + private void runTestForProperty(TestObject testObject) { + String testValues = testObject.testValues; + String expectedMin = testObject.expectedMin; + String expectedMax = testObject.expectedMax; + + // Create nodes in the database + db.executeTransactionally(String.format( + """ + UNWIND [%s] AS value + CREATE (n:Test { testValue: value, groupKey: 1 }) + """, + testValues)); + + // Test for minimum value + TestUtil.testCall( + db, + String.format( + """ + CALL apoc.nodes.group(['Test'], ['groupKey'], [{ testValue: 'min' }]) YIELD node + WITH apoc.any.property(node, 'min_testValue') AS result + RETURN result = %s AS value, result + """, + expectedMin), + (row) -> assertTrue( + "Testing: " + testValues + "; expected: " + expectedMin + "; but got: " + row.get("result"), + (Boolean) row.get("value"))); + + // Test for maximum value + TestUtil.testCall( + db, + String.format( + """ + CALL apoc.nodes.group(['Test'], ['groupKey'], [{ testValue: 'max' }]) YIELD node + WITH apoc.any.property(node, 'max_testValue') AS result + RETURN result = %s AS value, result + """, + expectedMax), + (row) -> assertTrue( + "Testing: " + testValues + "; expected: " + expectedMax + "; but got: " + row.get("result"), + (Boolean) row.get("value"))); + + // Delete nodes + db.executeTransactionally("MATCH (n:Test) DELETE n"); + } + + @Test + public void testSumAndAvg() { + // Create nodes in the database + db.executeTransactionally( + """ + UNWIND [ + [duration('P11DT16H12M'), 1, 3], + [duration('P1DT16H12M'), 2, duration('P1DT16H12M')], + [duration('P3DT20H12M'), 3, 4]] AS value + CREATE (n:Test { durationValue: value[0], intValue: value[1], mixedValue: value[2], groupKey: 1 }) + """); + + // Test for sum value + TestUtil.testCall( + db, + """ + CALL apoc.nodes.group(['Test'], ['groupKey'], [{ durationValue: 'sum', intValue: 'sum' }]) YIELD node + WITH apoc.any.property(node, 'sum_durationValue') AS sum_durationValue, apoc.any.property(node, 'sum_intValue') AS sum_intValue + RETURN sum_durationValue, sum_intValue + """, + (row) -> { + assertEquals(DurationValue.duration(0, 15, 189360, 0), row.get("sum_durationValue")); + assertEquals(6L, row.get("sum_intValue")); + }); + + // Test for avg value + TestUtil.testCall( + db, + """ + CALL apoc.nodes.group(['Test'], ['groupKey'], [{ durationValue: 'avg', intValue: 'avg' }]) YIELD node + WITH apoc.any.property(node, 'avg_durationValue') AS avg_durationValue, apoc.any.property(node, 'avg_intValue') AS avg_intValue + RETURN avg_durationValue, avg_intValue + """, + (row) -> { + assertEquals(DurationValue.duration(0, 5, 126240, 0), row.get("avg_durationValue")); + assertEquals(2.0, row.get("avg_intValue")); + }); + + // Delete nodes + db.executeTransactionally("MATCH (n:Test) DELETE n"); + } + @Test public void testSelfRels() { db.executeTransactionally("CREATE (u:User {gender:'male'})-[:REL]->(u)");