Skip to content

Commit

Permalink
Hotfix/sg 000 improve breaking change detection with imports (#1420)
Browse files Browse the repository at this point in the history
* fix

* update unitetst to work with SG

---------

Co-authored-by: Eugene Khvedchenya <[email protected]>
  • Loading branch information
Louis-Dupont and BloodAxe authored Aug 29, 2023
1 parent 4daedbd commit 56a7b99
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 27 deletions.
31 changes: 18 additions & 13 deletions tests/breaking_change_tests/breaking_changes_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def extract_code_breaking_changes(module_path: str, source_code: str, current_co
)
)

# IMPORTS - Check import ONLY if __init__ file and ignores non-SG imports.
current_imports = parse_imports(code=current_code)
if module_path.endswith("__init__.py"):
source_imports = parse_imports(code=source_code)
breaking_changes.imports_removed = [
ImportRemoved(import_name=source_import, line_num=0)
for source_import in source_imports
if (source_import not in current_imports) and ("super_gradients" in source_import)
]

# FUNCTION SIGNATURES
source_functions_signatures = parse_functions_signatures(source_code)
current_functions_signatures = parse_functions_signatures(current_code)
Expand Down Expand Up @@ -191,21 +201,16 @@ def extract_code_breaking_changes(module_path: str, source_code: str, current_co
)

else:
# FunctionRemoved
breaking_changes.functions_removed.append(
FunctionRemoved(
function_name=function_name,
line_num=source_function_signature.line_num,
# Count a function as removed only if it was removed AND it was not added in the imports!
imported_function_names = current_imports.values()
if function_name not in imported_function_names:
breaking_changes.functions_removed.append(
FunctionRemoved(
function_name=function_name,
line_num=source_function_signature.line_num,
)
)
)

# Check import ONLY if __init__ file.
if module_path.endswith("__init__.py"):
source_imports = parse_imports(code=source_code)
current_imports = parse_imports(code=current_code)
breaking_changes.imports_removed = [
ImportRemoved(import_name=source_import, line_num=0) for source_import in source_imports if source_import not in current_imports
]
return breaking_changes


Expand Down
28 changes: 14 additions & 14 deletions tests/breaking_change_tests/unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

class TestBreakingChangeDetection(unittest.TestCase):
def test_module_removed(self):
old_code = "import package.missing_module"
old_code = "import super_gradients.missing_module"
new_code = ""
self.assertEqual(parse_imports(old_code), {"package.missing_module": "package.missing_module"})
self.assertEqual(parse_imports(old_code), {"super_gradients.missing_module": "super_gradients.missing_module"})
self.assertEqual(parse_imports(new_code), {})

# Imports not checked in regular modules
Expand All @@ -25,11 +25,11 @@ def test_module_removed(self):

# Attributes of the breaking change
breaking_change = breaking_changes.imports_removed[0]
self.assertEqual(breaking_change.import_name, "package.missing_module")
self.assertEqual(breaking_change.import_name, "super_gradients.missing_module")

def test_module_renamed(self):
old_code = "import old_name"
new_code = "import new_name"
old_code = "import super_gradients"
new_code = "import new_module"
self.assertNotEqual(parse_imports(old_code), parse_imports(new_code))

# Imports not checked in regular modules
Expand All @@ -42,11 +42,11 @@ def test_module_renamed(self):

# Check the attributes of the breaking change
breaking_change = breaking_changes.imports_removed[0]
self.assertEqual(breaking_change.import_name, "old_name")
self.assertEqual(breaking_change.import_name, "super_gradients")

def test_module_location_changed(self):
old_code = "from package import my_module"
new_code = "from package.subpackage import my_module"
old_code = "from super_gradients import my_module"
new_code = "from super_gradients.subpackage import my_module"
self.assertNotEqual(parse_imports(old_code), parse_imports(new_code))

# Imports not checked in regular modules
Expand All @@ -59,13 +59,13 @@ def test_module_location_changed(self):

# Check the attributes of the breaking change
breaking_change = breaking_changes.imports_removed[0]
self.assertEqual(breaking_change.import_name, "package.my_module")
self.assertEqual(breaking_change.import_name, "super_gradients.my_module")

def test_dependency_version_changed(self):
"""We want to be sensitive to source, not alias! (i.e. we want to distinguish between v1 and v2)"""

old_code = "import library_v1 as library"
new_code = "import library_v2 as library"
old_code = "import super_gradients.library_v1 as library"
new_code = "import super_gradients.library_v2 as library"
self.assertNotEqual(parse_imports(old_code), parse_imports(new_code))

# Imports not checked in regular modules
Expand All @@ -78,7 +78,7 @@ def test_dependency_version_changed(self):

# Check the attributes of the breaking change
breaking_change = breaking_changes.imports_removed[0]
self.assertEqual(breaking_change.import_name, "library_v1")
self.assertEqual(breaking_change.import_name, "super_gradients.library_v1")

def test_function_removed(self):
old_code = "def old_function(): pass"
Expand Down Expand Up @@ -149,8 +149,8 @@ def test_no_changes(self):
self.assertEqual(len(breaking_changes.imports_removed), 0)

def test_multiple_changes(self):
old_code = "import module1\nfrom module2 import function1\ndef my_function(param1, param2): pass"
new_code = "import module3\ndef my_function(param1, param2, param3): pass"
old_code = "import super_gradients.module1\nfrom super_gradients.module2 import function1\ndef my_function(param1, param2): pass"
new_code = "import super_gradients.module3\ndef my_function(param1, param2, param3): pass"

# Imports not checked in regular modules
breaking_changes = extract_code_breaking_changes("module.py", old_code, new_code)
Expand Down

0 comments on commit 56a7b99

Please sign in to comment.