diff --git a/fakeredis/commands_mixins/geo_mixin.py b/fakeredis/commands_mixins/geo_mixin.py index 70289a31..dc4a6eae 100644 --- a/fakeredis/commands_mixins/geo_mixin.py +++ b/fakeredis/commands_mixins/geo_mixin.py @@ -1,3 +1,7 @@ +import sys +from collections import namedtuple +from typing import List, Optional, Any + from fakeredis import _msgs as msgs from fakeredis._command_args_parsing import extract_args from fakeredis._commands import command, Key, Float @@ -7,7 +11,28 @@ from fakeredis.geo.haversine import distance +def translate_meters_to_unit(unit_arg: bytes) -> float: + unit_str = unit_arg.decode().lower() + if unit_str == 'km': + unit = 0.001 + elif unit_str == 'mi': + unit = 0.000621371 + elif unit_str == 'ft': + unit = 3.28084 + else: # meter + unit = 1 + return unit + + +GeoResult = namedtuple('GeoResult', 'name long lat hash distance') + + class GeoCommandsMixin: + # TODO + # GEORADIUS, GEORADIUS_RO, + # GEORADIUSBYMEMBER, GEORADIUSBYMEMBER_RO, + # GEOSEARCH, GEOSEARCHSTORE + @command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,)) def geoadd(self, key, *args): (xx, nx, ch), data = extract_args( @@ -55,13 +80,45 @@ def geodist(self, key, m1, m2, *args): geo_locs = [geohash.decode(x) for x in geohashes] res = distance((geo_locs[0][0], geo_locs[0][1]), (geo_locs[1][0], geo_locs[1][1])) - unit = 1 - if len(args) == 1: - unit_str = args[0].decode().lower() - if unit_str == 'km': - unit = 0.001 - elif unit_str == 'mi': - unit = 0.000621371 - elif unit_str == 'ft': - unit = 3.28084 + unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1 return res * unit + + def _parse_results( + self, items: List[GeoResult], + withcoord: bool, withdist: bool, withhash: bool, + count: Optional[int], desc: bool) -> List[Any]: + items = sorted(items, key=lambda x: x.distance, reverse=desc) + if count: + items = items[:count] + res = list() + for item in items: + new_item = [item.name, ] + if withdist: + new_item.append(self._encodefloat(item.distance, False)) + if withcoord: + new_item.append([self._encodefloat(item.long, False), + self._encodefloat(item.lat, False)]) + if len(new_item) == 1: + new_item = new_item[0] + res.append(new_item) + return res + + @command(name='GEORADIUS', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,)) + def georadius(self, key, long, lat, radius, *args): + zset = key.value + results = list() + (withcoord, withdist, withhash, count, count_any, desc, store, storedist), left_args = extract_args( + args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc', '*store', '*storedist'), + error_on_unexpected=False, left_from_first_unexpected=False) + unit = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1 + count = count or sys.maxsize + + for name, _hash in zset.items(): + p_lat, p_long, _, _ = geohash.decode(_hash) + dist = distance((p_lat, p_long), (lat, long)) * unit + if dist < radius: + results.append(GeoResult(name, p_long, p_lat, _hash, dist)) + if count_any and len(results) >= count: + break + + return self._parse_results(results, withcoord, withdist, withhash, count, desc) diff --git a/test/test_mixins/test_geo_commands.py b/test/test_mixins/test_geo_commands.py index 2751409d..7efa43bf 100644 --- a/test/test_mixins/test_geo_commands.py +++ b/test/test_mixins/test_geo_commands.py @@ -89,3 +89,68 @@ def test_geodist_missing_one_member(r: redis.Redis): values = (2.1909389952632, 41.433791470673, "place1") r.geoadd("barcelona", values) assert r.geodist("barcelona", "place1", "missing_member", "km") is None + + +def test_georadius(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, b"\x80place2")) + + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"] + assert r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"] + + +def test_georadius_no_values(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 1, 2, 1000) == [] + + +def test_georadius_units(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [b"place1"] + + +def test_georadius_with(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + + r.geoadd("barcelona", values) + + # test a bunch of combinations to test the parse response + # function. + res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True, ) + assert res == [pytest.approx([ + b"place1", + 0.0881, + pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001) + ], 0.001)] + + res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True) + assert res == [pytest.approx([ + b"place1", + 0.0881, + pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001) + ], 0.001)] + + assert r.georadius( + "barcelona", 2.191, 41.433, 1, unit="km", withcoord=True + ) == [[b"place1", pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)]] + + # test no values. + assert (r.georadius("barcelona", 2, 1, 1, unit="km", withdist=True, withcoord=True, ) == []) + + +def test_georadius_count(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + + r.geoadd("barcelona", values) + assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [b"place1"] + res = r.georadius("barcelona", 2.191, 41.433, 3000, count=1, any=True) + assert (res == [b"place2"]) or res == [b'place1']