diff --git a/regress/expected/expr.out b/regress/expected/expr.out index 696a1b437..fb2bac9f6 100644 --- a/regress/expected/expr.out +++ b/regress/expected/expr.out @@ -198,7 +198,28 @@ $$RETURN {bool: true, int: 1} IN ['str', 1, 1.0, true, null, {bool: true, int: 1 t (1 row) +SELECT * FROM cypher('expr', +$$RETURN 1 IN [1.0, [NULL]]$$) AS r(c boolean); + c +--- + t +(1 row) + +SELECT * FROM cypher('expr', +$$RETURN [NULL] IN [1.0, [NULL]]$$) AS r(c boolean); + c +--- + t +(1 row) + -- should return SQL null, nothing +SELECT * FROM cypher('expr', +$$RETURN true IN NULL $$) AS r(c boolean); + c +--- + +(1 row) + SELECT * FROM cypher('expr', $$RETURN null IN ['str', 1, 1.0, true, null]$$) AS r(c boolean); c @@ -220,39 +241,81 @@ $$RETURN 'str' IN null $$) AS r(c boolean); (1 row) --- should all return false SELECT * FROM cypher('expr', $$RETURN 0 IN ['str', 1, 1.0, true, null]$$) AS r(c boolean); c --- - f + (1 row) SELECT * FROM cypher('expr', $$RETURN 1.1 IN ['str', 1, 1.0, true, null]$$) AS r(c boolean); c --- - f + (1 row) SELECT * FROM cypher('expr', $$RETURN 'Str' IN ['str', 1, 1.0, true, null]$$) AS r(c boolean); c --- - f + (1 row) SELECT * FROM cypher('expr', $$RETURN [1,3,5,[2,4,5]] IN ['str', 1, 1.0, true, null, [1,3,5,[2,4,6]]]$$) AS r(c boolean); c --- - f + (1 row) SELECT * FROM cypher('expr', $$RETURN {bool: true, int: 2} IN ['str', 1, 1.0, true, null, {bool: true, int: 1}, [1,3,5,[2,4,6]]]$$) AS r(c boolean); c --- + +(1 row) + +-- should return false +SELECT * FROM cypher('expr', +$$RETURN 'str' IN ['StR', 1, true]$$) AS r(c boolean); + c +--- + f +(1 row) + +SELECT * FROM cypher('expr', +$$RETURN 2 IN ['StR', 1, true]$$) AS r(c boolean); + c +--- + f +(1 row) + +SELECT * FROM cypher('expr', +$$RETURN false IN ['StR', 1, true]$$) AS r(c boolean); + c +--- + f +(1 row) + +SELECT * FROM cypher('expr', +$$RETURN [1,2] IN ['StR', 1, 2, true]$$) AS r(c boolean); + c +--- + f +(1 row) + +SELECT * FROM cypher('expr', +$$RETURN 1 in [[1]]$$) AS r(c boolean); + c +--- + f +(1 row) + +SELECT * FROM cypher('expr', +$$RETURN 1 IN [[null]]$$) AS r(c boolean); + c +--- f (1 row) @@ -260,9 +323,13 @@ $$RETURN {bool: true, int: 2} IN ['str', 1, 1.0, true, null, {bool: true, int: 1 SELECT * FROM cypher('expr', $$RETURN null IN 'str' $$) AS r(c boolean); ERROR: object of IN must be a list +LINE 2: $$RETURN null IN 'str' $$) AS r(c boolean); + ^ SELECT * FROM cypher('expr', $$RETURN 'str' IN 'str' $$) AS r(c boolean); ERROR: object of IN must be a list +LINE 2: $$RETURN 'str' IN 'str' $$) AS r(c boolean); + ^ -- list access SELECT * FROM cypher('expr', $$RETURN [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10][0]$$) AS r(c agtype); diff --git a/regress/sql/expr.sql b/regress/sql/expr.sql index 3dfee31c7..d3c816d04 100644 --- a/regress/sql/expr.sql +++ b/regress/sql/expr.sql @@ -121,14 +121,19 @@ SELECT * FROM cypher('expr', $$RETURN [1,3,5,[2,4,6]] IN ['str', 1, 1.0, true, null, [1,3,5,[2,4,6]]]$$) AS r(c boolean); SELECT * FROM cypher('expr', $$RETURN {bool: true, int: 1} IN ['str', 1, 1.0, true, null, {bool: true, int: 1}, [1,3,5,[2,4,6]]]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN 1 IN [1.0, [NULL]]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN [NULL] IN [1.0, [NULL]]$$) AS r(c boolean); -- should return SQL null, nothing SELECT * FROM cypher('expr', +$$RETURN true IN NULL $$) AS r(c boolean); +SELECT * FROM cypher('expr', $$RETURN null IN ['str', 1, 1.0, true, null]$$) AS r(c boolean); SELECT * FROM cypher('expr', $$RETURN null IN ['str', 1, 1.0, true]$$) AS r(c boolean); SELECT * FROM cypher('expr', $$RETURN 'str' IN null $$) AS r(c boolean); --- should all return false SELECT * FROM cypher('expr', $$RETURN 0 IN ['str', 1, 1.0, true, null]$$) AS r(c boolean); SELECT * FROM cypher('expr', @@ -139,6 +144,19 @@ SELECT * FROM cypher('expr', $$RETURN [1,3,5,[2,4,5]] IN ['str', 1, 1.0, true, null, [1,3,5,[2,4,6]]]$$) AS r(c boolean); SELECT * FROM cypher('expr', $$RETURN {bool: true, int: 2} IN ['str', 1, 1.0, true, null, {bool: true, int: 1}, [1,3,5,[2,4,6]]]$$) AS r(c boolean); +-- should return false +SELECT * FROM cypher('expr', +$$RETURN 'str' IN ['StR', 1, true]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN 2 IN ['StR', 1, true]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN false IN ['StR', 1, true]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN [1,2] IN ['StR', 1, 2, true]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN 1 in [[1]]$$) AS r(c boolean); +SELECT * FROM cypher('expr', +$$RETURN 1 IN [[null]]$$) AS r(c boolean); -- should error - ERROR: object of IN must be a list SELECT * FROM cypher('expr', $$RETURN null IN 'str' $$) AS r(c boolean); diff --git a/src/backend/parser/cypher_expr.c b/src/backend/parser/cypher_expr.c index 18d500ba9..5ec48e27a 100644 --- a/src/backend/parser/cypher_expr.c +++ b/src/backend/parser/cypher_expr.c @@ -26,6 +26,7 @@ #include "miscadmin.h" #include "nodes/nodeFuncs.h" +#include "optimizer/optimizer.h" #include "parser/parse_coerce.h" #include "parser/parse_collate.h" #include "parser/parse_func.h" @@ -90,6 +91,7 @@ static Node *transform_WholeRowRef(ParseState *pstate, RangeTblEntry *rte, static ArrayExpr *make_agtype_array_expr(List *args); static Node *transform_column_ref_for_indirection(cypher_parsestate *cpstate, ColumnRef *cr); +static bool verify_common_type(Oid common_type, List *exprs); /* transform a cypher expression */ Node *transform_cypher_expr(cypher_parsestate *cpstate, Node *expr, @@ -483,26 +485,153 @@ static Node *transform_cypher_comparison_aexpr_OP(cypher_parsestate *cpstate, return (Node *)transform_AEXPR_OP(cpstate, n); } +/* copied over from PostgreSQL version 13 function of the same name */ +static bool verify_common_type(Oid common_type, List *exprs) +{ + ListCell *lc; + + foreach(lc, exprs) + { + Node *nexpr = (Node *) lfirst(lc); + Oid ntype = exprType(nexpr); + + if (!can_coerce_type(1, &ntype, &common_type, COERCION_IMPLICIT)) + { + return false; + } + } + return true; +} static Node *transform_AEXPR_IN(cypher_parsestate *cpstate, A_Expr *a) { - Oid func_in_oid; - FuncExpr *result; - List *args = NIL; + ParseState *pstate = (ParseState *)cpstate; + cypher_list *rexpr; + Node *result = NULL; + Node *lexpr; + List *rexprs; + List *rvars; + List *rnonvars; + bool useOr; + ListCell *l; + + /* Check for null arguments in the list to return NULL*/ + if (!is_ag_node(a->rexpr, cypher_list)) + { + if (nodeTag(a->rexpr) == T_A_Const) + { + A_Const *r_a_const = (A_Const*)a->rexpr; + if (r_a_const->val.type == T_Null) + { + return (Node *)makeConst(AGTYPEOID, -1, InvalidOid, -1, + (Datum)NULL, true, false); + } + } - args = lappend(args, transform_cypher_expr_recurse(cpstate, a->rexpr)); - args = lappend(args, transform_cypher_expr_recurse(cpstate, a->lexpr)); + ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("object of IN must be a list"))); + } - /* get the agtype_access_slice function */ - func_in_oid = get_ag_func_oid("agtype_in_operator", 2, AGTYPEOID, - AGTYPEOID); + Assert(is_ag_node(a->rexpr, cypher_list)); + + // If the operator is <>, combine with AND not OR. + if (strcmp(strVal(linitial(a->name)), "<>") == 0) + { + useOr = false; + } + else + { + useOr = true; + } + + lexpr = transform_cypher_expr_recurse(cpstate, a->lexpr); + + rexprs = rvars = rnonvars = NIL; + + rexpr = (cypher_list *)a->rexpr; + + foreach(l, (List *) rexpr->elems) + { + Node *rexpr = transform_cypher_expr_recurse(cpstate, lfirst(l)); + + rexprs = lappend(rexprs, rexpr); + if (contain_vars_of_level(rexpr, 0)) + { + rvars = lappend(rvars, rexpr); + } + else + { + rnonvars = lappend(rnonvars, rexpr); + } + } - result = makeFuncExpr(func_in_oid, AGTYPEOID, args, InvalidOid, InvalidOid, - COERCE_EXPLICIT_CALL); + /* + * ScalarArrayOpExpr is only going to be useful if there's more than one + * non-Var righthand item. + */ + if (list_length(rnonvars) > 1) + { + List *allexprs; + Oid scalar_type; + List *aexprs; + ArrayExpr *newa; + + allexprs = list_concat(list_make1(lexpr), rnonvars); - result->location = exprLocation(a->lexpr); + scalar_type = AGTYPEOID; + + Assert(verify_common_type(scalar_type, allexprs)); + /* + * coerce all the right-hand non-Var inputs to the common type + * and build an ArrayExpr for them. + */ + + aexprs = NIL; + foreach(l, rnonvars) + { + Node *rexpr = (Node *) lfirst(l); - return (Node *)result; + rexpr = coerce_to_common_type(pstate, rexpr, AGTYPEOID, "IN"); + aexprs = lappend(aexprs, rexpr); + } + newa = makeNode(ArrayExpr); + newa->array_typeid = get_array_type(AGTYPEOID); + /* array_collid will be set by parse_collate.c */ + newa->element_typeid = AGTYPEOID; + newa->elements = aexprs; + newa->multidims = false; + result = (Node *) make_scalar_array_op(pstate, a->name, useOr, + lexpr, (Node *) newa, + a->location); + + /* Consider only the Vars (if any) in the loop below */ + rexprs = rvars; + } + + // Must do it the hard way, with a boolean expression tree. + foreach(l, rexprs) + { + Node *rexpr = (Node *) lfirst(l); + Node *cmp; + + // Ordinary scalar operator + cmp = (Node *) make_op(pstate, a->name, copyObject(lexpr), rexpr, + pstate->p_last_srf, a->location); + + cmp = coerce_to_boolean(pstate, cmp, "IN"); + if (result == NULL) + { + result = cmp; + } + else + { + result = (Node *) makeBoolExpr(useOr ? OR_EXPR : AND_EXPR, + list_make2(result, cmp), + a->location); + } + } + + return result; } static Node *transform_BoolExpr(cypher_parsestate *cpstate, BoolExpr *expr) diff --git a/src/backend/utils/adt/agtype_ops.c b/src/backend/utils/adt/agtype_ops.c index 7a1d3aed4..5309e1334 100644 --- a/src/backend/utils/adt/agtype_ops.c +++ b/src/backend/utils/adt/agtype_ops.c @@ -1401,7 +1401,7 @@ Datum agtype_exists_all_agtype(PG_FUNCTION_ARGS) PG_FUNCTION_INFO_V1(agtype_contains); /* - * <@ operator for agtype. Returns true if the right agtype path/value entries + * @> operator for agtype. Returns true if the right agtype path/value entries * contained at the top level within the left agtype value */ Datum agtype_contains(PG_FUNCTION_ARGS)