From 7c6f970dbaefdf27b9ef6070179068e50c0d8c2d Mon Sep 17 00:00:00 2001 From: neo4j-oss-build Date: Mon, 23 May 2022 14:28:30 +0200 Subject: [PATCH] Fixes #2654: The custom procedure validation fails with integer input types (#2702) (#2913) --- .../java/apoc/custom/CypherProcedures.java | 9 ++-- .../apoc/custom/CypherProceduresTest.java | 51 ++++++++++++++++++- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/full/src/main/java/apoc/custom/CypherProcedures.java b/full/src/main/java/apoc/custom/CypherProcedures.java index 72673ab497..26d85d3b4e 100644 --- a/full/src/main/java/apoc/custom/CypherProcedures.java +++ b/full/src/main/java/apoc/custom/CypherProcedures.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -174,12 +175,12 @@ private void validateFunction(String statement, List input) { } private void validateProcedure(String statement, List input, List output, Mode mode) { - - final Set inputSet = input.stream().map(FieldSignature::name).collect(Collectors.toSet()); + final Set outputSet = output.stream().map(FieldSignature::name).collect(Collectors.toSet()); - api.executeTransactionally("EXPLAIN " + statement, - inputSet.stream().collect(Collectors.toMap(i -> i, i -> i)), + api.executeTransactionally("EXPLAIN " + statement, + input.stream().collect(HashMap::new, + (map, value) -> map.put(value.name(), null), HashMap::putAll), result -> { if (!DEFAULT_MAP_OUTPUT.equals(output)) { checkOutputParams(outputSet, result.columns()); diff --git a/full/src/test/java/apoc/custom/CypherProceduresTest.java b/full/src/test/java/apoc/custom/CypherProceduresTest.java index 04efd693e9..c9de6f7ed7 100644 --- a/full/src/test/java/apoc/custom/CypherProceduresTest.java +++ b/full/src/test/java/apoc/custom/CypherProceduresTest.java @@ -15,9 +15,11 @@ import org.neo4j.graphdb.Node; import org.neo4j.graphdb.QueryExecutionException; import org.neo4j.graphdb.Transaction; +import org.neo4j.internal.helpers.collection.Iterators; import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; +import java.time.LocalDateTime; import java.util.Collections; import java.util.List; import java.util.Map; @@ -110,6 +112,51 @@ public void registerConcreteParameterAndReturnStatement() throws Exception { db.executeTransactionally("call apoc.custom.asProcedure('answer','RETURN $input as answer','read',[['answer','number']],[['input','int','42']])"); TestUtil.testCall(db, "call custom.answer()", (row) -> assertEquals(42L, row.get("answer"))); } + + @Test + public void testValidationProceduresIssue2654() { + db.executeTransactionally("CALL apoc.custom.declareProcedure('doubleProc(input::INT) :: (answer::INT)', 'RETURN $input * 2 AS answer')"); + TestUtil.testCall(db, "CALL custom.doubleProc(4);", (r) -> assertEquals(8L, r.get("answer"))); + + db.executeTransactionally("CALL apoc.custom.declareProcedure('testValTwo(input::INT) :: (answer::INT)', 'RETURN $input ^ 2 AS answer')"); + TestUtil.testCall(db, "CALL custom.testValTwo(4);", (r) -> assertEquals(16D, r.get("answer"))); + + db.executeTransactionally("CALL apoc.custom.declareProcedure('testValThree(input::MAP, power :: LONG) :: (answer::INT)', 'RETURN $input.a ^ $power AS answer')"); + TestUtil.testCall(db, "CALL custom.testValThree({a: 2}, 3);", (r) -> assertEquals(8D, r.get("answer"))); + + db.executeTransactionally("CALL apoc.custom.declareProcedure($signature, $query)", + Map.of("signature", "testValFour(input::INT, power::NUMBER) :: (answer::INT)", + "query", "UNWIND range(0, $power) AS power RETURN $input ^ power AS answer")); + + TestUtil.testResult(db, "CALL custom.testValFour(2, 3)", + (r) -> assertEquals(List.of(1D, 2D, 4D, 8D), Iterators.asList(r.columnAs("answer")))); + + db.executeTransactionally("CALL apoc.custom.declareProcedure($signature, $query)", + Map.of("signature", "multiProc(input::LOCALDATETIME, minus::INT) :: (first::INT, second:: STRING, third::DATETIME)", + "query", "WITH $input AS input RETURN input.year - $minus AS first, toString(input) as second, input as third")); + + TestUtil.testCall(db, "CALL custom.multiProc(localdatetime('2020'), 3);", (r) -> { + assertEquals(2017L, r.get("first")); + assertEquals("2020-01-01T00:00:00", r.get("second")); + assertEquals(LocalDateTime.of(2020, 1, 1, 0, 0, 0, 0), r.get("third")); + }); + } + + @Test + public void testValidationFunctionsIssue2654() { + db.executeTransactionally("CALL apoc.custom.declareFunction('double(input::INT) :: INT', 'RETURN $input * 2 AS answer')"); + TestUtil.testCall(db, "RETURN custom.double(4) AS answer", (r) -> assertEquals(8L, r.get("answer"))); + + db.executeTransactionally("CALL apoc.custom.declareFunction('testValOne(input::INT) :: INT', 'RETURN $input ^ 2 AS answer')"); + TestUtil.testCall(db, "RETURN custom.testValOne(3) as result", (r) -> assertEquals(9D, r.get("result"))); + + db.executeTransactionally("CALL apoc.custom.declareFunction($signature, $query)", + Map.of("signature", "multiFun(point:: POINT, input ::DATETIME, duration :: DURATION, minus = 1 ::INT) :: STRING", + "query", "RETURN toString($duration) + ', ' + toString($input.epochMillis - $minus) + ', ' + toString($point) as result")); + + TestUtil.testCall(db, "RETURN custom.multiFun(point({x: 1, y:1}), datetime('2020'), duration('P5M1DT12H')) as result", + (r) -> assertEquals("P5M1DT12H, 1577836799999, point({x: 1.0, y: 1.0, crs: 'cartesian'})", r.get("result"))); + } @Test public void testAllParameterTypes() throws Exception { @@ -244,8 +291,8 @@ public void registerParameterStatementFunction() throws Exception { @Test public void registerConcreteParameterAndReturnStatementFunction() throws Exception { - db.executeTransactionally("call apoc.custom.asFunction('answer','RETURN $input as answer','long',[['input','number']])"); - TestUtil.testCall(db, "return custom.answer(42) as answer", (row) -> assertEquals(42L, row.get("answer"))); + db.executeTransactionally("call apoc.custom.asFunction('answer','RETURN $input.a as answer','long',[['input','map']])"); + TestUtil.testCall(db, "return custom.answer({a: 42}) as answer", (row) -> assertEquals(42L, row.get("answer"))); } @Test