Skip to content

Commit

Permalink
Implement GEORADIUS
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Feb 24, 2023
1 parent cf18595 commit 6d729f6
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 9 deletions.
75 changes: 66 additions & 9 deletions fakeredis/commands_mixins/geo_mixin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys
from collections import namedtuple
from typing import List, Optional, Any

from fakeredis import _msgs as msgs
from fakeredis._command_args_parsing import extract_args
from fakeredis._commands import command, Key, Float
Expand All @@ -7,7 +11,28 @@
from fakeredis.geo.haversine import distance


def translate_meters_to_unit(unit_arg: bytes) -> float:
unit_str = unit_arg.decode().lower()
if unit_str == 'km':
unit = 0.001
elif unit_str == 'mi':
unit = 0.000621371
elif unit_str == 'ft':
unit = 3.28084
else: # meter
unit = 1
return unit


GeoResult = namedtuple('GeoResult', 'name long lat hash distance')


class GeoCommandsMixin:
# TODO
# GEORADIUS, GEORADIUS_RO,
# GEORADIUSBYMEMBER, GEORADIUSBYMEMBER_RO,
# GEOSEARCH, GEOSEARCHSTORE

@command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,))
def geoadd(self, key, *args):
(xx, nx, ch), data = extract_args(
Expand Down Expand Up @@ -55,13 +80,45 @@ def geodist(self, key, m1, m2, *args):
geo_locs = [geohash.decode(x) for x in geohashes]
res = distance((geo_locs[0][0], geo_locs[0][1]),
(geo_locs[1][0], geo_locs[1][1]))
unit = 1
if len(args) == 1:
unit_str = args[0].decode().lower()
if unit_str == 'km':
unit = 0.001
elif unit_str == 'mi':
unit = 0.000621371
elif unit_str == 'ft':
unit = 3.28084
unit = translate_meters_to_unit(args[0]) if len(args) == 1 else 1
return res * unit

def _parse_results(
self, items: List[GeoResult],
withcoord: bool, withdist: bool, withhash: bool,
count: Optional[int], desc: bool) -> List[Any]:
items = sorted(items, key=lambda x: x.distance, reverse=desc)
if count:
items = items[:count]
res = list()
for item in items:
new_item = [item.name, ]
if withdist:
new_item.append(self._encodefloat(item.distance, False))
if withcoord:
new_item.append([self._encodefloat(item.long, False),
self._encodefloat(item.lat, False)])
if len(new_item) == 1:
new_item = new_item[0]
res.append(new_item)
return res

@command(name='GEORADIUS', fixed=(Key(ZSet), Float, Float, Float), repeat=(bytes,))
def georadius(self, key, long, lat, radius, *args):
zset = key.value
results = list()
(withcoord, withdist, withhash, count, count_any, desc, store, storedist), left_args = extract_args(
args, ('withcoord', 'withdist', 'withhash', '+count', 'any', 'desc', '*store', '*storedist'),
error_on_unexpected=False, left_from_first_unexpected=False)
unit = translate_meters_to_unit(args[0]) if len(args) >= 1 else 1
count = count or sys.maxsize

for name, _hash in zset.items():
p_lat, p_long, _, _ = geohash.decode(_hash)
dist = distance((p_lat, p_long), (lat, long)) * unit
if dist < radius:
results.append(GeoResult(name, p_long, p_lat, _hash, dist))
if count_any and len(results) >= count:
break

return self._parse_results(results, withcoord, withdist, withhash, count, desc)
65 changes: 65 additions & 0 deletions test/test_mixins/test_geo_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,68 @@ def test_geodist_missing_one_member(r: redis.Redis):
values = (2.1909389952632, 41.433791470673, "place1")
r.geoadd("barcelona", values)
assert r.geodist("barcelona", "place1", "missing_member", "km") is None


def test_georadius(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, b"\x80place2"))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"]
assert r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"]


def test_georadius_no_values(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 1, 2, 1000) == []


def test_georadius_units(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [b"place1"]


def test_georadius_with(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)

# test a bunch of combinations to test the parse response
# function.
res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True, )
assert res == [pytest.approx([
b"place1",
0.0881,
pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)
], 0.001)]

res = r.georadius("barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True)
assert res == [pytest.approx([
b"place1",
0.0881,
pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)
], 0.001)]

assert r.georadius(
"barcelona", 2.191, 41.433, 1, unit="km", withcoord=True
) == [[b"place1", pytest.approx((2.19093829393386841, 41.43379028184083523), 0.0001)]]

# test no values.
assert (r.georadius("barcelona", 2, 1, 1, unit="km", withdist=True, withcoord=True, ) == [])


def test_georadius_count(r: redis.Redis):
values = ((2.1909389952632, 41.433791470673, "place1") +
(2.1873744593677, 41.406342043777, "place2",))

r.geoadd("barcelona", values)
assert r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [b"place1"]
res = r.georadius("barcelona", 2.191, 41.433, 3000, count=1, any=True)
assert (res == [b"place2"]) or res == [b'place1']

0 comments on commit 6d729f6

Please sign in to comment.