diff --git a/pyproject.toml b/pyproject.toml index 68731fdf..c74c88b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ ignore = [ "E501", # line too long, handled by black "B008", # do not perform function calls in argument defaults "B905", # ignore zip() without an explicit strict= parameter, only support with python >3.10 + "F811", # ignore multiple definitions, handled by mypy ] diff --git a/tests/fixtures/cql2text.cql2 b/tests/fixtures/cql2text.cql2 new file mode 100644 index 00000000..4ed7334b --- /dev/null +++ b/tests/fixtures/cql2text.cql2 @@ -0,0 +1,69 @@ +"id" = 'fa7e1920-9107-422d-a3db-c468cbc5d6df' +"id" <> 'fa7e1920-9107-422d-a3db-c468cbc5d6df' +"value" < 10 +"value" > 10 +"value" <= 10 +"value" >= 10 +"name" LIKE 'foo%' +"name" NOT LIKE 'foo%' +NOT "name" LIKE 'foo%' +"value" BETWEEN 10 AND 20 +"value" NOT BETWEEN 10 AND 20 +NOT "value" BETWEEN 10 AND 20 +"value" IN (1.0, 2.0, 3.0) +"value" NOT IN ('a', 'b', 'c') +NOT "value" IN ('a', 'b', 'c') +"value" IS NULL +"value" IS NOT NULL +NOT "value" IS NULL +"name" NOT LIKE 'foo%' AND "value" > 10 +(NOT "name" LIKE 'foo%' AND "value" > 10) +"value" IS NULL OR "value" BETWEEN 10 AND 20 +("value" IS NULL OR "value" BETWEEN 10 AND 20) +S_INTERSECTS("geometry", BBOX(-128.098193, -1.1, -99999.0, 180.0, 90.0, 100000.0)) +S_EQUALS( POLYGON ( (-0.333333 89.0, -102.723546 -0.5, -179.0 -89.0, -1.9 89.0, -0.0 89.0, 2.00001 -1.9, -0.333333 89.0) ), "geometry" ) +S_EQUALS(POLYGON ((-0.333333 89.0, -102.723546 -0.5, -179.0 -89.0, -1.9 89.0, -0.0 89.0, 2.00001 -1.9, -0.333333 89.0)), "geometry") +S_DISJOINT("geometry", MULTIPOLYGON (((144.022387 45.176126, -1.1 0.0, 180.0 47.808086, 144.022387 45.176126)))) +S_TOUCHES("geometry", MULTILINESTRING ((-1.9 -0.99999, 75.292574 1.5, -0.5 -4.016458, -31.708594 -74.743801, 179.0 -90.0),(-1.9 -1.1, 1.5 8.547371))) +S_WITHIN(POLYGON ((-49.88024 0.5 -75993.341684, -1.5 -0.99999 -100000.0, 0.0 0.5 -0.333333, -49.88024 0.5 -75993.341684), (-65.887123 2.00001 -100000.0, 0.333333 -53.017711 -79471.332949, 180.0 0.0 1852.616704, -65.887123 2.00001 -100000.0)), "geometry") +S_WITHIN(POLYGON Z ((-49.88024 0.5 -75993.341684, -1.5 -0.99999 -100000.0, 0.0 0.5 -0.333333, -49.88024 0.5 -75993.341684), (-65.887123 2.00001 -100000.0, 0.333333 -53.017711 -79471.332949, 180.0 0.0 1852.616704, -65.887123 2.00001 -100000.0)), "geometry") +S_OVERLAPS("geometry", BBOX(-179.912109, 1.9, 180.0, 16.897016)) +S_CROSSES("geometry", LINESTRING (172.03086 1.5, 1.1 -90.0, -159.757695 0.99999, -180.0 0.5, -12.111235 81.336403, -0.5 64.43958, 0.0 81.991815, -155.93831 90.0)) +S_CONTAINS("geometry", POINT (-3.508362 -1.754181)) +T_AFTER("updated_at", DATE('2010-02-10')) +T_BEFORE(updated_at, TIMESTAMP('2012-08-10T05:30:00Z')) +T_BEFORE("updated_at", TIMESTAMP('2012-08-10T05:30:00.000000Z')) +T_CONTAINS(INTERVAL('2000-01-01T00:00:00Z', '2005-01-10T01:01:01.393216Z'), "updated_at") +T_CONTAINS(INTERVAL('2000-01-01T00:00:00.000000Z', '2005-01-10T01:01:01.393216Z'), "updated_at") +T_DISJOINT(INTERVAL('..', '2005-01-10T01:01:01.393216Z'), "coverage_date") +T_DURING(INTERVAL("created_at", "updated_at"), INTERVAL('2005-01-10', '2010-02-10')) +T_EQUALS("updated_at", DATE('1851-04-29')) +T_FINISHEDBY("coverage_date", INTERVAL('1991-10-07T08:21:06.393262Z', '2010-02-10T05:29:20.073225Z')) +T_FINISHES("coverage_dates", INTERVAL('1991-10-07', '2010-02-10T05:29:20.073225Z')) +T_INTERSECTS("coverage_date", INTERVAL('1991-10-07T08:21:06.393262Z', '2010-02-10T05:29:20.073225Z')) +T_MEETS(INTERVAL('2005-01-10', '2010-02-10'), "coverage_dates") +T_METBY(INTERVAL('2010-02-10T05:29:20.073225Z', '2010-10-07'), "coverage_dates") +T_OVERLAPPEDBY(INTERVAL('1991-10-07T08:21:06.393262Z', '2010-02-10T05:29:20.073225Z'), "coverage_dates") +T_OVERLAPS("coverage_date", INTERVAL('1991-10-07T08:21:06.393262Z', '1992-10-09T08:08:08.393473Z')) +T_STARTEDBY(INTERVAL('1991-10-07T08:21:06.393262Z', '2010-02-10T05:29:20.073225Z'), "coverage_dates") +T_STARTS("coverage_dates", INTERVAL('1991-10-07T08:21:06.393262Z', '..')) +Foo("geometry") = TRUE +FALSE <> Bar("geometry", 100, 'a', 'b', FALSE) +ACCENTI("owner") = ACCENTI('Beyoncé') +CASEI("owner") = CASEI('somebody else') +"value" > ("foo" + 10) +"value" < ("foo" - 10) +"value" <> (22.1 * "foo") +"value" = (2 / "foo") +"value" <= (2 ^ "foo") +0 = ("foo" % 2) +1 = ("foo" div 2) +A_CONTAINEDBY("values", ('a', 'b', 'c')) +A_CONTAINS("values", ('a', 'b', 'c')) +A_EQUALS(('a', TRUE, 1.0, 8), "values") +A_OVERLAPS("values", (TIMESTAMP('2012-08-10T05:30:00.000000Z'), DATE('2010-02-10'), FALSE)) +S_EQUALS(MULTIPOINT (180.0 -0.5, 179.0 -47.121701, 180.0 -0.0, 33.470475 -0.99999, 179.0 -15.333062), "geometry") +S_EQUALS(GEOMETRYCOLLECTION (POINT (1.9 2.00001), POINT (0.0 -2.00001), MULTILINESTRING ((-2.00001 -0.0, -77.292642 -0.5, -87.515626 -0.0, -180.0 12.502773, 21.204842 -1.5, -21.878857 -90.0)), POINT (1.9 0.5), LINESTRING (179.0 1.179148, -148.192487 -65.007816, 0.5 0.333333)), "geometry") +value = - foo * 2.0 + "bar" / 6.1234 - "x" ^ 2.0 +"value" = ((((-1 * "foo") * 2.0) + ("bar" / 6.1234)) - ("x" ^ 2.0)) +"name" LIKE CASEI('FOO%') diff --git a/tests/fixtures/cql2text_asyncpg.asql b/tests/fixtures/cql2text_asyncpg.asql new file mode 100644 index 00000000..d04682a8 --- /dev/null +++ b/tests/fixtures/cql2text_asyncpg.asql @@ -0,0 +1,69 @@ +'id = $1', ['fa7e1920-9107-422d-a3db-c468cbc5d6df'] +'id != $1', ['fa7e1920-9107-422d-a3db-c468cbc5d6df'] +'value < $1', [10.0] +'value > $1', [10.0] +'value <= $1', [10.0] +'value >= $1', [10.0] +'name LIKE $1', ['foo%'] +'NOT name LIKE $1', ['foo%'] +'NOT name LIKE $1', ['foo%'] +'value BETWEEN $1 AND $2', [10.0, 20.0] +'NOT (value BETWEEN $1 AND $2)', [10.0, 20.0] +'NOT (value BETWEEN $1 AND $2)', [10.0, 20.0] +'value = ANY(ARRAY[$1, $2, $3])', [1.0, 2.0, 3.0] +'NOT value = ANY(ARRAY[$1, $2, $3])', ['a', 'b', 'c'] +'NOT value = ANY(ARRAY[$1, $2, $3])', ['a', 'b', 'c'] +'value IS NULL', [] +'NOT value IS NULL', [] +'NOT value IS NULL', [] +'NOT name LIKE $1 AND value > $2', ['foo%', 10.0] +'NOT name LIKE $1 AND value > $2', ['foo%', 10.0] +'value IS NULL OR value BETWEEN $1 AND $2', [10.0, 20.0] +'value IS NULL OR value BETWEEN $1 AND $2', [10.0, 20.0] +'ST_INTERSECTS(geometry, ST_MAKEENVELOPE($1, $2, $3, $4, $5))', [-128.098193, -1.1, 180.0, 90.0, 4326] +'ST_EQUALS($1::geometry, geometry)', ['SRID=4326;POLYGON ((-0.333333 89.0, -102.723546 -0.5, -179.0 -89.0, -1.9 89.0, -0.0 89.0, 2.00001 -1.9, -0.333333 89.0))'] +'ST_EQUALS($1::geometry, geometry)', ['SRID=4326;POLYGON ((-0.333333 89.0, -102.723546 -0.5, -179.0 -89.0, -1.9 89.0, -0.0 89.0, 2.00001 -1.9, -0.333333 89.0))'] +'ST_DISJOINT(geometry, $1::geometry)', ['SRID=4326;MULTIPOLYGON (((144.022387 45.176126, -1.1 0.0, 180.0 47.808086, 144.022387 45.176126)))'] +'ST_TOUCHES(geometry, $1::geometry)', ['SRID=4326;MULTILINESTRING ((-1.9 -0.99999, 75.292574 1.5, -0.5 -4.016458, -31.708594 -74.743801, 179.0 -90.0), (-1.9 -1.1, 1.5 8.547371))'] +'ST_WITHIN($1::geometry, geometry)', ['SRID=4326;POLYGON Z ((-49.88024 0.5 -75993.341684, -1.5 -0.99999 -100000.0, 0.0 0.5 -0.333333, -49.88024 0.5 -75993.341684), (-65.887123 2.00001 -100000.0, 0.333333 -53.017711 -79471.332949, 180.0 0.0 1852.616704, -65.887123 2.00001 -100000.0))'] +'ST_WITHIN($1::geometry, geometry)', ['SRID=4326;POLYGON Z ((-49.88024 0.5 -75993.341684, -1.5 -0.99999 -100000.0, 0.0 0.5 -0.333333, -49.88024 0.5 -75993.341684), (-65.887123 2.00001 -100000.0, 0.333333 -53.017711 -79471.332949, 180.0 0.0 1852.616704, -65.887123 2.00001 -100000.0))'] +'ST_OVERLAPS(geometry, ST_MAKEENVELOPE($1, $2, $3, $4, $5))', [-179.912109, 1.9, 180.0, 16.897016, 4326] +'ST_CROSSES(geometry, $1::geometry)', ['SRID=4326;LINESTRING (172.03086 1.5, 1.1 -90.0, -159.757695 0.99999, -180.0 0.5, -12.111235 81.336403, -0.5 64.43958, 0.0 81.991815, -155.93831 90.0)'] +'ST_CONTAINS(geometry, $1::geometry)', ['SRID=4326;POINT (-3.508362 -1.754181)'] +'$1::date < "updated_at"', ['2010-02-10'] +'"updated_at" < $1::timestamptz', ['2012-08-10T05:30:00+00:00'] +'"updated_at" < $1::timestamptz', ['2012-08-10T05:30:00+00:00'] +'"updated_at" > $1::timestamptz AND "updated_at" < $2::timestamptz', ['2000-01-01T00:00:00+00:00', '2005-01-10T01:01:01.393216+00:00'] +'"updated_at" > $1::timestamptz AND "updated_at" < $2::timestamptz', ['2000-01-01T00:00:00+00:00', '2005-01-10T01:01:01.393216+00:00'] +'$1::timestamptz > "coverage_date" OR $2::timestamptz < "coverage_date"', ['-infinity', '2005-01-10T01:01:01.393216+00:00'] +'"created_at" > $1::date AND "updated_at" < $2::date', ['2005-01-10', '2010-02-10'] +'"updated_at" = $1::date AND "updated_at" = $2::date', ['1851-04-29', '1851-04-29'] +'$1::timestamptz > "coverage_date" AND $2::timestamptz = "coverage_date"', ['1991-10-07T08:21:06.393262+00:00', '2010-02-10T05:29:20.073225+00:00'] +'"coverage_dates" > $1::date AND "coverage_dates" = $2::timestamptz', ['1991-10-07', '2010-02-10T05:29:20.073225+00:00'] +'"coverage_date" <= $1::timestamptz AND "coverage_date" >= $2::timestamptz', ['2010-02-10T05:29:20.073225+00:00', '1991-10-07T08:21:06.393262+00:00'] +'$1::date = "coverage_dates"', ['2010-02-10'] +'"coverage_dates" = $1::timestamptz', ['2010-02-10T05:29:20.073225+00:00'] +'"coverage_dates" < $1::timestamptz AND "coverage_dates" > $2::timestamptz AND "coverage_dates" < $3::timestamptz', ['1991-10-07T08:21:06.393262+00:00', '2010-02-10T05:29:20.073225+00:00', '2010-02-10T05:29:20.073225+00:00'] +'"coverage_date" < $1::timestamptz AND "coverage_date" > $2::timestamptz AND "coverage_date" < $3::timestamptz', ['1991-10-07T08:21:06.393262+00:00', '1992-10-09T08:08:08.393473+00:00', '1992-10-09T08:08:08.393473+00:00'] +'"coverage_dates" = $1::timestamptz AND "coverage_dates" < $2::timestamptz', ['1991-10-07T08:21:06.393262+00:00', '2010-02-10T05:29:20.073225+00:00'] +'"coverage_dates" = $1::timestamptz AND "coverage_dates" < $2::timestamptz', ['1991-10-07T08:21:06.393262+00:00', 'infinity'] +'FOO(geometry) IS $1', [True] +'$1 IS NOT BAR(geometry, $2, $3, $4, $5)', [False, 100.0, 'a', 'b', False] +'UNACCENT(owner) = UNACCENT($1)', ['Beyoncé'] +'LOWER(owner) = LOWER($1)', ['somebody else'] +'value > (foo + $1)', [10.0] +'value < (foo - $1)', [10.0] +'value != ($1 * foo)', [22.1] +'value = ($1 / foo)', [2.0] +'value <= POWER($1, foo)', [2.0] +'$1 = MOD(foo, $2)', [0.0, 2.0] +'$1 = (foo / $2)', [1.0, 2.0] +'values <@ ARRAY[$1, $2, $3]', ['a', 'b', 'c'] +'values @> ARRAY[$1, $2, $3]', ['a', 'b', 'c'] +'ARRAY[$1, $2, $3, $4] = values', ['a', True, 1.0, 8.0] +'values && ARRAY[$1::timestamptz, $2::date, $3]', ['2012-08-10T05:30:00+00:00', '2010-02-10', False] +'ST_EQUALS(MULTIPOINT($1 - $2, $3 - $4, $5 - $6, $7 - $8, $9 - $10), geometry)', [180.0, 0.5, 179.0, 47.121701, 180.0, 0.0, 33.470475, 0.99999, 179.0, 15.333062] +'ST_EQUALS($1::geometry, geometry)', ['SRID=4326;GEOMETRYCOLLECTION (POINT (1.9 2.00001), POINT (0.0 -2.00001), MULTILINESTRING ((-2.00001 -0.0, -77.292642 -0.5, -87.515626 -0.0, -180.0 12.502773, 21.204842 -1.5, -21.878857 -90.0)), POINT (1.9 0.5), LINESTRING (179.0 1.179148, -148.192487 -65.007816, 0.5 0.333333))'] +'value = (($1 * foo) * $2) + (bar / $3) - POWER(x, $4)', [-1, 2.0, 6.1234, 2.0] +'value = (($1 * foo) * $2) + (bar / $3) - POWER(x, $4)', [-1.0, 2.0, 6.1234, 2.0] +'name ILIKE $1', ['FOO%'] diff --git a/tests/fixtures/cql2text_rawsql.sql b/tests/fixtures/cql2text_rawsql.sql new file mode 100644 index 00000000..52ae4534 --- /dev/null +++ b/tests/fixtures/cql2text_rawsql.sql @@ -0,0 +1,69 @@ +id = 'fa7e1920-9107-422d-a3db-c468cbc5d6df' +id != 'fa7e1920-9107-422d-a3db-c468cbc5d6df' +value < 10.0 +value > 10.0 +value <= 10.0 +value >= 10.0 +name LIKE 'foo%' +NOT name LIKE 'foo%' +NOT name LIKE 'foo%' +value BETWEEN 10.0 AND 20.0 +NOT (value BETWEEN 10.0 AND 20.0) +NOT (value BETWEEN 10.0 AND 20.0) +value = ANY(ARRAY[1.0, 2.0, 3.0]) +NOT value = ANY(ARRAY['a', 'b', 'c']) +NOT value = ANY(ARRAY['a', 'b', 'c']) +value IS NULL +NOT value IS NULL +NOT value IS NULL +NOT name LIKE 'foo%' AND value > 10.0 +NOT name LIKE 'foo%' AND value > 10.0 +value IS NULL OR value BETWEEN 10.0 AND 20.0 +value IS NULL OR value BETWEEN 10.0 AND 20.0 +ST_INTERSECTS(geometry, ST_MAKEENVELOPE(-128.098193, -1.1, 180.0, 90.0, 4326)) +ST_EQUALS('SRID=4326;POLYGON ((-0.333333 89.0, -102.723546 -0.5, -179.0 -89.0, -1.9 89.0, -0.0 89.0, 2.00001 -1.9, -0.333333 89.0))'::geometry, geometry) +ST_EQUALS('SRID=4326;POLYGON ((-0.333333 89.0, -102.723546 -0.5, -179.0 -89.0, -1.9 89.0, -0.0 89.0, 2.00001 -1.9, -0.333333 89.0))'::geometry, geometry) +ST_DISJOINT(geometry, 'SRID=4326;MULTIPOLYGON (((144.022387 45.176126, -1.1 0.0, 180.0 47.808086, 144.022387 45.176126)))'::geometry) +ST_TOUCHES(geometry, 'SRID=4326;MULTILINESTRING ((-1.9 -0.99999, 75.292574 1.5, -0.5 -4.016458, -31.708594 -74.743801, 179.0 -90.0), (-1.9 -1.1, 1.5 8.547371))'::geometry) +ST_WITHIN('SRID=4326;POLYGON Z ((-49.88024 0.5 -75993.341684, -1.5 -0.99999 -100000.0, 0.0 0.5 -0.333333, -49.88024 0.5 -75993.341684), (-65.887123 2.00001 -100000.0, 0.333333 -53.017711 -79471.332949, 180.0 0.0 1852.616704, -65.887123 2.00001 -100000.0))'::geometry, geometry) +ST_WITHIN('SRID=4326;POLYGON Z ((-49.88024 0.5 -75993.341684, -1.5 -0.99999 -100000.0, 0.0 0.5 -0.333333, -49.88024 0.5 -75993.341684), (-65.887123 2.00001 -100000.0, 0.333333 -53.017711 -79471.332949, 180.0 0.0 1852.616704, -65.887123 2.00001 -100000.0))'::geometry, geometry) +ST_OVERLAPS(geometry, ST_MAKEENVELOPE(-179.912109, 1.9, 180.0, 16.897016, 4326)) +ST_CROSSES(geometry, 'SRID=4326;LINESTRING (172.03086 1.5, 1.1 -90.0, -159.757695 0.99999, -180.0 0.5, -12.111235 81.336403, -0.5 64.43958, 0.0 81.991815, -155.93831 90.0)'::geometry) +ST_CONTAINS(geometry, 'SRID=4326;POINT (-3.508362 -1.754181)'::geometry) +'2010-02-10'::date < "updated_at" +"updated_at" < '2012-08-10T05:30:00+00:00'::timestamptz +"updated_at" < '2012-08-10T05:30:00+00:00'::timestamptz +"updated_at" > '2000-01-01T00:00:00+00:00'::timestamptz AND "updated_at" < '2005-01-10T01:01:01.393216+00:00'::timestamptz +"updated_at" > '2000-01-01T00:00:00+00:00'::timestamptz AND "updated_at" < '2005-01-10T01:01:01.393216+00:00'::timestamptz +'-infinity'::timestamptz > "coverage_date" OR '2005-01-10T01:01:01.393216+00:00'::timestamptz < "coverage_date" +"created_at" > '2005-01-10'::date AND "updated_at" < '2010-02-10'::date +"updated_at" = '1851-04-29'::date AND "updated_at" = '1851-04-29'::date +'1991-10-07T08:21:06.393262+00:00'::timestamptz > "coverage_date" AND '2010-02-10T05:29:20.073225+00:00'::timestamptz = "coverage_date" +"coverage_dates" > '1991-10-07'::date AND "coverage_dates" = '2010-02-10T05:29:20.073225+00:00'::timestamptz +"coverage_date" <= '2010-02-10T05:29:20.073225+00:00'::timestamptz AND "coverage_date" >= '1991-10-07T08:21:06.393262+00:00'::timestamptz +'2010-02-10'::date = "coverage_dates" +"coverage_dates" = '2010-02-10T05:29:20.073225+00:00'::timestamptz +"coverage_dates" < '1991-10-07T08:21:06.393262+00:00'::timestamptz AND "coverage_dates" > '2010-02-10T05:29:20.073225+00:00'::timestamptz AND "coverage_dates" < '2010-02-10T05:29:20.073225+00:00'::timestamptz +"coverage_date" < '1991-10-07T08:21:06.393262+00:00'::timestamptz AND "coverage_date" > '1992-10-09T08:08:08.393473+00:00'::timestamptz AND "coverage_date" < '1992-10-09T08:08:08.393473+00:00'::timestamptz +"coverage_dates" = '1991-10-07T08:21:06.393262+00:00'::timestamptz AND "coverage_dates" < '2010-02-10T05:29:20.073225+00:00'::timestamptz +"coverage_dates" = '1991-10-07T08:21:06.393262+00:00'::timestamptz AND "coverage_dates" < 'infinity'::timestamptz +FOO(geometry) IS True +False IS NOT BAR(geometry, 100.0, 'a', 'b', False) +UNACCENT(owner) = UNACCENT('Beyoncé') +LOWER(owner) = LOWER('somebody else') +value > (foo + 10.0) +value < (foo - 10.0) +value != (22.1 * foo) +value = (2.0 / foo) +value <= POWER(2.0, foo) +0.0 = MOD(foo, 2.0) +1.0 = (foo / 2.0) +values <@ ARRAY['a', 'b', 'c'] +values @> ARRAY['a', 'b', 'c'] +ARRAY['a', True, 1.0, 8.0] = values +values && ARRAY['2012-08-10T05:30:00+00:00'::timestamptz, '2010-02-10'::date, False] +ST_EQUALS(MULTIPOINT(180.0 - 0.5, 179.0 - 47.121701, 180.0 - 0.0, 33.470475 - 0.99999, 179.0 - 15.333062), geometry) +ST_EQUALS('SRID=4326;GEOMETRYCOLLECTION (POINT (1.9 2.00001), POINT (0.0 -2.00001), MULTILINESTRING ((-2.00001 -0.0, -77.292642 -0.5, -87.515626 -0.0, -180.0 12.502773, 21.204842 -1.5, -21.878857 -90.0)), POINT (1.9 0.5), LINESTRING (179.0 1.179148, -148.192487 -65.007816, 0.5 0.333333))'::geometry, geometry) +value = ((-1 * foo) * 2.0) + (bar / 6.1234) - POWER(x, 2.0) +value = ((-1.0 * foo) * 2.0) + (bar / 6.1234) - POWER(x, 2.0) +name ILIKE 'FOO%' diff --git a/tests/routes/test_items.py b/tests/routes/test_items.py index 16fc0e02..5d27b101 100644 --- a/tests/routes/test_items.py +++ b/tests/routes/test_items.py @@ -247,6 +247,7 @@ def test_items_filter_cql_ids(app): assert response.status_code == 200 assert response.headers["content-type"] == "application/geo+json" body = response.json() + print(body) assert len(body["features"]) == 1 assert body["numberMatched"] == 1 assert body["numberReturned"] == 1 @@ -257,7 +258,7 @@ def test_items_filter_cql_ids(app): response = app.get( "/collections/public.landsat_wrs/items?filter-lang=cql2-text&filter=ogc_fid IN (1,2)" ) - + print(response.json()) assert response.status_code == 200 assert response.headers["content-type"] == "application/geo+json" body = response.json() @@ -327,7 +328,7 @@ def test_items_properties_filter_cql2(app): assert body["features"][0]["properties"]["row"] == 10 Items.model_validate(body) - filter_query = {"op": "isNull", "args": [{"property": "numeric"}]} + filter_query = {"op": "isNull", "args": {"property": "numeric"}} response = app.get( f"/collections/public.my_data/items?filter-lang=cql2-json&filter=&filter={json.dumps(filter_query)}" ) diff --git a/tipg/collections.py b/tipg/collections.py index 711364a2..f2c35078 100644 --- a/tipg/collections.py +++ b/tipg/collections.py @@ -2,14 +2,13 @@ import datetime import re +from functools import cached_property from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union -from buildpg import RawDangerous as raw -from buildpg import asyncpg, clauses -from buildpg import funcs as pg_funcs -from buildpg import logic, render -from ciso8601 import parse_rfc3339 +import asyncpg +from fastapi import FastAPI from morecantile import Tile, TileMatrixSet +from pgmini import And, Or, Select, With, Param from pydantic import BaseModel, Field, model_validator from pygeofilter.ast import AstType @@ -21,35 +20,34 @@ InvalidPropertyName, MissingDatetimeColumn, ) -from tipg.filter.evaluate import to_filter +from tipg.filter.cql2sql import CQL2SQL from tipg.filter.filters import bbox_to_wkt from tipg.logger import logger from tipg.model import Extent +from tipg.query import ( + NULL, + Bbox, + Count, + F, + P, + Table, + Transform, + build, + date_param, + ensure_list, + raw_query, + simplified, +) +import pgmini from tipg.settings import FeaturesSettings, MVTSettings, TableSettings -from fastapi import FastAPI - mvt_settings = MVTSettings() features_settings = FeaturesSettings() def debug_query(q, *p): """Utility to print raw statement to use for debugging.""" - - qsub = re.sub(r"\$([0-9]+)", r"{\1}", q) - - def quote_str(s): - """Quote strings.""" - - if s is None: - return "null" - elif isinstance(s, str): - return f"'{s}'" - else: - return s - - p = [quote_str(s) for s in p] - logger.debug(qsub.format(None, *p)) + logger.debug(raw_query(q, *p)) # Links to geojson schema @@ -169,6 +167,16 @@ class Collection(BaseModel): datetime_column: Optional[Column] = None parameters: List[Parameter] = [] + @cached_property + def T(self): + """Returns a pgmini Table.""" + tablecls = type( + self.table, + (Table,), + {f'"{c.name}"': c.type for c in self.properties}, + ) + return tablecls(self.table) + @property def extent(self) -> Optional[Extent]: """Return extent.""" @@ -228,7 +236,7 @@ def dt_bounds(self) -> Optional[List[str]]: return None - @property + @cached_property def crs(self): """Return crs of set geometry column.""" if self.geometry_column: @@ -301,25 +309,13 @@ def get_column(self, property_name: str) -> Optional[Column]: return None def _select_no_geo(self, properties: Optional[List[str]], addid: bool = True): - nocomma = False columns = self.columns(properties) - if columns: - sel = logic.as_sql_block(clauses.Select(columns)) - else: - sel = logic.as_sql_block(raw("SELECT ")) - nocomma = True + cols = self.T.cols(columns) if addid: - if self.id_column: - id_clause = logic.V(self.id_column).as_("tipg_id") - else: - id_clause = raw(" ROW_NUMBER () OVER () AS tipg_id ") - if nocomma: - sel = clauses.Clauses(sel, id_clause) - else: - sel = sel.comma(id_clause) + cols.append(self.T.create_tipg_id(self.id_column)) - return logic.as_sql_block(sel) + return Select(*cols) def _select( self, @@ -330,25 +326,22 @@ def _select( geom_as_wkt: bool = False, ): sel = self._select_no_geo(properties) - geom = self._geom(geometry_column, bbox_only, simplify) + + gcol = None if geom_as_wkt: if geom: - sel = sel.comma(logic.Func("ST_AsEWKT", geom).as_("tipg_geom")) + gcol = F("ST_AsEWKT", geom) else: - sel = sel.comma(pg_funcs.cast(None, "text").as_("tipg_geom")) + gcol = NULL("text") else: if geom: - sel = sel.comma( - pg_funcs.cast(logic.Func("ST_AsGeoJSON", geom), "json").as_( - "tipg_geom" - ) - ) + gcol = F("ST_AsGeoJSON", geom).Cast("json") else: - sel = sel.comma(pg_funcs.cast(None, "json").as_("tipg_geom")) + gcol = NULL("json") - return sel + return sel.AddColumns(gcol.As("tipg_geom")) def _select_mvt( self, @@ -358,79 +351,57 @@ def _select_mvt( tile: Tile, ): """Create MVT from intersecting geometries.""" - geom = pg_funcs.cast(logic.V(geometry_column.name), "geometry") + + geom = self.T.get(geometry_column.name).Cast("geometry") # make sure the geometries do not overflow the TMS bbox if not tms.is_valid(tile): - geom = logic.Func( + geom = F( "ST_Intersection", - logic.Func("ST_MakeEnvelope", *tms.bbox, 4326), - logic.Func( - "ST_Transform", - geom, - pg_funcs.cast(4326, "int"), - ), + Bbox(tms.box), + Transform(geom), ) # Transform the geometries to TMS CRS using EPSG code if tms_srid := tms.crs.to_epsg(): - transform_logic = logic.Func( - "ST_Transform", - geom, - pg_funcs.cast(tms_srid, "int"), - ) + transform_logic = Transform(geom, tms_srid) # Transform the geometries to TMS CRS using PROJ String else: tms_proj = tms.crs.to_proj4() - transform_logic = logic.Func( - "ST_Transform", - geom, - pg_funcs.cast(tms_proj, "text"), - ) + transform_logic = Transform(geom, tms_proj) bbox = tms.xy_bounds(tile) - sel = self._select_no_geo(properties, addid=False).comma( - logic.Func( + sel = self._select_no_geo(properties, addid=False).AddColumns( + F( "ST_AsMVTGeom", transform_logic, - logic.Func( + F( "ST_Segmentize", - logic.Func( - "ST_MakeEnvelope", - bbox.left, - bbox.bottom, - bbox.right, - bbox.top, - ), + Bbox(bbox), bbox.right - bbox.left, ), mvt_settings.tile_resolution, mvt_settings.tile_buffer, mvt_settings.tile_clip, - ).as_("geom") + ).As("geom") ) return sel def _select_count(self): - return clauses.Select(pg_funcs.count("*")) + return Select(Count()) def _from(self, function_parameters: Optional[Dict[str, str]]): if self.type == "Function": if not function_parameters: - return clauses.From(self.id) + raw("()") + return F(self.id) params = [] for p in self.parameters: if p.name in function_parameters: - params.append( - pg_funcs.cast( - pg_funcs.cast(function_parameters[p.name], "text"), - p.type, - ) - ) - return clauses.From(logic.Func(self.id, *params)) - return clauses.From(self.id) + params.append(P(function_parameters[p.name]).Cast(p.type)) + return F(self.id, *params) + return self.T def _geom( self, @@ -441,23 +412,19 @@ def _geom( if geometry_column is None: return None - g = pg_funcs.cast(logic.V(geometry_column.name), "geometry") + g = self.T.get(geometry_column.name).Cast("geometry") - # Reproject to WGS64 if needed + # Reproject to WGS84 if needed if geometry_column.srid != 4326: - g = logic.Func("ST_Transform", g, pg_funcs.cast(4326, "int")) + g = Transform(g) # Return BBOX Only if bbox_only: - g = logic.Func("ST_Envelope", g) + g = F("ST_Envelope", g) # Simplify the geometry elif simplify: - g = logic.Func( - "ST_SnapToGrid", - logic.Func("ST_Simplify", g, simplify), - simplify, - ) + g = simplified(g, simplify) return g @@ -467,35 +434,23 @@ def _where( # noqa: C901 datetime: Optional[List[str]] = None, bbox: Optional[List[float]] = None, properties: Optional[List[Tuple[str, Any]]] = None, - cql: Optional[AstType] = None, + cql: Optional[Any] = None, geom: Optional[str] = None, dt: Optional[str] = None, tile: Optional[Tile] = None, tms: Optional[TileMatrixSet] = None, ): """Construct WHERE query.""" - wheres = [logic.S(True)] + wheres = [] # `ids` filter if ids is not None: - if len(ids) == 1: - wheres.append( - logic.V(self.id_column) - == pg_funcs.cast( - pg_funcs.cast(ids[0], "text"), self.id_column_info.type - ) - ) - else: - w = [ - logic.V(self.id_column) - == logic.S( - pg_funcs.cast( - pg_funcs.cast(i, "text"), self.id_column_info.type - ) - ) - for i in ids - ] - wheres.append(pg_funcs.OR(*w)) + ids = ensure_list(ids) + w = [ + self.T.get(self.id_column) == P(i).Cast(self.id_column_info.type) + for i in ids + ] + wheres.append(Or(*w)) # `properties filter if properties is not None: @@ -504,23 +459,21 @@ def _where( # noqa: C901 col = self.get_column(prop) if not col: raise InvalidPropertyName(f"Invalid property name: {prop}") + dbcol = self.T.get(col.name) - w.append( - logic.V(col.name) - == logic.S(pg_funcs.cast(pg_funcs.cast(val, "text"), col.type)) - ) + w.append(dbcol == P(val).Cast(col.type)) if w: - wheres.append(pg_funcs.AND(*w)) + wheres.append(And(*w)) # `bbox` filter geometry_column = self.get_geometry_column(geom) if bbox is not None and geometry_column is not None: wheres.append( - logic.Func( + F( "ST_Intersects", - logic.S(bbox_to_wkt(bbox)), - logic.V(geometry_column.name), + Bbox(bbox), + self.T.get(geometry_column.name), ) ) @@ -539,7 +492,10 @@ def _where( # noqa: C901 # `CQL` filter if cql is not None: - wheres.append(to_filter(cql, [p.name for p in self.properties])) + print('ADDING CQL FILTER', cql) + cqlt = CQL2SQL(self) + print('CQLT', cqlt) + wheres.append(cqlt.sql(cql)) if tile and tms and geometry_column: # Get Tile Bounds in Geographic CRS (usually epsg:4326) @@ -550,41 +506,29 @@ def _where( # noqa: C901 right, top = tms.truncate_lnglat(right, top) wheres.append( - logic.Func( + F( "ST_Intersects", - logic.Func( - "ST_Transform", - logic.Func( + Transform( + F( "ST_Segmentize", - logic.Func( - "ST_MakeEnvelope", - left, - bottom, - right, - top, - 4326, - ), + Bbox((left, bottom, right, top)), right - left, ), - pg_funcs.cast(geometry_column.srid, "int"), + geometry_column.srid, ), - logic.V(geometry_column.name), + self.T.get(geometry_column.name), ) ) - return clauses.Where(pg_funcs.AND(*wheres)) + return wheres def _datetime_filter_to_sql(self, interval: List[str], dt_name: str): + datecol = self.T.get(dt_name) if len(interval) == 1: - return logic.V(dt_name) == logic.S( - pg_funcs.cast(parse_rfc3339(interval[0]), "timestamptz") - ) - + return datecol == date_param(interval[0]) else: - start = ( - parse_rfc3339(interval[0]) if interval[0] not in ["..", ""] else None - ) - end = parse_rfc3339(interval[1]) if interval[1] not in ["..", ""] else None + start = interval[0] if interval[0] not in ["..", ""] else None + end = interval[1] if interval[1] not in ["..", ""] else None if start is None and end is None: raise InvalidDatetime( @@ -595,47 +539,49 @@ def _datetime_filter_to_sql(self, interval: List[str], dt_name: str): raise InvalidDatetime("Start datetime cannot be before end datetime.") if not start: - return logic.V(dt_name) <= logic.S(pg_funcs.cast(end, "timestamptz")) + return datecol <= date_param(end) elif not end: - return logic.V(dt_name) >= logic.S(pg_funcs.cast(start, "timestamptz")) + return datecol >= date_param(start) else: - return pg_funcs.AND( - logic.V(dt_name) >= logic.S(pg_funcs.cast(start, "timestamptz")), - logic.V(dt_name) < logic.S(pg_funcs.cast(end, "timestamptz")), - ) + return And(datecol >= date_param(start), datecol < date_param(end)) def _sortby(self, sortby: Optional[str]): + print('SORTBY', sortby) sorts = [] if sortby: for s in sortby.strip().split(","): - parts = re.match( - "^(?P[+-]?)(?P.*)$", s - ).groupdict() # type:ignore + parts = re.match("^(?P[+-]?)(?P.*)$", s).groupdict() # type:ignore direction = parts["direction"] column = parts["column"].strip() if self.get_column(column): + colexpr = self.T.get(column) if direction == "-": - sorts.append(logic.V(column).desc()) + sorts.append(colexpr.Desc()) else: - sorts.append(logic.V(column)) + sorts.append(colexpr.Asc()) else: raise InvalidPropertyName(f"Property {column} does not exist.") else: if self.id_column is not None: - sorts.append(logic.V(self.id_column)) + print('sorting by id column') + idcol = self.T.get(self.id_column) + print(idcol) + print(idcol.Asc()) + sorts.append(idcol.Asc()) else: - sorts.append(logic.V(self.properties[0].name)) - - return clauses.OrderBy(*sorts) + print('sorting by first column') + sorts.append(self.T.get(self.properties[0].name).Asc()) + print('SORTS', sorts) + return sorts async def _features_query( self, *, - pool: asyncpg.BuildPgPool, + pool: asyncpg.Pool, ids_filter: Optional[List[str]] = None, bbox_filter: Optional[List[float]] = None, datetime_filter: Optional[List[str]] = None, @@ -655,17 +601,16 @@ async def _features_query( """Build Features query.""" limit = limit or features_settings.default_features_limit offset = offset or 0 - - c = clauses.Clauses( + query = ( self._select( properties=properties, geometry_column=self.get_geometry_column(geom), bbox_only=bbox_only, simplify=simplify, geom_as_wkt=geom_as_wkt, - ), - self._from(function_parameters), - self._where( + ) + .From(self._from(function_parameters)) + .Where(*self._where( ids=ids_filter, datetime=datetime_filter, bbox=bbox_filter, @@ -673,13 +618,14 @@ async def _features_query( cql=cql_filter, geom=geom, dt=dt, - ), - self._sortby(sortby), - clauses.Limit(limit), - clauses.Offset(offset), + )) + .OrderBy(*self._sortby(sortby)) + .Limit(limit) + .Offset(offset) ) - q, p = render(":c", c=c) + q, p = build(query) + async with pool.acquire() as conn: for r in await conn.fetch(q, *p): props = dict(r) @@ -691,7 +637,7 @@ async def _features_query( async def _features_count_query( self, *, - pool: asyncpg.BuildPgPool, + pool: asyncpg.Pool, ids_filter: Optional[List[str]] = None, bbox_filter: Optional[List[float]] = None, datetime_filter: Optional[List[str]] = None, @@ -702,28 +648,27 @@ async def _features_count_query( function_parameters: Optional[Dict[str, str]], ) -> int: """Build features COUNT query.""" - c = clauses.Clauses( - self._select_count(), - self._from(function_parameters), - self._where( - ids=ids_filter, - datetime=datetime_filter, - bbox=bbox_filter, - properties=properties_filter, - cql=cql_filter, - geom=geom, - dt=dt, - ), - ) + query = self._select_count( + ).From( + self._from(function_parameters) + ).Where(*self._where( + ids=ids_filter, + datetime=datetime_filter, + bbox=bbox_filter, + properties=properties_filter, + cql=cql_filter, + geom=geom, + dt=dt, + )) - q, p = render(":c", c=c) + q, p = build(query) async with pool.acquire() as conn: count = await conn.fetchval(q, *p) return count async def features( self, - pool: asyncpg.BuildPgPool, + pool: asyncpg.Pool, *, ids_filter: Optional[List[str]] = None, bbox_filter: Optional[List[float]] = None, @@ -798,7 +743,7 @@ async def features( async def get_tile( self, *, - pool: asyncpg.BuildPgPool, + pool: asyncpg.Pool, tms: TileMatrixSet, tile: Tile, ids_filter: Optional[List[str]] = None, @@ -824,38 +769,46 @@ async def get_tile( raise InvalidLimit( f"Limit can not be set higher than the `tipg_max_features_per_tile` setting of {mvt_settings.max_features_per_tile}" ) + mvtlayername = ( + self.table if mvt_settings.set_mvt_layername is True else "default" + ) - c = clauses.Clauses( + baseq = ( self._select_mvt( properties=properties, geometry_column=geometry_column, tms=tms, tile=tile, - ), - self._from(function_parameters), - self._where( - ids=ids_filter, - datetime=datetime_filter, - bbox=bbox_filter, - properties=properties_filter, - cql=cql_filter, - geom=geom, - dt=dt, - tms=tms, - tile=tile, - ), - clauses.Limit(limit), + ) + .From(self._from(function_parameters)) + .Where( + *self._where( + ids=ids_filter, + datetime=datetime_filter, + bbox=bbox_filter, + properties=properties_filter, + cql=cql_filter, + geom=geom, + dt=dt, + tms=tms, + tile=tile, + ) + ) + .Limit(limit).Subquery('baseq') ) + print('BASEQ', build(baseq, 'raw')) + query = With(baseq).Select( + F( + "ST_AsMVT", + pgmini.raw.Raw('baseq.*'), + mvtlayername + ) + ).From(baseq) - q, p = render( - """ - WITH - t AS (:c) - SELECT ST_AsMVT(t.*, :l) FROM t - """, - c=c, - l=self.table if mvt_settings.set_mvt_layername is True else "default", - ) + q, p = build(query) + print('WITHQUERY', query) + print('-----') + print(build(query, 'raw')) debug_query(q, *p) async with pool.acquire() as conn: @@ -898,7 +851,7 @@ class Catalog(TypedDict): async def get_collection_index( # noqa: C901 - db_pool: asyncpg.BuildPgPool, + db_pool: asyncpg.Pool, schemas: Optional[List[str]] = None, tables: Optional[List[str]] = None, exclude_tables: Optional[List[str]] = None, @@ -913,35 +866,24 @@ async def get_collection_index( # noqa: C901 """Fetch Table and Functions index.""" schemas = schemas or ["public"] - query = """ - SELECT pg_temp.tipg_catalog( - :schemas, - :tables, - :exclude_tables, - :exclude_table_schemas, - :functions, - :exclude_functions, - :exclude_function_schemas, - :spatial, - :spatial_extent, - :datetime_extent - ); - """ # noqa: W605 + queryf = F( + "pg_temp.tipg_catalog", + schemas, + tables, + exclude_tables, + exclude_table_schemas, + functions, + exclude_functions, + exclude_function_schemas, + spatial, + spatial_extent, + datetime_extent, + ) + query = Select(queryf.STAR).From(queryf) + q, p = build(query) async with db_pool.acquire() as conn: - rows = await conn.fetch_b( - query, - schemas=schemas, - tables=tables, - exclude_tables=exclude_tables, - exclude_table_schemas=exclude_table_schemas, - functions=functions, - exclude_functions=exclude_functions, - exclude_function_schemas=exclude_function_schemas, - spatial=spatial, - spatial_extent=spatial_extent, - datetime_extent=datetime_extent, - ) + rows = await conn.fetch(q, *p) catalog: Dict[str, Collection] = {} table_settings = TableSettings() diff --git a/tipg/database.py b/tipg/database.py index e7cac221..ca59e06d 100644 --- a/tipg/database.py +++ b/tipg/database.py @@ -4,7 +4,7 @@ from typing import List, Optional import orjson -from buildpg import asyncpg +import asyncpg from tipg.logger import logger from tipg.settings import PostgresSettings @@ -80,7 +80,7 @@ async def connect_to_db( con_init = connection_factory(schemas, user_sql_files) - app.state.pool = await asyncpg.create_pool_b( + app.state.pool = await asyncpg.create_pool( str(settings.database_url), min_size=settings.db_min_conn_size, max_size=settings.db_max_conn_size, diff --git a/tipg/dependencies.py b/tipg/dependencies.py index c8d7b676..323ad11f 100644 --- a/tipg/dependencies.py +++ b/tipg/dependencies.py @@ -6,9 +6,8 @@ from ciso8601 import parse_rfc3339 from morecantile import Tile from morecantile import tms as default_tms -from pygeofilter.ast import AstType -from pygeofilter.parsers.cql2_json import parse as cql2_json_parser -from pygeofilter.parsers.cql2_text import parse as cql2_text_parser +from pycql2.cql2_transformer import parser, transformer +from pycql2.cql2_pydantic import BooleanExpression from typing_extensions import Annotated from tipg.collections import Catalog, Collection, CollectionList @@ -189,7 +188,7 @@ def bbox_query( bbox: Annotated[ Optional[str], Query(description="Spatial Filter."), - ] = None + ] = None, ) -> Optional[List[float]]: """BBox dependency.""" if bbox: @@ -290,14 +289,19 @@ def filter_query( alias="filter-lang", ), ] = None, -) -> Optional[AstType]: +) -> Optional[BooleanExpression]: """Parse Filter Query.""" + print('PARSING CQL2', type(query), filter_lang, query) if query is not None: if filter_lang == "cql2-json": - return cql2_json_parser(query) + print('PARSING AS JSON') + model = BooleanExpression.model_validate_json(query) + print('MODEL', model) + return model # default to cql2-text - return cql2_text_parser(query) + print('PARSING AS TEXT') + return transformer.transform(parser.parse(query)) return None diff --git a/tipg/filter/cql2sql.py b/tipg/filter/cql2sql.py new file mode 100644 index 00000000..9688ef6f --- /dev/null +++ b/tipg/filter/cql2sql.py @@ -0,0 +1,519 @@ +"""Tools to convert CQL2 into PostgreSQL SQL.""" +from datetime import date, datetime +from inspect import signature +import re +from typing import Any, Callable, Dict, List +from typing import Literal as TypeLiteral +from typing import Optional, Tuple, Union + +import pgmini +from pgmini.utils import CompileABC +from plum import dispatch, overload +from pycql2.cql2_pydantic import ( + Accenti, + AndOrExpression, + ArithmeticExpression, + Array, + ArrayExpression, + ArrayPredicate, + BboxLiteral, + BinaryComparisonPredicate, + BooleanExpression, + Casei, + CharacterExpression, + DateInstant, + Function, + FunctionRef, + GeometryLiteral, + IntervalInstance, + IsBetweenPredicate, + IsInListPredicate, + IsLikePredicate, + IsNullPredicate, + NotExpression, + NumericExpression, + PatternExpression, + PropertyRef, + SpatialPredicate, + TemporalPredicate, + TimestampInstant, +) +from pycql2.cql2_transformer import parser, transformer +from pydantic import BaseModel +from tipg import collections + +from tipg.query import NULL, F, Table, build, strip_ident, ensure_list, P, Param + +transform = transformer.transform +parse = parser.parse + + + + + +class Operator: + """Filter Operators.""" + + OPERATORS: Dict[str, Callable] = { + "==": lambda f, a: f == a, + "=": lambda f, a: f == a, + "eq": lambda f, a: f == a, + "!=": lambda f, a: f != a, + "<>": lambda f, a: f != a, + "ne": lambda f, a: f != a, + ">": lambda f, a: f > a, + "gt": lambda f, a: f > a, + "<": lambda f, a: f < a, + "lt": lambda f, a: f < a, + ">=": lambda f, a: f >= a, + "ge": lambda f, a: f >= a, + "<=": lambda f, a: f <= a, + "le": lambda f, a: f <= a, + "+": lambda f, a: f + a, + "-": lambda f, a: f - a, + "*": lambda f, a: f * a, + "/": lambda f, a: f / a, + "div": lambda f, a: f / a, + "^": lambda f, a: F("power", f, a), + "%": lambda f, a: F("mod", f, a), + } + + def __init__(self, operator: Optional[str] = None) -> None: + """Init.""" + if not operator: + operator = "==" + + if operator not in self.OPERATORS: + msg = f"Operator `{operator}` not valid." + raise Exception(msg) + self.operator = operator + self.function = self.OPERATORS[operator] + self.arity = len(signature(self.function).parameters) + + +class CQL2SQL: + """Class to convert CQL2 to SQL.""" + + def __init__(self, collection: Optional["Collection"] = None) -> None: + """Init Class.""" + if collection is not None: + self.collection = collection + cols={} + for ccol in collection.properties: + print('CCOL', ccol) + cols[ccol.name] = ccol.type + # cols = {f'{ccol.name}': ccol.type for ccol in collection.properties} + print(cols) + self._cols=cols + tablecls = type( + self.collection.table, + (Table,), + cols, + ) + self.table = tablecls(self.collection.table) + else: + self.collection = None + self.table = Table("mytable") + + def col(self, c: Union[str, PropertyRef]) -> pgmini.column.Column: + """Return a column.""" + if isinstance(c, PropertyRef): + return self.table.get(c.property) + return self.table.get(c) + + def get_args(self, e: BaseModel) -> List[CompileABC]: + """Sqlify all args, always return as list.""" + if hasattr(e, "args"): + return [self.sql(arg) for arg in ensure_list(e.args)] + return [] + + @overload + def sql(self, e: NotExpression) -> CompileABC: # type: ignore[no-redef] + """Get Not Expression.""" + args = self.get_args(e) + return pgmini.operators.Not(args[0]) + + def casei_accenti_arg(self, e): + """Checks if arg is case insensitive.""" + casei = False + accenti = False + + if isinstance(e, Casei): + arg = self.sql(e.casei) + casei = True + elif hasattr(e, "root") and isinstance(e.root, Casei): + arg = self.sql(e.root.casei) + casei = True + elif isinstance(e, Accenti): + arg = self.sql(e.accenti) + accenti = True + elif hasattr(e, "root") and isinstance(e.root, Accenti): + arg = self.sql(e.root.accenti) + accenti = True + else: + arg = self.sql(e) + return arg, casei, accenti + + def get_args_casei_accenti(self, e) -> Tuple[Any, bool, bool]: + """Check if arg is case insensitive.""" + args, casei, accenti = zip( + *[self.casei_accenti_arg(arg) for arg in ensure_list(e.args)], + ) + return args, any(casei), any(accenti) + + @overload + def sql(self, e: IsLikePredicate) -> CompileABC: # type: ignore[no-redef] + """Get Like Expression.""" + args, useilike, unaccent = self.get_args_casei_accenti(e) + if unaccent: + args = [F("unaccent", arg) for arg in args] + left, right = args + if useilike: + return left.Ilike(right) + return left.Like(right) + + @overload + def sql(self, e: IsBetweenPredicate) -> CompileABC: # type: ignore[no-redef] + """Get Between Expression.""" + left, low, high = self.get_args(e) + return left.Between(low, high) + + @overload + def sql(self, e: IsInListPredicate) -> CompileABC: # type: ignore[no-redef] + """Get In Expression.""" + left = self.sql(e.args[0]) + args = [(left == self.sql(arg)) for arg in e.args[1]] + # args = pgmini.array.Array([self.sql(arg) for arg in e.args[1]]) + # return left.Any(args) + return pgmini.Or(*args) + + @overload + def sql(self, e: IsNullPredicate) -> CompileABC: # type: ignore[no-redef] + """Get Null Expression.""" + print('getting null expression', e) + left = self.get_args(e)[0] + return left.Is(NULL()) + + @overload + def sql(self, e: BinaryComparisonPredicate) -> CompileABC: # type: ignore[no-redef] + """Get binary comparisons expression.""" + print('getting binary comparison predicate', e) + args, casei, unaccent = self.get_args_casei_accenti(e) + if casei: + args = [F("lower", arg) for arg in args] + if unaccent: + args = [F("unaccent", arg) for arg in args] + op = Operator(e.op) + return op.function(*args) + + @overload + def sql(self, e: ArithmeticExpression) -> CompileABC: # type: ignore[no-redef] + """Get operators expression.""" + args = self.get_args(e) + op = Operator(e.op) + return op.function(*args) + + @overload + def sql(self, e: AndOrExpression) -> CompileABC: # type: ignore[no-redef] + """Get and/or expression.""" + print('getting and/or', e) + args = self.get_args(e) + if e.op == "or": + return pgmini.Or(*args) + return pgmini.And(*args) + + @overload + def sql(self, e: BooleanExpression) -> CompileABC: # type: ignore[no-redef] + """Get boolean expression.""" + print('getting boolean', e) + if isinstance(e.root, bool): + if e.root: + return P(True) + return P(False) + return self.sql(e.root) + + @overload + def sql(self, e: PropertyRef) -> CompileABC: # type: ignore[no-redef] + """Get property expression.""" + return self.col(e.property) + + @overload + def sql( # type: ignore[no-redef] + self, + e: NumericExpression, + ) -> CompileABC: + """Get buildsql for character expression.""" + print('BUILDING Numeric EXPRESSION') + print(self, e, type(e)) + if isinstance(e, float): + if e.is_integer(): + print('Is Integer') + e = int(e) + if hasattr(e, "root"): + print(e.root, type(e.root)) + if isinstance(e.root, PropertyRef): + return self.sql(e.root) + return P(e.root) + return P(e) + @overload + def sql( # type: ignore[no-redef] + self, + e: Union[ + CharacterExpression, + PatternExpression + ], + ) -> CompileABC: + """Get buildsql for character expression.""" + print('BUILDING EXPRESSION') + print(e, type(e)) + if hasattr(e, "root"): + print(e.root, type(e.root)) + if isinstance(e.root, PropertyRef): + return self.sql(e.root) + return P(e.root) + return P(e) + + @overload + def sql(self, e: BboxLiteral) -> CompileABC: # type: ignore[no-redef] + """Get BBox expression.""" + box = e.bbox + if len(box) == 4: + return F("ST_MAKEENVELOPE", *[P(b) for b in box], P(4326)) + if len(box) == 6: + return F( + "ST_MAKEENVELOPE", + P(box[0]), + P(box[1]), + P(box[3]), + P(box[4]), + P(4326), + ) + return None + + @overload + def sql(self, e: GeometryLiteral) -> CompileABC: # type: ignore[no-redef] + """Get wkt expression for geometry.""" + wkt = e.root.wkt + if not wkt.startswith("SRID"): + wkt = "SRID=4326;" + wkt + return P(wkt).Cast("geometry") + + def get_collection_geom_info( + self, + col = None, + ) -> Union[Tuple[None, None], Tuple[int, str]]: + """Get geometry/geography and srid from collectin.""" + print('get_collection_geom', col, type(col)) + if self.collection and col and hasattr(col, '_name'): + name = col._name + ccol = self.collection.get_column(strip_ident(name)) + if ccol: + return ccol.srid, ccol.type + elif isinstance(col, Param): + matches = re.match(r'SRID=(\d+);.*', col._value) + print('MATCHES', matches) + if matches: + return int(matches.group(1)), 'geometry' + return None, None + + @overload + def sql(self, e: SpatialPredicate) -> CompileABC: # type: ignore[no-redef] + """Get buildsql for spatial predicate.""" + print('getting spatial args', e.model_dump_json()) + op = e.op.upper().replace("S_", "ST_") + args = self.get_args(e) + left = args[0] + right = args[1] + print(left, right) + + lsrid, ltyp = self.get_collection_geom_info(left) + rsrid, rtyp = self.get_collection_geom_info(right) + print('Types/SRIDS', lsrid, ltyp, rsrid, rtyp) + + if ( + (lsrid == rsrid and ltyp == rtyp) + or (lsrid == 4326 and ltyp == "geometry" and rtyp is None) + or (rsrid == 4326 and rtyp == "geometry" and ltyp is None) + ): + return F(op, left, right) + + if ltyp == "geography" and rtyp is None: + right = right.Cast("geography") + elif lsrid != 4326 and rsrid is None: + right = F("ST_TRANSFORM", right, lsrid) + elif rtyp == "geography" and ltyp is None: + left = left.Cast("geography") + elif rsrid != 4326 and lsrid is None: + left = F("ST_TRANSFORM", left, rsrid) + elif ltyp == "geography" and rtyp == "geometry": + right = right.Cast("geography") + elif ltyp == "geometry" and rtyp == "geography": + left = left.Cast("geography") + elif (lsrid != rsrid) or (lsrid is None and rsrid is None): + right = F("ST_TRANSFORM", right, F("ST_SRID", left)) + + return F(op, left, right) + + temporal_opposites = { + "T_AFTER": "T_BEFORE", + "T_METBY": "T_MEETS", + "T_OVERLAPPEDBY": "T_OVERLAPS", + "T_STARTEDBY": "T_STARTS", + "T_CONTAINS": "T_DURING", + "T_FINISHEDBY": "T_FINISHES", + } + + temporal_ops = { + "T_BEFORE": lambda ll, lh, rl, rh: lh < rl, + "T_MEETS": lambda ll, lh, rl, rh: lh == rl, + "T_OVERLAPS": lambda ll, lh, rl, rh: pgmini.And(ll < rl, lh > rh, lh < rh), + "T_STARTS": lambda ll, lh, rl, rh: pgmini.And(ll == rl, lh < rh), + "T_DURING": lambda ll, lh, rl, rh: pgmini.And(ll > rl, lh < rh), + "T_FINISHES": lambda ll, lh, rl, rh: pgmini.And(ll > rl, lh == rh), + "T_EQUALS": lambda ll, lh, rl, rh: pgmini.And(ll == rl, lh == rh), + "T_DISJOINT": lambda ll, lh, rl, rh: pgmini.Or(ll > rh, lh < rl), + "T_INTERSECTS": lambda ll, lh, rl, rh: pgmini.And(ll <= rh, lh >= rl), + } + + @overload + def sql(self, e: DateInstant) -> CompileABC: # type: ignore[no-redef] + """Get date instant date.""" + return self.sql(e.date) + + @overload + def sql(self, e: TimestampInstant) -> CompileABC: # type: ignore[no-redef] + """Get TimestampInstant datetime.""" + return self.sql(e.timestamp) + + @overload + def sql(self, e: date) -> CompileABC: # type: ignore[no-redef] + """Get date expression.""" + return P(e.isoformat()).Cast("date") + + @overload + def sql(self, e: datetime) -> CompileABC: # type: ignore[no-redef] + """Get datetime expression.""" + return P(e.isoformat()).Cast("timestamptz") + + @overload + def sql(self, e: IntervalInstance) -> Tuple[CompileABC, CompileABC]: # type: ignore[no-redef] + """Get Interval Expression.""" + if e.interval[0].root == "..": + lower = P("-infinity").Cast("timestamptz") + else: + lower = self.sql(e.interval[0].root) + + if e.interval[1].root == "..": + upper = P("infinity").Cast("timestamptz") + else: + upper = self.sql(e.interval[1].root) + return lower, upper + + def get_temporal_args(self, arg) -> Tuple[CompileABC, CompileABC]: + """Get Temporal Arguments.""" + if isinstance(arg, PropertyRef): + colout = self.col(arg) + if self.collection: + col = self.collection.get_column(arg._name.strip('"')) + if col.type.endswith("range"): + return F("lower", self.col(colout)), F("upper", self.col(colout)) + else: + return colout, colout + if isinstance(arg, IntervalInstance): + return self.sql(arg) + if isinstance( + arg, + ( + DateInstant, + TimestampInstant, + ), + ): + val = self.sql(arg) + return val, val + return None + + @overload + def sql(self, e: TemporalPredicate) -> CompileABC: # type: ignore[no-redef] + """Get buildsql for temporal predicate.""" + op = e.op.upper() + left = self.get_temporal_args(e.args[0]) + right = self.get_temporal_args(e.args[1]) + if op in self.temporal_opposites: + op = self.temporal_opposites[op] + tmp = right + right = left + left = tmp + if op == "ANYINTERACTS": + op = "T_INTERSECTS" + ll, lh = left + rl, rh = right + return self.temporal_ops[op](ll, lh, rl, rh) + + @overload + def sql(self, e: Array) -> CompileABC: # type: ignore[no-redef] + """Get Array Expression.""" + print('Getting Array Expression', e.root) + vals = [self.sql(val) for val in e.root] + return pgmini.array.Array(vals) + + @overload + def sql(self, e: ArrayExpression) -> Tuple[CompileABC, CompileABC]: # type: ignore[no-redef] + """Get root of array expression.""" + print('getting root of array expression', e) + tuple = e.root + return self.sql(tuple[0]), self.sql(tuple[1]) + + @overload + def sql(self, e: ArrayPredicate) -> CompileABC: # type: ignore[no-redef] + """Get Array predicate expression.""" + op = e.op.upper() + left, right = self.sql(e.args) + if op == "A_CONTAINEDBY": + return left.Op("<@", right) + if op == "A_CONTAINS": + return left.Op("@>", right) + if op == "A_EQUALS": + return left == right + if op == "A_OVERLAPS": + return left.Op("&&", right) + return None + + @overload + def sql(self, e: FunctionRef) -> CompileABC: # type: ignore[no-redef] + """Get reference to function.""" + return self.sql(e.function) + + @overload + def sql(self, e: Function) -> CompileABC: # type: ignore[no-redef] + """Get Function Expression.""" + op = e.name if hasattr(e, "name") else e.op + args = self.get_args(e) + return F(op, *args) + + @dispatch + def sql(self, e): # type: ignore[no-redef] + """Fall through.""" + pass + + +def cql2pgmini( + query: str, + collection: Optional["Collection"] = None +) -> Union[Tuple[str, List[Any]], str]: + """Convert cql2 text into Pgmini expression.""" + print('PARSING:', query) + cql = transform(parse(query)) + T = CQL2SQL(collection) + return T.sql(cql) + + +def cql2sql( + query: str, + collection: Optional["Collection"] = None, + driver: TypeLiteral["asyncpg", "psycopg", "raw"] = "asyncpg", + table_in_column: bool = False, +) -> Union[Tuple[str, List[Any]], str]: + """Convert cql2 text into Postgres SQL.""" + out = cql2pgmini(query, collection) + + return build(out, driver=driver, table_in_column=table_in_column) diff --git a/tipg/query.py b/tipg/query.py new file mode 100644 index 00000000..a0382226 --- /dev/null +++ b/tipg/query.py @@ -0,0 +1,282 @@ +"""Helpers for using pgmini.""" +import re +from contextvars import copy_context +from typing import List +from typing import Literal as TypeLiteral +from typing import Optional, Tuple, Union +import attrs +from pgmini.marks import Marks + +import pgmini +from pgmini.alias import extract_alias +from pgmini.utils import ( + CTX_ALIAS_ONLY, + CTX_CTE, + CTX_DISABLE_TABLE_IN_COLUMN, + CTX_FORCE_CAST_BRACKETS, + CTX_TABLES, + CompileABC, +) +from datetime import date, datetime + +def is_integer(n): + """Check if a value is an integer.""" + try: + float(n) + except ValueError: + return False + else: + return float(n).is_integer() + + +def NULL(type: Optional[str] = None): + """Return typed NULL.""" + if type is None: + return pgmini.literal.NULL + return pgmini.literal.NULL.As(type) + + +def strip_ident(s: str) -> str: + """Strip quotes from identifier.""" + if s.startswith('"') and s.endswith('"'): + return s[1:-1] + return s + +def quote_ident_part(s: str) -> str: + """Quote Identifiers.""" + s = strip_ident(s) + s = s.strip() + if s in ('AS','ASC','DESC'): + return s + if re.match(r"^[a-z][a-z_]*$", s): + return s + if re.match(r"^[a-zA-Z][\w\d_]*$", s): + return f'"{s}"' + raise TypeError(f"{s} is not a valid identifier") + + +def quote_ident(s: str) -> str: + """Quote qualified identifiers.""" + outspacearr = [] + splitbyspace = s.split(" ") + for spacesplit in splitbyspace: + splitbycast = spacesplit.split("::") + outsplitarr=[] + for castsplit in splitbycast: + splitbydot = castsplit.split(".") + outsplitarr.append(".".join(map(quote_ident_part, splitbydot))) + outbysplit = "::".join(outsplitarr) + outspacearr.append(outbysplit) + out = " ".join(outspacearr) + return out + + +def F(name: str, *args): + """Run Postgres Function.""" + if re.match(r"^[a-zA-Z_]+(\.[a-zA-Z_]+)?$", name): + return pgmini.func._Func(x_name=name, x_params=args) + else: + raise TypeError( + f"Cannot Create {name}" "Only functions that match ^[a-zA-Z_]+ allowed", + ) + + +def Transform(g, srid: Union[int, str] = 4326): + """Transform geometry.""" + if is_integer(srid): + return F("ST_Transform", g, P(srid).Cast("int")) + else: + return F("ST_Transform", g, P(srid).Cast("text")) + + +def Bbox(box, srid: int = 4326): + """Return Bounding Box.""" + print('BBOX', box, type(box)) + box = list(box) + #if isinstance(box, (list, tuple)): + if len(box) == 4: + left, bottom, right, top = box + elif len(box) == 6: + left = box[0] + bottom = box[1] + right = box[3] + top = box[4] + # else: + # left = box.left + # bottom = box.bottom + # right = box.right + # top = box.top + + out = F("ST_MakeEnvelope", left, bottom, right, top, srid) + print(out) + return out + +def Count(): + """Return Count.""" + return F('count','*') + +def date_param(val): + """Make a parameter from date/time""" + if isinstance(val, (date, datetime)): + val = val.isoformat() + return P(val).Cast("timestamptz") + +def simplified(geom, tolerance): + return F( + "ST_SnapToGrid", + F("ST_Simplify", geom, tolerance), + tolerance, + ) + +def row_num(alias: str ='row'): + """Return Row Number.""" + return F("row_number").Over().As(alias) + +class Table(pgmini.Table): + """PgMini Table with useful functions.""" + + def get(self, attr: str) -> pgmini.column.Column: + """Get attribute via string.""" + return C(attr, self) + + def create_tipg_id(self, id_column: str): + """Create ID column using existing primary key or row number.""" + if id_column: + id_column_col = Column(id_column, self) + return id_column_col.As("tipg_id") + return row_num("tipg_id") + + def cols(self, colnames: List[str]): + """Return pgmini columns from list of names.""" + return [self.get(c) for c in colnames] + + + + +class Column(pgmini.column.Column): + """PGMini Column extended to ensure identifier quoting.""" + + def _build(self, params: list | dict) -> str: + out = super()._build(params) + return quote_ident(out) + + def Desc(self): + if self._marks: + marks = attrs.evolve(self._marks, order_by='DESC') + else: + marks = Marks(order_by='DESC') + return attrs.evolve(self, x_marks=marks) + + def Asc(self): + if self._marks: + marks = attrs.evolve(self._marks, order_by='ASC') + else: + marks = Marks(order_by='ASC') + return attrs.evolve(self, x_marks=marks) + + def NullsFirst(self): + if self._marks: + marks = attrs.evolve(self._marks, order_by_nulls='FIRST') + else: + marks = Marks(order_by_nulls='FIRST') + return attrs.evolve(self, x_marks=marks) + + def NullsLast(self): + if self._marks: + marks = attrs.evolve(self._marks, order_by_nulls='LAST') + else: + marks = Marks(order_by_nulls='LAST') + return attrs.evolve(self, x_marks=marks) + +class Param(pgmini.param.Param): + """Make sure that params with a text value have + an initial cast to text.""" + def _build(self, params: list | dict) -> str: + if alias := extract_alias(self): + return alias + + index = len(params) + 1 + if isinstance(params, list): + params.append(self._value) + res = '$%d' % index + else: + params[f'p{index}'] = self._value + res = f'%(p{index})s' + + if isinstance(self._value, str): + print('Param value is a string', self._value) + res = f'{res}::text' + + if self._marks: + res = self._marks.build(res) + print(f"Built Parm {self}, {self._value}, {type(self._value)}, {res}") + return res + +P = Param + + +def C(name: str, table: Optional[Table] = Table("t")): + """Return a pgmini column.""" + return Column(name, table) + + +def raw_query(q, *p): + """Utility to print raw statement to use for debugging.""" + qsub = re.sub(r"\$([0-9]+)", r"{\1}", q) + + def quote_str(s): + """Quote strings.""" + if s is None: + return "null" + elif isinstance(s, str): + return f"'{s}'" + else: + return s + + p = [quote_str(s) for s in p] + return qsub.format(None, *p) + + +def build( + item: CompileABC, + driver: TypeLiteral["asyncpg", "psycopg", "raw"] = "asyncpg", + table_in_column: bool = False, +) -> Union[Tuple[str, list], str]: + """Build a SQL Query from CQL2 pydantic model. + Return as raw SQL or as a tuple of sql and parameters + ready for asyncpg or psycopg parameter binding. + """ + + def run(): + CTX_FORCE_CAST_BRACKETS.set(False) + CTX_CTE.set(()) + CTX_TABLES.set(()) + CTX_ALIAS_ONLY.set(False) + CTX_DISABLE_TABLE_IN_COLUMN.set(not table_in_column) + + if driver == "psycopg": + params = {} + else: + params = [] + + query = item._build(params) + print(f"QUERY: {query}") + print("PARAMS", params) + if driver == "raw": + return raw_query(query, *params) + else: + return query, params + + return copy_context().run(run) + +def ensure_list(s) -> list: + """Makes sure that variable is treated as list.""" + if s is None: + return [] + if isinstance(s, list): + return s + if isinstance(s, set): + return list(s) + if isinstance(s, tuple): + return list(s) + return [s]