Skip to content

Commit

Permalink
Add other types to give a more robust grouping proc
Browse files Browse the repository at this point in the history
  • Loading branch information
gem-neo4j committed Dec 23, 2024
1 parent e43ecc4 commit b179557
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 14 deletions.
171 changes: 159 additions & 12 deletions core/src/main/java/apoc/nodes/Grouping.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
import static java.util.Collections.*;

import apoc.Pools;
import apoc.convert.ConvertUtils;
import apoc.result.VirtualNode;
import apoc.result.VirtualRelationship;
import apoc.util.Util;
import apoc.util.collection.Iterables;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.OffsetTime;
import java.time.ZonedDateTime;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
Expand All @@ -38,6 +44,8 @@
import org.neo4j.procedure.Name;
import org.neo4j.procedure.NotThreadSafe;
import org.neo4j.procedure.Procedure;
import org.neo4j.values.storable.DurationValue;
import org.neo4j.values.storable.PointValue;

/**
* @author mh
Expand Down Expand Up @@ -395,24 +403,24 @@ private void aggregate(Entity pc, Map<String, List<String>> aggregations, Map<St
pc.setProperty(key, ((Number) pc.getProperty(key, 0)).longValue() + 1);
break;
case "sum":
pc.setProperty(
key, ((Number) pc.getProperty(key, 0)).doubleValue() + Util.toDouble(value));
if (value instanceof DurationValue) {
DurationValue dv =
(DurationValue) pc.getProperty(key, DurationValue.duration(0, 0, 0, 0));
pc.setProperty(key, ((DurationValue) value).add(dv));
} else if (value instanceof Number) {
pc.setProperty(
key,
((Number) pc.getProperty(key, 0)).doubleValue() + Util.toDouble(value));
}
break;
case "min":
pc.setProperty(
key,
Math.min(
((Number) pc.getProperty(key, Double.MAX_VALUE)).doubleValue(),
Util.toDouble(value)));
pc.setProperty(key, getMin(key, pc, value));
break;
case "max":
pc.setProperty(
key,
Math.max(
((Number) pc.getProperty(key, Double.MIN_VALUE)).doubleValue(),
Util.toDouble(value)));
pc.setProperty(key, getMax(key, pc, value));
break;
case "avg": {
// TODO; Add duration option
double[] avg = (double[]) pc.getProperty(key, new double[2]);
avg[0] += Util.toDouble(value);
avg[1] += 1;
Expand All @@ -426,6 +434,145 @@ private void aggregate(Entity pc, Map<String, List<String>> aggregations, Map<St
});
}

private Object getMin(String key, Entity pc, Object value) {
Object prop = pc.getProperty(key);

if (prop == null) {
return value;
}

if (isComparableTypes(prop, value)) {
return compareValues(prop, value) ? prop : value;
}

return returnMinOfDifferentValues(prop, value);
}

private Object getMax(String key, Entity pc, Object value) {
Object prop = pc.getProperty(key);

if (prop == null) {
return value;
}

if (isComparableTypes(prop, value)) {
return compareValues(prop, value) ? value : prop;
}

return returnMaxOfDifferentValues(prop, value);
}

private boolean isComparableTypes(Object prop, Object value) {
return (prop instanceof ZonedDateTime && value instanceof ZonedDateTime)
|| (prop instanceof LocalDateTime && value instanceof LocalDateTime)
|| (prop instanceof LocalDate && value instanceof LocalDate)
|| (prop instanceof OffsetTime && value instanceof OffsetTime)
|| (prop instanceof LocalTime && value instanceof LocalTime)
|| (prop instanceof DurationValue && value instanceof DurationValue)
|| (prop instanceof String && value instanceof String)
|| (prop instanceof Boolean && value instanceof Boolean)
|| (prop instanceof Number && value instanceof Number)
|| ((prop instanceof Collection || prop.getClass().isArray())
&& (value instanceof Collection || value.getClass().isArray()))
|| (prop instanceof PointValue && value instanceof PointValue);
}

private boolean compareValues(Object prop, Object value) {
if (prop instanceof ZonedDateTime pZonedDateTime) {
return ((ZonedDateTime) value).isAfter(pZonedDateTime);
} else if (prop instanceof LocalDateTime pLocalDateTime) {
return ((LocalDateTime) value).isAfter(pLocalDateTime);
} else if (prop instanceof LocalDate pLocalDate) {
return ((LocalDate) value).isAfter(pLocalDate);
} else if (prop instanceof OffsetTime pOffsetTime) {
return ((OffsetTime) value).isAfter(pOffsetTime);
} else if (prop instanceof LocalTime pLocalTime) {
return ((LocalTime) value).isAfter(pLocalTime);
} else if (prop instanceof DurationValue pDurationValue) {
return pDurationValue.compareTo((DurationValue) value) < 0;
} else if (prop instanceof String pString) {
return pString.compareTo((String) value) < 0;
} else if (prop instanceof Boolean pBool) {
return !pBool; // Return `false` if `prop` is `false`
} else if (prop instanceof Number pNumber) {
return pNumber.doubleValue() < Util.toDouble(value);
} else if ((prop instanceof Collection || prop.getClass().isArray())
&& (value instanceof Collection || value.getClass().isArray())) {
return compareCollections(ConvertUtils.convertToList(prop), ConvertUtils.convertToList(value));
} else if (prop instanceof PointValue pPoint && value instanceof PointValue vPoint) {
return pPoint.compareTo(vPoint) < 0;
}
return false; // Default fallback (shouldn't reach here for comparable types)
}

private boolean compareCollections(Collection<?> col1, Collection<?> col2) {
Iterator<?> it1 = col1.iterator();
Iterator<?> it2 = col2.iterator();

while (it1.hasNext() && it2.hasNext()) {
Object elem1 = it1.next();
Object elem2 = it2.next();

// Compare elements recursively
int comparison = compareElements(elem1, elem2);
if (comparison != 0) {
return comparison < 0; // Return true if col1 < col2
}
}

// If one collection runs out of elements, it is smaller
return col1.size() < col2.size();
}

private int compareElements(Object elem1, Object elem2) {
if (elem1 == null && elem2 == null) {
return 0;
}
if (elem1 == null) {
return -1;
}
if (elem2 == null) {
return 1;
}

if (elem1 instanceof Comparable && elem2 instanceof Comparable) {
// Cast to Comparable and compare
return ((Comparable<Object>) elem1).compareTo(elem2);
}

// If elements are not directly comparable, use orderOfType
return Integer.compare(orderOfType(elem1), orderOfType(elem2));
}

private Object returnMinOfDifferentValues(Object prop, Object value) {
return orderOfType(prop) < orderOfType(value) ? prop : value;
}

private Object returnMaxOfDifferentValues(Object prop, Object value) {
return orderOfType(prop) < orderOfType(value) ? value : prop;
}

private int orderOfType(Object value) {
if (value != null && value.getClass().isArray()) {
return 0;
}
return switch (value) {
case null -> 11;
case Collection ignored -> 0;
case PointValue ignored -> 1;
case ZonedDateTime ignored -> 2;
case LocalDateTime ignored -> 3;
case LocalDate ignored -> 4;
case OffsetTime ignored -> 5;
case LocalTime ignored -> 6;
case DurationValue ignored -> 7;
case String ignored -> 8;
case Boolean ignored -> 9;
case Number ignored -> 10;
default -> 12;
};
}

/**
* Returns the properties for the given node according to the specified keys. If a node does not have a property
* assigned to given key, the value is set to {@code null}.
Expand Down
128 changes: 126 additions & 2 deletions core/src/test/java/apoc/nodes/GroupingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static apoc.util.TestUtil.testResult;
import static apoc.util.Util.map;
import static org.junit.Assert.*;
import static org.neo4j.configuration.GraphDatabaseSettings.procedure_unrestricted;

import apoc.util.TestUtil;
import apoc.util.collection.Iterators;
Expand All @@ -43,11 +44,11 @@
public class GroupingTest {

@Rule
public DbmsRule db = new ImpermanentDbmsRule();
public DbmsRule db = new ImpermanentDbmsRule().withSetting(procedure_unrestricted, List.of("apoc*"));

@Before
public void setUp() {
TestUtil.registerProcedure(db, Grouping.class);
TestUtil.registerProcedure(db, Grouping.class, Nodes.class);
}

@After
Expand Down Expand Up @@ -159,6 +160,129 @@ public void testRemoveOrphans() {
TestUtil.testCallCount(db, "CALL apoc.nodes.group(['User'],['gender'],null,{orphans:true})", 1);
}

@Test
public void testGroupWithDatetimes() {
db.executeTransactionally(
"""
UNWIND range(1, 1000) AS minutes
CREATE (f:Foo {
created_at: datetime({ year: 2019, month: 3, day: 23 }) - duration({minutes: minutes})
})
SET f.created_at_hour = datetime.truncate('hour', f.created_at)
""");
TestUtil.testCallCount(
db,
"CALL apoc.nodes.group(['Foo'], ['created_at_hour'], [{ created_at: 'min' }]) YIELD node\n"
+ "RETURN node",
17);
}

public class TestObject {
final String testValues; // Values to be inserted as nodes
final String expectedMin; // Expected minimum value
final String expectedMax; // Expected maximum value

public TestObject(String testValues, String expectedMin, String expectedMax) {
this.testValues = testValues;
this.expectedMin = expectedMin;
this.expectedMax = expectedMax;
}
}

@Test
public void testGroupWithVariousProperties() {
List<TestObject> testObjects = List.of(
new TestObject("42, 99, 12, 34", "12", "99"),
new TestObject("\"alpha\", \"beta\", \"zeta\"", "\"alpha\"", "\"zeta\""),
new TestObject("true, false, true", "false", "true"),
new TestObject(
"datetime({ year: 2022, month: 1, day: 1 }), datetime({ year: 2021, month: 1, day: 1 }), datetime({ year: 2023, month: 1, day: 1 })",
"datetime({ year: 2021, month: 1, day: 1 })",
"datetime({ year: 2023, month: 1, day: 1 })"),
new TestObject(
"localdatetime({ year: 2022, month: 1, day: 1 }), localdatetime({ year: 2021, month: 1, day: 1 }), localdatetime({ year: 2023, month: 1, day: 1 })",
"localdatetime({ year: 2021, month: 1, day: 1 })",
"localdatetime({ year: 2023, month: 1, day: 1 })"),
new TestObject(
"localtime({hour: 10, minute: 30, second: 1}), localtime({hour: 4, minute: 23, second: 3}), localtime({hour: 6, minute: 33, second: 15})",
"localtime({hour: 4, minute: 23, second: 3})",
"localtime({hour: 10, minute: 30, second: 1})"),
new TestObject(
"time({hour: 10, minute: 30, second: 1}), time({hour: 4, minute: 23, second: 3}), time({hour: 6, minute: 33, second: 15})",
"time({hour: 4, minute: 23, second: 3})",
"time({hour: 10, minute: 30, second: 1})"),
new TestObject(
"date({ year: 2022, month: 1, day: 1 }), date({ year: 2021, month: 1, day: 1 }), date({ year: 2023, month: 1, day: 1 })",
"date({ year: 2021, month: 1, day: 1 })",
"date({ year: 2023, month: 1, day: 1 })"),
new TestObject(
"duration('P11DT16H12M'), duration('P1DT16H12M'), duration('P1DT20H12M')",
"duration('P1DT16H12M')",
"duration('P11DT16H12M')"),
new TestObject("[1, 0, 3], [1, 2, 3]", "[1, 0, 3]", "[1, 2, 3]"),
// Mixed values
new TestObject("1, [1, 2, 3], false", "[1, 2, 3]", "1"),
new TestObject("1, [1, 2, 3], null", "[1, 2, 3]", "1"),
new TestObject("duration('P11DT16H12M'), \"alpha\", false", "duration('P11DT16H12M')", "false"),
new TestObject(
"date({ year: 2022, month: 1, day: 1 }), localtime({hour: 10, minute: 30, second: 1}), datetime({ year: 2022, month: 1, day: 1 })",
"datetime({ year: 2022, month: 1, day: 1 })",
"localtime({hour: 10, minute: 30, second: 1})"));

for (TestObject testObject : testObjects) {
runTestForProperty(testObject);
}
}

private void runTestForProperty(TestObject testObject) {
String testValues = testObject.testValues;
String expectedMin = testObject.expectedMin;
String expectedMax = testObject.expectedMax;

// Create nodes in the database
db.executeTransactionally(String.format(
"""
UNWIND [%s] AS value
CREATE (n:Test { testValue: value, groupKey: 1 })
""",
testValues));

// Test for minimum value
TestUtil.testCall(
db,
String.format(
"""
CALL apoc.nodes.group(['Test'], ['groupKey'], [{ testValue: 'min' }]) YIELD node
WITH apoc.any.property(node, 'min_testValue') AS result
RETURN result = %s AS value, result
""",
expectedMin),
(row) -> assertTrue(
"Testing: " + testValues + "; expected: " + expectedMin + "; but got: " + row.get("result"),
(Boolean) row.get("value")));

// Test for maximum value
TestUtil.testCall(
db,
String.format(
"""
CALL apoc.nodes.group(['Test'], ['groupKey'], [{ testValue: 'max' }]) YIELD node
WITH apoc.any.property(node, 'max_testValue') AS result
RETURN result = %s AS value, result
""",
expectedMax),
(row) -> assertTrue(
"Testing: " + testValues + "; expected: " + expectedMax + "; but got: " + row.get("result"),
(Boolean) row.get("value")));

// Delete nodes
db.executeTransactionally(
String.format("""
MATCH (n:Test)
DELETE n
""", testValues));
}

@Test
public void testSelfRels() {
db.executeTransactionally("CREATE (u:User {gender:'male'})-[:REL]->(u)");
Expand Down

0 comments on commit b179557

Please sign in to comment.