From ed00059c21ae1a97f6a4351d692f7a745a3fddeb Mon Sep 17 00:00:00 2001 From: Daniel M Date: Fri, 24 Feb 2023 13:17:09 -0500 Subject: [PATCH] Implement GEO commands --- fakeredis/_msgs.py | 1 + fakeredis/commands_mixins/geo_mixin.py | 152 ++++++++++++++++++------- test/test_mixins/test_geo_commands.py | 104 ++++++++++------- 3 files changed, 175 insertions(+), 82 deletions(-) diff --git a/fakeredis/_msgs.py b/fakeredis/_msgs.py index ab08960b..a5f11fdc 100644 --- a/fakeredis/_msgs.py +++ b/fakeredis/_msgs.py @@ -63,3 +63,4 @@ FLAG_NO_SCRIPT = 's' # Command not allowed in scripts FLAG_LEAVE_EMPTY_VAL = 'v' FLAG_TRANSACTION = 't' +GEO_UNSUPPORTED_UNIT = 'unsupported unit provided. please use M, KM, FT, MI' diff --git a/fakeredis/commands_mixins/geo_mixin.py b/fakeredis/commands_mixins/geo_mixin.py index dc4a6eae..80551a9d 100644 --- a/fakeredis/commands_mixins/geo_mixin.py +++ b/fakeredis/commands_mixins/geo_mixin.py @@ -1,37 +1,94 @@ import sys from collections import namedtuple -from typing import List, Optional, Any +from typing import List, Any from fakeredis import _msgs as msgs from fakeredis._command_args_parsing import extract_args -from fakeredis._commands import command, Key, Float +from fakeredis._commands import command, Key, Float, CommandItem from fakeredis._helpers import SimpleError from fakeredis._zset import ZSet from fakeredis.geo import geohash from fakeredis.geo.haversine import distance +UNIT_TO_M = {'km': 0.001, 'mi': 0.000621371, 'ft': 3.28084, 'm': 1} + def translate_meters_to_unit(unit_arg: bytes) -> float: - unit_str = unit_arg.decode().lower() - if unit_str == 'km': - unit = 0.001 - elif unit_str == 'mi': - unit = 0.000621371 - elif unit_str == 'ft': - unit = 3.28084 - else: # meter - unit = 1 + """number of meters in a unit. + :param unit_arg: unit name (km, mi, ft, m) + :returns: number of meters in unit + """ + unit = UNIT_TO_M.get(unit_arg.decode().lower()) + if unit is None: + raise SimpleError(msgs.GEO_UNSUPPORTED_UNIT) return unit GeoResult = namedtuple('GeoResult', 'name long lat hash distance') +def _parse_results( + items: List[GeoResult], + withcoord: bool, withdist: bool) -> List[Any]: + """Parse list of GeoResults to redis response + :param withcoord: include coordinates in response + :param withdist: include distance in response + :returns: Parsed list + """ + res = list() + for item in items: + new_item = [item.name, ] + if withdist: + new_item.append(Float.encode(item.distance, False)) + if withcoord: + new_item.append([Float.encode(item.long, False), + Float.encode(item.lat, False)]) + if len(new_item) == 1: + new_item = new_item[0] + res.append(new_item) + return res + + +def _find_near( + zset: ZSet, + lat: float, long: float, radius: float, + conv: float, count: int, count_any: bool, desc: bool) -> List[GeoResult]: + """Find items within area (lat,long)+radius + :param zset: list of items to check + :param lat: latitude + :param long: longitude + :param radius: radius in whatever units + :param conv: conversion of radius to meters + :param count: number of results to give + :param count_any: should we return any results that match? (vs. sorted) + :param desc: should results be sorted descending order? + :returns: List of GeoResults + """ + results = list() + for name, _hash in zset.items(): + p_lat, p_long, _, _ = geohash.decode(_hash) + dist = distance((p_lat, p_long), (lat, long)) * conv + if dist < radius: + results.append(GeoResult(name, p_long, p_lat, _hash, dist)) + if count_any and len(results) >= count: + break + results = sorted(results, key=lambda x: x.distance, reverse=desc) + if count: + results = results[:count] + return results + + class GeoCommandsMixin: # TODO - # GEORADIUS, GEORADIUS_RO, - # GEORADIUSBYMEMBER, GEORADIUSBYMEMBER_RO, # GEOSEARCH, GEOSEARCHSTORE + def _store_geo_results(self, item_name: bytes, geo_results: List[GeoResult], scoredist: bool) -> int: + db_item = CommandItem(item_name, self._db, item=self._db.get(item_name), default=ZSet()) + db_item.value = ZSet() + for item in geo_results: + val = item.distance if scoredist else item.hash + db_item.value.add(item.name, val) + db_item.writeback() + return len(geo_results) @command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,)) def geoadd(self, key, *args): @@ -83,42 +140,51 @@ def geodist(self, key, m1, m2, *args): unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1 return res * unit - def _parse_results( - self, items: List[GeoResult], - withcoord: bool, withdist: bool, withhash: bool, - count: Optional[int], desc: bool) -> List[Any]: - items = sorted(items, key=lambda x: x.distance, reverse=desc) - if count: - items = items[:count] - res = list() - for item in items: - new_item = [item.name, ] - if withdist: - new_item.append(self._encodefloat(item.distance, False)) - if withcoord: - new_item.append([self._encodefloat(item.long, False), - self._encodefloat(item.lat, False)]) - if len(new_item) == 1: - new_item = new_item[0] - res.append(new_item) - return res + def _search( + self, key, long, lat, radius, conv, + withcoord, withdist, withhash, count, count_any, desc, store, storedist): + zset = key.value + geo_results = _find_near(zset, lat, long, radius, conv, count, count_any, desc) + + if store: + self._store_geo_results(store, geo_results, scoredist=False) + return len(geo_results) + if storedist: + self._store_geo_results(storedist, geo_results, scoredist=True) + return len(geo_results) + ret = _parse_results(geo_results, withcoord, withdist) + return ret + + @command(name='GEORADIUS_RO', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,)) + def georadius_ro(self, key, long, lat, radius, *args): + (withcoord, withdist, withhash, count, count_any, desc), left_args = extract_args( + args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc',), + error_on_unexpected=False, left_from_first_unexpected=False) + count = count or sys.maxsize + conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 + return self._search( + key, long, lat, radius, conv, + withcoord, withdist, withhash, count, count_any, desc, False, False) @command(name='GEORADIUS', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,)) def georadius(self, key, long, lat, radius, *args): - zset = key.value - results = list() (withcoord, withdist, withhash, count, count_any, desc, store, storedist), left_args = extract_args( args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc', '*store', '*storedist'), error_on_unexpected=False, left_from_first_unexpected=False) - unit = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 count = count or sys.maxsize + conv = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 + return self._search( + key, long, lat, radius, conv, + withcoord, withdist, withhash, count, count_any, desc, store, storedist) - for name, _hash in zset.items(): - p_lat, p_long, _, _ = geohash.decode(_hash) - dist = distance((p_lat, p_long), (lat, long)) * unit - if dist < radius: - results.append(GeoResult(name, p_long, p_lat, _hash, dist)) - if count_any and len(results) >= count: - break + @command(name='GEORADIUSBYMEMBER', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)) + def georadiusbymember(self, key, member_name, radius, *args): + member_score = key.value.get(member_name) + lat, long, _, _ = geohash.decode(member_score) + return self.georadius(key, long, lat, radius, *args) - return self._parse_results(results, withcoord, withdist, withhash, count, desc) + @command(name='GEORADIUSBYMEMBER_RO', fixed=(Key(ZSet), bytes, Float), repeat=(bytes,)) + def georadiusbymember_ro(self, key, member_name, radius, *args): + member_score = key.value.get(member_name) + lat, long, _, _ = geohash.decode(member_score) + return self.georadius_ro(key, long, lat, radius, *args) diff --git a/test/test_mixins/test_geo_commands.py b/test/test_mixins/test_geo_commands.py index 7efa43bf..d23068ed 100644 --- a/test/test_mixins/test_geo_commands.py +++ b/test/test_mixins/test_geo_commands.py @@ -1,6 +1,10 @@ +from typing import Dict, Any + import pytest import redis +from test import testtools + def test_geoadd(r: redis.Redis): values = ((2.1909389952632, 41.433791470673, "place1") + @@ -10,13 +14,20 @@ def test_geoadd(r: redis.Redis): values = (2.1909389952632, 41.433791470673, "place1") assert r.geoadd("a", values) == 1 + values = ((2.1909389952632, 31.433791470673, "place1") + (2.1873744593677, 41.406342043777, "place2",)) assert r.geoadd("a", values, ch=True) == 2 assert r.zrange("a", 0, -1) == [b"place1", b"place2"] - with pytest.raises(redis.RedisError): + with pytest.raises(redis.DataError): r.geoadd("barcelona", (1, 2)) + with pytest.raises(redis.DataError): + r.geoadd("t", values, ch=True, nx=True, xx=True) + with pytest.raises(redis.ResponseError): + testtools.raw_command(r, "geoadd", "barcelona", "1", "2") + with pytest.raises(redis.ResponseError): + testtools.raw_command(r, "geoadd", "barcelona", "nx", "xx", *values, ) def test_geoadd_xx(r: redis.Redis): @@ -91,29 +102,22 @@ def test_geodist_missing_one_member(r: redis.Redis): assert r.geodist("barcelona", "place1", "missing_member", "km") is None -def test_georadius(r: redis.Redis): - values = ((2.1909389952632, 41.433791470673, "place1") + - (2.1873744593677, 41.406342043777, b"\x80place2")) - - r.geoadd("barcelona", values) - assert r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"] - assert r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"] - - -def test_georadius_no_values(r: redis.Redis): - values = ((2.1909389952632, 41.433791470673, "place1") + - (2.1873744593677, 41.406342043777, "place2",)) - - r.geoadd("barcelona", values) - assert r.georadius("barcelona", 1, 2, 1000) == [] - - -def test_georadius_units(r: redis.Redis): +@pytest.mark.parametrize( + "long,lat,radius,extra,expected", [ + (2.191, 41.433, 1000, {}, [b"place1"]), + (2.187, 41.406, 1000, {}, [b"place2"]), + (1, 2, 1000, {}, []), + (2.191, 41.433, 1, {"unit": "km"}, [b"place1"]), + (2.191, 41.433, 3000, {"count": 1}, [b"place1"]), + ]) +def test_georadius( + r: redis.Redis, long: float, lat: float, radius: float, + extra: Dict[str, Any], + expected): values = ((2.1909389952632, 41.433791470673, "place1") + - (2.1873744593677, 41.406342043777, "place2",)) - + (2.1873744593677, 41.406342043777, b"place2")) r.geoadd("barcelona", values) - assert r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [b"place1"] + assert r.georadius("barcelona", long, lat, radius, **extra) == expected def test_georadius_with(r: redis.Redis): @@ -121,26 +125,15 @@ def test_georadius_with(r: redis.Redis): (2.1873744593677, 41.406342043777, "place2",)) r.geoadd("barcelona", values) - - # test a bunch of combinations to test the parse response - # function. + # test a bunch of combinations to test the parse response function. res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True, ) - assert res == [pytest.approx([ - b"place1", - 0.0881, - pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001) - ], 0.001)] + assert res == [pytest.approx([b"place1", 0.0881, pytest.approx((2.1909, 41.4337), 0.0001)], 0.001)] res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True) - assert res == [pytest.approx([ - b"place1", - 0.0881, - pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001) - ], 0.001)] + assert res == [pytest.approx([b"place1", 0.0881, pytest.approx((2.1909, 41.4337), 0.0001)], 0.001)] - assert r.georadius( - "barcelona", 2.191, 41.433, 1, unit="km", withcoord=True - ) == [[b"place1", pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)]] + res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withcoord=True) + assert res == [[b"place1", pytest.approx((2.1909, 41.4337), 0.0001)]] # test no values. assert (r.georadius("barcelona", 2, 1, 1, unit="km", withdist=True, withcoord=True, ) == []) @@ -151,6 +144,39 @@ def test_georadius_count(r: redis.Redis): (2.1873744593677, 41.406342043777, "place2",)) r.geoadd("barcelona", values) - assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [b"place1"] + + assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1, store='barcelona') == 1 + assert r.georadius("barcelona", 2.191, 41.433, 3000, store_dist='extract') == 1 + assert r.zcard("extract") == 1 res = r.georadius("barcelona", 2.191, 41.433, 3000, count=1, any=True) assert (res == [b"place2"]) or res == [b'place1'] + + values = ((13.361389, 38.115556, "Palermo") + + (15.087269, 37.502669, "Catania",)) + + r.geoadd("Sicily", values) + assert testtools.raw_command( + r, "GEORADIUS", "Sicily", "15", "37", "200", "km", + "STOREDIST", "neardist", "STORE", "near") == 2 + assert r.zcard("near") == 2 + assert r.zcard("neardist") == 0 + + +def test_georadius_errors(r: redis.Redis): + values = ((13.361389, 38.115556, "Palermo") + + (15.087269, 37.502669, "Catania",)) + + r.geoadd("Sicily", values) + + with pytest.raises(redis.DataError): # Unsupported unit + r.georadius("barcelona", 2.191, 41.433, 3000, unit='dsf') + with pytest.raises(redis.ResponseError): # Unsupported unit + testtools.raw_command( + r, "GEORADIUS", "Sicily", "15", "37", "200", "ddds", + "STOREDIST", "neardist", "STORE", "near") + + bad_values = (13.361389, 38.115556, "Palermo", 15.087269, "Catania",) + with pytest.raises(redis.DataError): + r.geoadd('newgroup', bad_values) + with pytest.raises(redis.ResponseError): + testtools.raw_command(r, 'geoadd', 'newgroup', *bad_values)