Skip to content

Commit

Permalink
fix: improve KeyMatch and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Zixuan Liu <[email protected]>
  • Loading branch information
nodece committed Nov 5, 2020
1 parent 6b52db1 commit f62a2b2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
25 changes: 7 additions & 18 deletions casbin/util/builtin_operators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import re
import ipaddress


KEY_MATCH2_PATTERN = re.compile(r'(.*):[^\/]+(.*)')
KEY_MATCH3_PATTERN = re.compile(r'(.*){[^\/]+}(.*)')
KEY_MATCH2_PATTERN = re.compile(r'(.*?):[^\/]+(.*?)')
KEY_MATCH3_PATTERN = re.compile(r'(.*?){[^\/]+}(.*?)')


def key_match(key1, key2):
Expand Down Expand Up @@ -35,14 +34,9 @@ def key_match2(key1, key2):
"""

key2 = key2.replace("/*", "/.*")
key2 = KEY_MATCH2_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0)

while True:
if "/:" not in key2:
break

key2 = "^" + KEY_MATCH2_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0) + "$"

return regex_match(key1, key2)
return regex_match(key1, "^" + key2 + "$")


def key_match2_func(*args):
Expand All @@ -58,14 +52,9 @@ def key_match3(key1, key2):
"""

key2 = key2.replace("/*", "/.*")
key2 = KEY_MATCH3_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0)

while True:
if "{" not in key2:
break

key2 = KEY_MATCH3_PATTERN.sub(r'\g<1>[^\/]+\g<2>', key2, 0)

return regex_match(key1, key2)
return regex_match(key1, "^" + key2 + "$")


