Skip to content

Commit

Permalink
chore: simplify compare_policy_report.py to be compatible with python36
Browse files Browse the repository at this point in the history
Signed-off-by: Nathan Nguyen <[email protected]>
  • Loading branch information
nathanwn committed Oct 18, 2023
1 parent a6001b0 commit 5389861
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions tests/policy_engine/compare_policy_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
# Copyright (c) 2023 - 2023, Oracle and/or its affiliates. All rights reserved.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/.

"""This module checks the policy engine report against expected results."""
"""This script checks the policy engine report against expected results.
Note: Make sure that this script is compatible with Python 3.6 (system python version on OL8).
"""

import json
import logging
import sys
Expand All @@ -15,7 +19,7 @@
logger.setLevel(logging.DEBUG)


def check_policies(results: list, expectations: list) -> bool:
def check_policies(results: list, expectations: list) -> str | None:
"""
Compare result policies against expected policies.
Expand All @@ -28,12 +32,9 @@ def check_policies(results: list, expectations: list) -> bool:
Returns
-------
bool
Returns True if successful.
Raises
------
ValueError
str | None
Returns an error message if ``results`` do not match ``expectations``.
Returns ``None`` otherwise.
"""
# If not empty, policy is always a list with one item.
# For example, Datalog declaration for failed_policies is `failed_policies(policy_id: symbol)`.
Expand All @@ -42,22 +43,25 @@ def check_policies(results: list, expectations: list) -> bool:
# Iterate through the rows returned by the policy engine.
for index, exp in enumerate(expectations):
res = results[index]
if (fails := abs(len(res) - len(exp))) > 0:

fails = abs(len(res) - len(exp))
if fails > 0:
fail_count += fails
continue

# Do not check the first element, which is the primary key.
c_fail = Counter(exp[1:])
c_fail.subtract(Counter(res[1:]))
if (fails := len([value for value in c_fail.values() if value != 0])) > 0:

fails = len([value for value in c_fail.values() if value != 0])
if fails > 0:
fail_count += fails
else:
fail_count += abs(len(results) - len(expectations))

if fail_count > 0:
raise ValueError(
f"Results do not match in {fail_count} item(s): Result is {results} but expected {expectations}"
)
return True
return f"Results do not match in {fail_count} item(s): Result is {results} but expected {expectations}"
return None


def main() -> int:
Expand All @@ -67,34 +71,29 @@ def main() -> int:
with open(sys.argv[1], encoding="utf8") as res_file, open(sys.argv[2], encoding="utf8") as exp_file:
result = json.load(res_file)
expected = json.load(exp_file)

try:
check_policies(result["failed_policies"], expected["failed_policies"])
except ValueError as error:
return_code = 1
logger.error("failed_policies: %s", error)

try:
check_policies(result["passed_policies"], expected["passed_policies"])
except ValueError as error:
return_code = 1
logger.error("passed_policies: %s", error)

try:
check_policies(result["component_violates_policy"], expected["component_violates_policy"])
except ValueError as error:
return_code = 1
logger.error("component_violates_policy: %s", error)

try:
check_policies(result["component_satisfies_policy"], expected["component_satisfies_policy"])
except ValueError as error:
return_code = 1
logger.error("component_satisfies_policy: %s", error)

except FileNotFoundError as error:
logger.error(error)
return 1

error_msg = check_policies(result["failed_policies"], expected["failed_policies"])
if error_msg:
return_code = 1
logger.error("failed_policies: %s", error_msg)

error_msg = check_policies(result["passed_policies"], expected["passed_policies"])
if error_msg:
return_code = 1
logger.error("passed_policies: %s", error_msg)

error_msg = check_policies(result["component_violates_policy"], expected["component_violates_policy"])
if error_msg:
return_code = 1
logger.error("component_violates_policy: %s", error_msg)

error_msg = check_policies(result["component_satisfies_policy"], expected["component_satisfies_policy"])
if error_msg:
return_code = 1
logger.error("component_satisfies_policy: %s", error_msg)

return return_code

Expand Down

0 comments on commit 5389861

Please sign in to comment.