-
-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
276 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from fakeredis import _msgs as msgs | ||
from fakeredis._command_args_parsing import extract_args | ||
from fakeredis._commands import command, Key, Float | ||
from fakeredis._helpers import SimpleError | ||
from fakeredis._zset import ZSet | ||
from fakeredis.geo import geohash | ||
from fakeredis.geo.haversine import distance | ||
|
||
|
||
class GeoCommandsMixin: | ||
@command(name='GEOADD', fixed=(Key(ZSet),), repeat=(bytes,)) | ||
def geoadd(self, key, *args): | ||
(xx, nx, ch), data = extract_args( | ||
args, ('nx', 'xx', 'ch'), | ||
error_on_unexpected=False, left_from_first_unexpected=True) | ||
if xx and nx: | ||
raise SimpleError(msgs.NX_XX_GT_LT_ERROR_MSG) | ||
if len(data) == 0 or len(data) % 3 != 0: | ||
raise SimpleError(msgs.SYNTAX_ERROR_MSG) | ||
zset = key.value | ||
old_len, changed_items = len(zset), 0 | ||
for i in range(0, len(data), 3): | ||
long, lat, name = Float.decode(data[i + 0]), Float.decode(data[i + 1]), data[i + 2] | ||
if (name in zset and not xx) or (name not in zset and not nx): | ||
if zset.add(name, geohash.encode(lat, long, 10)): | ||
changed_items += 1 | ||
if changed_items: | ||
key.updated() | ||
if ch: | ||
return changed_items | ||
return len(zset) - old_len | ||
|
||
@command(name='GEOHASH', fixed=(Key(ZSet), bytes), repeat=(bytes,)) | ||
def geohash(self, key, *members): | ||
hashes = map(key.value.get, members) | ||
geohash_list = [((x + '0').encode() if x is not None else x) for x in hashes] | ||
return geohash_list | ||
|
||
@command(name='GEOPOS', fixed=(Key(ZSet), bytes), repeat=(bytes,)) | ||
def geopos(self, key, *members): | ||
gospositions = map( | ||
lambda x: geohash.decode(x) if x is not None else x, | ||
map(key.value.get, members)) | ||
res = [([self._encodefloat(x[1], humanfriendly=False), | ||
self._encodefloat(x[0], humanfriendly=False)] | ||
if x is not None else None) | ||
for x in gospositions] | ||
return res | ||
|
||
@command(name='GEODIST', fixed=(Key(ZSet), bytes, bytes), repeat=(bytes,)) | ||
def geodist(self, key, m1, m2, *args): | ||
geohashes = [key.value.get(m1), key.value.get(m2)] | ||
if any(elem is None for elem in geohashes): | ||
return None | ||
geo_locs = [geohash.decode(x) for x in geohashes] | ||
res = distance((geo_locs[0][0], geo_locs[0][1]), | ||
(geo_locs[1][0], geo_locs[1][1])) | ||
unit = 1 | ||
if len(args) == 1: | ||
unit_str = args[0].decode().lower() | ||
if unit_str == 'km': | ||
unit = 0.001 | ||
elif unit_str == 'mi': | ||
unit = 0.000621371 | ||
elif unit_str == 'ft': | ||
unit = 3.28084 | ||
return res * unit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Note: the alphabet in geohash differs from the common base32 | ||
# alphabet described in IETF's RFC 4648 | ||
# (http://tools.ietf.org/html/rfc4648) | ||
from typing import Tuple | ||
|
||
base32 = '0123456789bcdefghjkmnpqrstuvwxyz' | ||
decodemap = {base32[i]: i for i in range(len(base32))} | ||
|
||
|
||
def decode(geohash: str) -> Tuple[float, float, float, float]: | ||
""" | ||
Decode the geohash to its exact values, including the error | ||
margins of the result. Returns four float values: latitude, | ||
longitude, the plus/minus error for latitude (as a positive | ||
number) and the plus/minus error for longitude (as a positive | ||
number). | ||
""" | ||
lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0) | ||
lat_err, lon_err = 90.0, 180.0 | ||
is_longitude = True | ||
for c in geohash: | ||
cd = decodemap[c] | ||
for mask in [16, 8, 4, 2, 1]: | ||
if is_longitude: # adds longitude info | ||
lon_err /= 2 | ||
if cd & mask: | ||
lon_interval = ((lon_interval[0] + lon_interval[1]) / 2, lon_interval[1]) | ||
else: | ||
lon_interval = (lon_interval[0], (lon_interval[0] + lon_interval[1]) / 2) | ||
else: # adds latitude info | ||
lat_err /= 2 | ||
if cd & mask: | ||
lat_interval = ((lat_interval[0] + lat_interval[1]) / 2, lat_interval[1]) | ||
else: | ||
lat_interval = (lat_interval[0], (lat_interval[0] + lat_interval[1]) / 2) | ||
is_longitude = not is_longitude | ||
lat = (lat_interval[0] + lat_interval[1]) / 2 | ||
lon = (lon_interval[0] + lon_interval[1]) / 2 | ||
return lat, lon, lat_err, lon_err | ||
|
||
|
||
def encode(latitude: float, longitude: float, precision=12) -> str: | ||
""" | ||
Encode a position given in float arguments latitude, longitude to | ||
a geohash which will have the character count precision. | ||
""" | ||
lat_interval, lon_interval = (-90.0, 90.0), (-180.0, 180.0) | ||
geohash, bits = [], [16, 8, 4, 2, 1] | ||
bit, ch = 0, 0 | ||
is_longitude = True | ||
|
||
def next_interval(curr: float, interval: Tuple[float, float], ch: int) -> Tuple[Tuple[float, float], int]: | ||
mid = (interval[0] + interval[1]) / 2 | ||
if curr > mid: | ||
ch |= bits[bit] | ||
return (mid, interval[1]), ch | ||
else: | ||
return (interval[0], mid), ch | ||
|
||
while len(geohash) < precision: | ||
if is_longitude: | ||
lon_interval, ch = next_interval(longitude, lon_interval, ch) | ||
else: | ||
lat_interval, ch = next_interval(latitude, lat_interval, ch) | ||
is_longitude = not is_longitude | ||
if bit < 4: | ||
bit += 1 | ||
else: | ||
geohash += base32[ch] | ||
bit = 0 | ||
ch = 0 | ||
return ''.join(geohash) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import math | ||
from typing import Tuple | ||
|
||
|
||
# class GeoMember: | ||
# def __init__(self, name: bytes, lat: float, long: float): | ||
# self.name = name | ||
# self.long = long | ||
# self.lat = lat | ||
# | ||
# @staticmethod | ||
# def from_bytes_tuple(t: Tuple[bytes, bytes, bytes]) -> 'GeoMember': | ||
# long = Float.decode(t[0]) | ||
# lat = Float.decode(t[1]) | ||
# name = t[2] | ||
# return GeoMember(name, lat, long) | ||
# | ||
# def geohash(self): | ||
# return geohash.encode(self.lat, self.long) | ||
|
||
|
||
def distance(origin: Tuple[float, float], destination: Tuple[float, float]) -> float: | ||
"""Calculate the Haversine distance in meters.""" | ||
radius = 6372797.560856 # Earth's quatratic mean radius for WGS-84 | ||
|
||
lat1, lon1, lat2, lon2 = map( | ||
math.radians, [origin[0], origin[1], destination[0], destination[1]]) | ||
|
||
dlon = lon2 - lon1 | ||
dlat = lat2 - lat1 | ||
a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2 | ||
c = 2 * math.asin(math.sqrt(a)) | ||
|
||
return c * radius |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import pytest | ||
import redis | ||
|
||
|
||
def test_geoadd(r: redis.Redis): | ||
values = ((2.1909389952632, 41.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
assert r.geoadd("barcelona", values) == 2 | ||
assert r.zcard("barcelona") == 2 | ||
|
||
values = (2.1909389952632, 41.433791470673, "place1") | ||
assert r.geoadd("a", values) == 1 | ||
values = ((2.1909389952632, 31.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
assert r.geoadd("a", values, ch=True) == 2 | ||
assert r.zrange("a", 0, -1) == [b"place1", b"place2"] | ||
|
||
with pytest.raises(redis.RedisError): | ||
r.geoadd("barcelona", (1, 2)) | ||
|
||
|
||
def test_geoadd_xx(r: redis.Redis): | ||
values = ((2.1909389952632, 41.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
assert r.geoadd("a", values) == 2 | ||
values = ( | ||
(2.1909389952632, 41.433791470673, "place1") | ||
+ (2.1873744593677, 41.406342043777, "place2") | ||
+ (2.1804738294738, 41.405647879212, "place3") | ||
) | ||
assert r.geoadd("a", values, nx=True) == 1 | ||
assert r.zrange("a", 0, -1) == [b"place3", b"place2", b"place1"] | ||
|
||
|
||
def test_geoadd_ch(r: redis.Redis): | ||
values = (2.1909389952632, 41.433791470673, "place1") | ||
assert r.geoadd("a", values) == 1 | ||
values = (2.1909389952632, 31.433791470673, "place1") + ( | ||
2.1873744593677, | ||
41.406342043777, | ||
"place2", | ||
) | ||
assert r.geoadd("a", values, ch=True) == 2 | ||
assert r.zrange("a", 0, -1) == [b"place1", b"place2"] | ||
|
||
|
||
def test_geohash(r: redis.Redis): | ||
values = ((2.1909389952632, 41.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
r.geoadd("barcelona", values) | ||
assert r.geohash("barcelona", "place1", "place2", "place3") == [ | ||
"sp3e9yg3kd0", | ||
"sp3e9cbc3t0", | ||
None, | ||
] | ||
|
||
|
||
def test_geopos(r: redis.Redis): | ||
values = ((2.1909389952632, 41.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
r.geoadd("barcelona", values) | ||
# small errors may be introduced. | ||
assert r.geopos("barcelona", "place1", "place4", "place2") == [ | ||
pytest.approx((2.1909389952632, 41.433791470673), 0.00001), | ||
None, | ||
pytest.approx((2.1873744593677, 41.406342043777), 0.00001), | ||
] | ||
|
||
|
||
def test_geodist(r: redis.Redis): | ||
values = ((2.1909389952632, 41.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
assert r.geoadd("barcelona", values) == 2 | ||
assert r.geodist("barcelona", "place1", "place2") == pytest.approx(3067.4157, 0.0001) | ||
|
||
|
||
def test_geodist_units(r: redis.Redis): | ||
values = ((2.1909389952632, 41.433791470673, "place1") + | ||
(2.1873744593677, 41.406342043777, "place2",)) | ||
r.geoadd("barcelona", values) | ||
assert r.geodist("barcelona", "place1", "place2", "km") == pytest.approx(3.0674, 0.0001) | ||
assert r.geodist("barcelona", "place1", "place2", "mi") == pytest.approx(1.906, 0.0001) | ||
assert r.geodist("barcelona", "place1", "place2", "ft") == pytest.approx(10063.6998, 0.0001) | ||
with pytest.raises(redis.RedisError): | ||
assert r.geodist("x", "y", "z", "inches") | ||
|
||
|
||
def test_geodist_missing_one_member(r: redis.Redis): | ||
values = (2.1909389952632, 41.433791470673, "place1") | ||
r.geoadd("barcelona", values) | ||
assert r.geodist("barcelona", "place1", "missing_member", "km") is None |