diff --git a/docs/redis-commands/Redis.md b/docs/redis-commands/Redis.md index 5ba6af97..c20c95f9 100644 --- a/docs/redis-commands/Redis.md +++ b/docs/redis-commands/Redis.md @@ -1180,26 +1180,28 @@ Set the debug mode for executed scripts. Kill the script currently in execution. +## geo commands -### Unsupported geo commands -> To implement support for a command, see [here](/guides/implement-command/) - -#### [GEOADD](https://redis.io/commands/geoadd/) +### [GEOADD](https://redis.io/commands/geoadd/) Add one or more geospatial items in the geospatial index represented using a sorted set -#### [GEODIST](https://redis.io/commands/geodist/) +### [GEODIST](https://redis.io/commands/geodist/) Returns the distance between two members of a geospatial index -#### [GEOHASH](https://redis.io/commands/geohash/) +### [GEOHASH](https://redis.io/commands/geohash/) Returns members of a geospatial index as standard geohash strings -#### [GEOPOS](https://redis.io/commands/geopos/) +### [GEOPOS](https://redis.io/commands/geopos/) Returns longitude and latitude of members of a geospatial index + +### Unsupported geo commands +> To implement support for a command, see [here](/guides/implement-command/) + #### [GEORADIUS](https://redis.io/commands/georadius/) Query a sorted set representing a geospatial index to fetch members matching a given maximum distance from a point diff --git a/fakeredis/_fakesocket.py b/fakeredis/_fakesocket.py index c0104a14..29b6108d 100644 --- a/fakeredis/_fakesocket.py +++ b/fakeredis/_fakesocket.py @@ -3,6 +3,7 @@ from .commands_mixins.bitmap_mixin import BitmapCommandsMixin from .commands_mixins.connection_mixin import ConnectionCommandsMixin from .commands_mixins.generic_mixin import GenericCommandsMixin +from .commands_mixins.geo_mixin import GeoCommandsMixin from .commands_mixins.hash_mixin import HashCommandsMixin from .commands_mixins.list_mixin import ListCommandsMixin from .commands_mixins.pubsub_mixin import PubSubCommandsMixin @@ -31,6 +32,7 @@ class FakeSocket( SortedSetCommandsMixin, StreamsCommandsMixin, JSONCommandsMixin, + GeoCommandsMixin, ): def __init__(self, server): diff --git a/fakeredis/commands_mixins/geo_mixin.py b/fakeredis/commands_mixins/geo_mixin.py new file mode 100644 index 00000000..70289a31 --- /dev/null +++ b/fakeredis/commands_mixins/geo_mixin.py @@ -0,0 +1,67 @@ +from fakeredis import _msgs as msgs +from fakeredis._command_args_parsing import extract_args +from fakeredis._commands import command, Key, Float +from fakeredis._helpers import SimpleError +from fakeredis._zset import ZSet +from fakeredis.geo import geohash +from fakeredis.geo.haversine import distance + + +class GeoCommandsMixin: + @command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,)) + def geoadd(self, key, *args): + (xx, nx, ch), data = extract_args( + args, ('nx', 'xx', 'ch'), + error_on_unexpected=False, left_from_first_unexpected=True) + if xx and nx: + raise SimpleError(msgs.NX_XX_GT_LT_ERROR_MSG) + if len(data) == 0 or len(data) % 3 != 0: + raise SimpleError(msgs.SYNTAX_ERROR_MSG) + zset = key.value + old_len, changed_items = len(zset), 0 + for i in range(0, len(data), 3): + long, lat, name = Float.decode(data[i + 0]), Float.decode(data[i + 1]), data[i + 2] + if (name in zset and not xx) or (name not in zset and not nx): + if zset.add(name, geohash.encode(lat, long, 10)): + changed_items += 1 + if changed_items: + key.updated() + if ch: + return changed_items + return len(zset) - old_len + + @command(name='GEOHASH', fixed=(Key(ZSet), bytes), repeat=(bytes,)) + def geohash(self, key, *members): + hashes = map(key.value.get, members) + geohash_list = [((x + '0').encode() if x is not None else x) for x in hashes] + return geohash_list + + @command(name='GEOPOS', fixed=(Key(ZSet), bytes), repeat=(bytes,)) + def geopos(self, key, *members): + gospositions = map( + lambda x: geohash.decode(x) if x is not None else x, + map(key.value.get, members)) + res = [([self._encodefloat(x[1], humanfriendly=False), + self._encodefloat(x[0], humanfriendly=False)] + if x is not None else None) + for x in gospositions] + return res + + @command(name='GEODIST', fixed=(Key(ZSet), bytes, bytes), repeat=(bytes,)) + def geodist(self, key, m1, m2, *args): + geohashes = [key.value.get(m1), key.value.get(m2)] + if any(elem is None for elem in geohashes): + return None + geo_locs = [geohash.decode(x) for x in geohashes] + res = distance((geo_locs[0][0], geo_locs[0][1]), + (geo_locs[1][0], geo_locs[1][1])) + unit = 1 + if len(args) == 1: + unit_str = args[0].decode().lower() + if unit_str == 'km': + unit = 0.001 + elif unit_str == 'mi': + unit = 0.000621371 + elif unit_str == 'ft': + unit = 3.28084 + return res * unit diff --git a/fakeredis/commands_mixins/sortedset_mixin.py b/fakeredis/commands_mixins/sortedset_mixin.py index 9b18b2da..63e6a35a 100644 --- a/fakeredis/commands_mixins/sortedset_mixin.py +++ b/fakeredis/commands_mixins/sortedset_mixin.py @@ -395,7 +395,7 @@ def zunionstore(self, dest, numkeys, *args): def zinterstore(self, dest, numkeys, *args): return self._zunioninter('ZINTERSTORE', dest, numkeys, *args) - @command(name="zmscore", fixed=(Key(ZSet), bytes), repeat=(bytes,)) + @command(name="ZMSCORE", fixed=(Key(ZSet), bytes), repeat=(bytes,)) def zmscore(self, key: CommandItem, *members: Union[str, bytes]) -> list[Optional[float]]: """Get the scores associated with the specified members in the sorted set stored at key. diff --git a/fakeredis/geo/__init__.py b/fakeredis/geo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fakeredis/geo/geohash.py b/fakeredis/geo/geohash.py new file mode 100644 index 00000000..e8f14b3f --- /dev/null +++ b/fakeredis/geo/geohash.py @@ -0,0 +1,72 @@ +# Note: the alphabet in geohash differs from the common base32 +# alphabet described in IETF's RFC 4648 +# (http://tools.ietf.org/html/rfc4648) +from typing import Tuple + +base32 = '0123456789bcdefghjkmnpqrstuvwxyz' +decodemap = {base32[i]: i for i in range(len(base32))} + + +def decode(geohash: str) -> Tuple[float, float, float, float]: + """ + Decode the geohash to its exact values, including the error + margins of the result. Returns four float values: latitude, + longitude, the plus/minus error for latitude (as a positive + number) and the plus/minus error for longitude (as a positive + number). + """ + lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0) + lat_err, lon_err = 90.0, 180.0 + is_longitude = True + for c in geohash: + cd = decodemap[c] + for mask in [16, 8, 4, 2, 1]: + if is_longitude: # adds longitude info + lon_err /= 2 + if cd & mask: + lon_interval = ((lon_interval[0] + lon_interval[1]) / 2, lon_interval[1]) + else: + lon_interval = (lon_interval[0], (lon_interval[0] + lon_interval[1]) / 2) + else: # adds latitude info + lat_err /= 2 + if cd & mask: + lat_interval = ((lat_interval[0] + lat_interval[1]) / 2, lat_interval[1]) + else: + lat_interval = (lat_interval[0], (lat_interval[0] + lat_interval[1]) / 2) + is_longitude = not is_longitude + lat = (lat_interval[0] + lat_interval[1]) / 2 + lon = (lon_interval[0] + lon_interval[1]) / 2 + return lat, lon, lat_err, lon_err + + +def encode(latitude: float, longitude: float, precision=12) -> str: + """ + Encode a position given in float arguments latitude, longitude to + a geohash which will have the character count precision. + """ + lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0) + geohash, bits = [], [16, 8, 4, 2, 1] + bit, ch = 0, 0 + is_longitude = True + + def next_interval(curr: float, interval: Tuple[float, float], ch: int) -> Tuple[Tuple[float, float], int]: + mid = (interval[0] + interval[1]) / 2 + if curr > mid: + ch |= bits[bit] + return (mid, interval[1]), ch + else: + return (interval[0], mid), ch + + while len(geohash) < precision: + if is_longitude: + lon_interval, ch = next_interval(longitude, lon_interval, ch) + else: + lat_interval, ch = next_interval(latitude, lat_interval, ch) + is_longitude = not is_longitude + if bit < 4: + bit += 1 + else: + geohash += base32[ch] + bit = 0 + ch = 0 + return ''.join(geohash) diff --git a/fakeredis/geo/haversine.py b/fakeredis/geo/haversine.py new file mode 100644 index 00000000..99a7216f --- /dev/null +++ b/fakeredis/geo/haversine.py @@ -0,0 +1,34 @@ +import math +from typing import Tuple + + +# class GeoMember: +# def __init__(self, name: bytes, lat: float, long: float): +# self.name = name +# self.long = long +# self.lat = lat +# +# @staticmethod +# def from_bytes_tuple(t: Tuple[bytes, bytes, bytes]) -> 'GeoMember': +# long = Float.decode(t[0]) +# lat = Float.decode(t[1]) +# name = t[2] +# return GeoMember(name, lat, long) +# +# def geohash(self): +# return geohash.encode(self.lat, self.long) + + +def distance(origin: Tuple[float, float], destination: Tuple[float, float]) -> float: + """Calculate the Haversine distance in meters.""" + radius = 6372797.560856 # Earth's quatratic mean radius for WGS-84 + + lat1, lon1, lat2, lon2 = map( + math.radians, [origin[0], origin[1], destination[0], destination[1]]) + + dlon = lon2 - lon1 + dlat = lat2 - lat1 + a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2 + c = 2 * math.asin(math.sqrt(a)) + + return c * radius diff --git a/test/test_mixins/test_geo_commands.py b/test/test_mixins/test_geo_commands.py new file mode 100644 index 00000000..2751409d --- /dev/null +++ b/test/test_mixins/test_geo_commands.py @@ -0,0 +1,91 @@ +import pytest +import redis + + +def test_geoadd(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + assert r.geoadd("barcelona", values) == 2 + assert r.zcard("barcelona") == 2 + + values = (2.1909389952632, 41.433791470673, "place1") + assert r.geoadd("a", values) == 1 + values = ((2.1909389952632, 31.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + assert r.geoadd("a", values, ch=True) == 2 + assert r.zrange("a", 0, -1) == [b"place1", b"place2"] + + with pytest.raises(redis.RedisError): + r.geoadd("barcelona", (1, 2)) + + +def test_geoadd_xx(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + assert r.geoadd("a", values) == 2 + values = ( + (2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2") + + (2.1804738294738, 41.405647879212, "place3") + ) + assert r.geoadd("a", values, nx=True) == 1 + assert r.zrange("a", 0, -1) == [b"place3", b"place2", b"place1"] + + +def test_geoadd_ch(r: redis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + assert r.geoadd("a", values) == 1 + values = (2.1909389952632, 31.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + assert r.geoadd("a", values, ch=True) == 2 + assert r.zrange("a", 0, -1) == [b"place1", b"place2"] + + +def test_geohash(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + r.geoadd("barcelona", values) + assert r.geohash("barcelona", "place1", "place2", "place3") == [ + "sp3e9yg3kd0", + "sp3e9cbc3t0", + None, + ] + + +def test_geopos(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + r.geoadd("barcelona", values) + # small errors may be introduced. + assert r.geopos("barcelona", "place1", "place4", "place2") == [ + pytest.approx((2.1909389952632, 41.433791470673), 0.00001), + None, + pytest.approx((2.1873744593677, 41.406342043777), 0.00001), + ] + + +def test_geodist(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + assert r.geoadd("barcelona", values) == 2 + assert r.geodist("barcelona", "place1", "place2") == pytest.approx(3067.4157, 0.0001) + + +def test_geodist_units(r: redis.Redis): + values = ((2.1909389952632, 41.433791470673, "place1") + + (2.1873744593677, 41.406342043777, "place2",)) + r.geoadd("barcelona", values) + assert r.geodist("barcelona", "place1", "place2", "km") == pytest.approx(3.0674, 0.0001) + assert r.geodist("barcelona", "place1", "place2", "mi") == pytest.approx(1.906, 0.0001) + assert r.geodist("barcelona", "place1", "place2", "ft") == pytest.approx(10063.6998, 0.0001) + with pytest.raises(redis.RedisError): + assert r.geodist("x", "y", "z", "inches") + + +def test_geodist_missing_one_member(r: redis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + r.geoadd("barcelona", values) + assert r.geodist("barcelona", "place1", "missing_member", "km") is None