Skip to content

Commit

Permalink
feat: enhance FilteredFileAdapter to handle flexible filtering for po…
Browse files Browse the repository at this point in the history
…licies and roles (#360)

* feat: optimize filtered file adapter policy loading

* style: standardize whitespace and formatting in filtered_file_adapter.py

* feat: add test

* test: improve test_load_filtered_policy_with_comments in test_filter.py

* test: update test description for mixed filter
  • Loading branch information
HashCookie authored Nov 25, 2024
1 parent 7b64b85 commit 936d5f6
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 12 deletions.
28 changes: 17 additions & 11 deletions casbin/persist/adapters/filtered_file_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,28 @@ def load_filtered_policy(self, model, filter):

try:
filter_value = [filter.__dict__["P"]] + [filter.__dict__["G"]]
is_empty_filter = all(not f for f in filter_value) or all(
all(not x.strip() for x in f) if f else True for f in filter_value
)
if is_empty_filter:
return self.load_policy(model)
except:
raise RuntimeError("invalid filter type")

self.load_filtered_policy_file(model, filter_value, persist.load_policy_line)
self.filtered = True

def load_filtered_policy_file(self, model, filter, hanlder):
def load_filtered_policy_file(self, model, filter, handler):
with open(self._file_path, "rb") as file:
while True:
line = file.readline()
for line in file:
line = line.decode().strip()
if line == "\n":
if not line or line == "\n":
continue
if not line:
break

if filter_line(line, filter):
continue

hanlder(line, model)
handler(line, model)

# is_filtered returns true if the loaded policy has been filtered.
def is_filtered(self):
Expand All @@ -92,10 +95,13 @@ def filter_line(line, filter):
return True
filter_slice = []

if p[0].strip() == "p":
filter_slice = filter[0]
elif p[0].strip() == "g":
if p[0].strip() == "g":
if not filter[1] or all(not x.strip() for x in filter[1]):
return False
filter_slice = filter[1]
elif p[0].strip() == "p":
filter_slice = filter[0]

return filter_words(p, filter_slice)


Expand All @@ -104,7 +110,7 @@ def filter_words(line, filter):
return True
skip_line = False
for i, v in enumerate(filter):
if len(v) > 0 and (v.strip() != line[i + 1].strip()):
if v and v.strip() and (v.strip() != line[i + 1].strip()):
skip_line = True
break

Expand Down
176 changes: 175 additions & 1 deletion tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import casbin
import os
from unittest import TestCase
import casbin
from tests.test_enforcer import get_examples
from casbin.persist.adapters import FilteredFileAdapter
from casbin.persist.adapters.filtered_file_adapter import filter_line, filter_words


class Filter:
Expand Down Expand Up @@ -141,3 +143,175 @@ def test_filtered_adapter_invalid_filepath(self):

with self.assertRaises(RuntimeError):
e.load_filtered_policy(None)

def test_empty_filter_array(self):
"""Test filter for empty array."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = []
filter.G = []

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_empty_string_filter(self):
"""Test the filter for all empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "", ""]
filter.G = ["", "", ""]

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_mixed_empty_filter(self):
"""Test the filter for mixed empty and non-empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain1", ""]
filter.G = ["", "", "domain1"]

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_nonexistent_domain_filter(self):
"""Testing the filter for a non-existent domain."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain3"]
filter.G = ["", "", "domain3"]

e.load_filtered_policy(filter)
self.assertFalse(e.has_policy(["admin", "domain3", "data3", "read"]))

def test_empty_filter_array(self):
"""Test filter for empty array."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = []
filter.G = []

try:
e.load_filtered_policy(filter)
except:
raise RuntimeError("unexpected error with empty filter arrays")

self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty filters")

self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_empty_string_filter(self):
"""Test the filter for all empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "", ""]
filter.G = ["", "", ""]

try:
e.load_filtered_policy(filter)
except:
raise RuntimeError("unexpected error with empty string filters")

self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty string filters")

try:
e.save_policy()
except:
raise RuntimeError("unexpected error in SavePolicy with empty string filters")

self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_mixed_empty_filter(self):
"""Test the filter for mixed empty and non-empty strings."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain1", ""]
filter.G = ["", "", "domain1"]

try:
e.load_filtered_policy(filter)
except:
raise RuntimeError("unexpected error with mixed empty filters")

self.assertTrue(e.is_filtered(), "Adapter should be marked as filtered")

with self.assertRaises(RuntimeError):
e.save_policy()

self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_whitespace_filter(self):
"""Test the filter for all blank characters."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = [" ", " ", "\t"]
filter.G = ["\n", " ", " "]

e.load_filtered_policy(filter)

self.assertFalse(e.is_filtered())
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))

def test_filter_line_edge_cases(self):
"""Test the boundary cases of the filter_line function."""
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))

self.assertFalse(filter_line("", [[""], [""]]))

self.assertFalse(filter_line("invalid_line", [[""], [""]]))

self.assertFalse(filter_line("p, admin, domain1, data1, read", None))

def test_filter_words_edge_cases(self):
"""Test the boundary cases of the filter_words function."""
self.assertTrue(filter_words(["p"], ["filter1", "filter2"]))

self.assertFalse(filter_words(["p", "admin", "domain1"], []))

line = ["admin", "domain1", "data*", "read"]
filter = ["", "", "data1", ""]
self.assertTrue(filter_words(line, filter))

def test_load_filtered_policy_with_comments(self):
"""Test loading filtering policies with comments."""
import tempfile
import shutil

with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
with open(get_examples("rbac_with_domains_policy.csv"), "r") as source:
shutil.copyfileobj(source, temp_file)

temp_file.write("\n# This is a comment\np, admin, domain1, data3, read")
temp_file.flush()

temp_path = temp_file.name

try:
adapter = FilteredFileAdapter(temp_path)
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
filter = Filter()
filter.P = ["", "domain1"]
filter.G = ["", "", "domain1"]

e.load_filtered_policy(filter)
self.assertTrue(e.has_policy(["admin", "domain1", "data3", "read"]))
finally:
try:
os.unlink(temp_path)
except OSError:
pass

0 comments on commit 936d5f6

Please sign in to comment.