def key_match3_func(*args):
Expand Down Expand Up @@ -100,7 +89,7 @@ def ip_match(ip1, ip2):
"""
ip1 = ipaddress.ip_address(ip1)
try:
network = ipaddress.ip_network(ip2, strict=True)
network = ipaddress.ip_network(ip2, strict=False)
return ip1 in network
except ValueError:
return ip1 == ip2
Expand Down
59 changes: 55 additions & 4 deletions tests/util/test_builtin_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class TestBuiltinOperators(TestCase):

def test_key_match(self):
self.assertFalse(util.key_match_func("/foo", "/"))
self.assertTrue(util.key_match_func("/foo", "/foo"))
self.assertTrue(util.key_match_func("/foo", "/foo*"))
self.assertFalse(util.key_match_func("/foo", "/foo/*"))
Expand All @@ -15,15 +16,18 @@ def test_key_match(self):
self.assertTrue(util.key_match_func("/foobar", "/foo*"))
self.assertFalse(util.key_match_func("/foobar", "/foo/*"))

self.assertFalse(util.key_match2_func("/alice/all", "/:/all"))

def test_key_match2(self):
self.assertFalse(util.key_match2_func("/foo", "/"))
self.assertTrue(util.key_match2_func("/foo", "/foo"))
self.assertTrue(util.key_match2_func("/foo", "/foo*"))
self.assertFalse(util.key_match2_func("/foo", "/foo/*"))
self.assertTrue(util.key_match2_func("/foo/bar", "/foo")) # different with KeyMatch.
self.assertTrue(util.key_match2_func("/foo/bar", "/foo*"))
self.assertFalse(util.key_match2_func("/foo/bar", "/foo")) # different with KeyMatch.
self.assertFalse(util.key_match2_func("/foo/bar", "/foo*"))
self.assertTrue(util.key_match2_func("/foo/bar", "/foo/*"))
self.assertTrue(util.key_match2_func("/foobar", "/foo")) # different with KeyMatch.
self.assertTrue(util.key_match2_func("/foobar", "/foo*"))
self.assertFalse(util.key_match2_func("/foobar", "/foo")) # different with KeyMatch.
self.assertFalse(util.key_match2_func("/foobar", "/foo*"))
self.assertFalse(util.key_match2_func("/foobar", "/foo/*"))

self.assertFalse(util.key_match2_func("/", "/:resource"))
Expand All @@ -42,3 +46,50 @@ def test_key_match2(self):
self.assertTrue(util.key_match2_func("/alice/all", "/:id/all"))
self.assertFalse(util.key_match2_func("/alice", "/:id/all"))
self.assertFalse(util.key_match2_func("/alice/all", "/:id"))

self.assertFalse(util.key_match2_func("/alice/all", "/:/all"))

def test_key_match3(self):
self.assertTrue(util.key_match3_func("/foo", "/foo"))
self.assertTrue(util.key_match3_func("/foo", "/foo*"))
self.assertFalse(util.key_match3_func("/foo", "/foo/*"))
self.assertFalse(util.key_match3_func("/foo/bar", "/foo"))
self.assertFalse(util.key_match3_func("/foo/bar", "/foo*"))
self.assertTrue(util.key_match3_func("/foo/bar", "/foo/*"))
self.assertFalse(util.key_match3_func("/foobar", "/foo"))
self.assertFalse(util.key_match3_func("/foobar", "/foo*"))
self.assertFalse(util.key_match3_func("/foobar", "/foo/*"))

self.assertFalse(util.key_match3_func("/", "/{resource}"))
self.assertTrue(util.key_match3_func("/resource1", "/{resource}"))
self.assertFalse(util.key_match3_func("/myid", "/{id}/using/{resId}"))
self.assertTrue(util.key_match3_func("/myid/using/myresid", "/{id}/using/{resId}"))

self.assertFalse(util.key_match3_func("/proxy/myid", "/proxy/{id}/*"))
self.assertTrue(util.key_match3_func("/proxy/myid/", "/proxy/{id}/*"))
self.assertTrue(util.key_match3_func("/proxy/myid/res", "/proxy/{id}/*"))
self.assertTrue(util.key_match3_func("/proxy/myid/res/res2", "/proxy/{id}/*"))
self.assertTrue(util.key_match3_func("/proxy/myid/res/res2/res3", "/proxy/{id}/*"))
self.assertFalse(util.key_match3_func("/proxy/", "/proxy/{id}/*"))

self.assertFalse(util.key_match3_func("/myid/using/myresid", "/{id/using/{resId}"))

def test_regex_match(self):
self.assertTrue(util.regex_match_func("/topic/create", "/topic/create"))
self.assertTrue(util.regex_match_func("/topic/create/123", "/topic/create"))
self.assertFalse(util.regex_match_func("/topic/delete", "/topic/create"))
self.assertFalse(util.regex_match_func("/topic/edit", "/topic/edit/[0-9]+"))
self.assertTrue(util.regex_match_func("/topic/edit/123", "/topic/edit/[0-9]+"))
self.assertFalse(util.regex_match_func("/topic/edit/abc", "/topic/edit/[0-9]+"))
self.assertFalse(util.regex_match_func("/foo/delete/123", "/topic/delete/[0-9]+"))
self.assertTrue(util.regex_match_func("/topic/delete/0", "/topic/delete/[0-9]+"))
self.assertFalse(util.regex_match_func("/topic/edit/123s", "/topic/delete/[0-9]+"))

def test_ip_match(self):
self.assertTrue(util.ip_match_func("192.168.2.123", "192.168.2.0/24"))
self.assertFalse(util.ip_match_func("192.168.2.123", "192.168.3.0/24"))
self.assertTrue(util.ip_match_func("192.168.2.123", "192.168.2.0/16"))
self.assertTrue(util.ip_match_func("192.168.2.123", "192.168.2.123"))
self.assertTrue(util.ip_match_func("192.168.2.123", "192.168.2.123/32"))
self.assertTrue(util.ip_match_func("10.0.0.11", "10.0.0.0/8"))
self.assertFalse(util.ip_match_func("11.0.0.123", "10.0.0.0/8"))

0 comments on commit f62a2b2

Please sign in to comment.