Skip to content

Commit

Permalink
Fix #20, handle function calls without names
Browse files Browse the repository at this point in the history
  • Loading branch information
mschwager committed Dec 9, 2024
1 parent 549a934 commit ad43839
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
ignore = E501,W504,H601
ignore = E501,W503,H601
include = cohesion,tests
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ This project adheres to [Semantic Versioning](http://semver.org/).
This project adheres to [CHANGELOG](http://keepachangelog.com/).

## [Unreleased]
### Added
- Python 3.12, 3.13, and 3.14 support

### Removed
- Python 3.7, and 3.8 support

### Fixed
- Function calls without names ([#20](https://github.com/mschwager/cohesion/issues/20))

## [1.1.0]
### Added
Expand Down
8 changes: 4 additions & 4 deletions cohesion/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def class_cohesion_percentage(self, class_name):
)

total_class_variable_count = (
len(self.structure[class_name]["variables"]) *
len(self.structure[class_name]["functions"])
len(self.structure[class_name]["variables"])
* len(self.structure[class_name]["functions"])
)

if total_class_variable_count != 0.0:
class_percentage = round((
total_function_variable_count /
total_class_variable_count
total_function_variable_count
/ total_class_variable_count
) * 100, 2)
else:
class_percentage = 0.0
Expand Down
29 changes: 17 additions & 12 deletions cohesion/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ def get_object_name(obj):
ast.FunctionDef: "name",
ast.ClassDef: "name",
ast.Subscript: "value",
ast.arg: "arg",
}

# This is a new ast type in Python 3
if hasattr(ast, "arg"):
name_dispatch[ast.arg] = "arg"

while not isinstance(obj, str):
assert type(obj) in name_dispatch
if type(obj) not in name_dispatch:
return None
obj = getattr(obj, name_dispatch[type(obj)])

return obj
Expand All @@ -47,14 +45,18 @@ def is_class_method_bound(method, arg_name=BOUND_METHOD_ARGUMENT_NAME):

first_arg_name = get_object_name(first_arg)

return first_arg_name == arg_name
return first_arg_name is not None and first_arg_name == arg_name


def class_method_has_decorator(method, decorator):
"""
Return whether a class method has a specific decorator
"""
return decorator in [get_object_name(d) for d in method.decorator_list]
return decorator in [
object_name
for dec in method.decorator_list
if (object_name := get_object_name(dec)) is not None
]


def is_class_method_classmethod(method):
Expand Down Expand Up @@ -101,13 +103,14 @@ def get_instance_variables(node, bound_name_classifier=BOUND_METHOD_ARGUMENT_NAM
node_attributes = [
child
for child in ast.walk(node)
if isinstance(child, ast.Attribute) and
get_attribute_name_id(child) == bound_name_classifier
if isinstance(child, ast.Attribute)
and get_attribute_name_id(child) == bound_name_classifier
]
node_function_call_names = [
get_object_name(child)
object_name
for child in ast.walk(node)
if isinstance(child, ast.Call)
and (object_name := get_object_name(child)) is not None
]
node_instance_variables = [
attribute
Expand All @@ -123,8 +126,9 @@ def get_all_class_variable_names_used_in_method(method):
given method
"""
return {
get_object_name(variable)
object_name
for variable in get_instance_variables(method)
if (object_name := get_object_name(variable)) is not None
}


Expand All @@ -141,8 +145,9 @@ def get_all_class_variable_names(cls):
given class
"""
return {
get_object_name(variable)
object_name
for variable in get_all_class_variables(cls)
if (object_name := get_object_name(variable)) is not None
}


Expand Down
12 changes: 0 additions & 12 deletions tests/test_filesystem.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import collections
import os
import textwrap
import unittest
Expand All @@ -19,17 +18,6 @@ def tearDown(self):
# It is no longer necessary to add self.tearDownPyfakefs()
pass

def assertCountEqual(self, first, second):
"""
Test whether two sequences contain the same elements.
This exists in Python 3, but not Python 2.
"""
self.assertEqual(
collections.Counter(list(first)),
collections.Counter(list(second))
)

def test_get_file_contents(self):
filename = os.path.join("directory", "filename.py")

Expand Down
12 changes: 0 additions & 12 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python

import collections
import os
import textwrap
import unittest
Expand All @@ -15,17 +14,6 @@ class TestModule(unittest.TestCase):
def assertEmpty(self, iterable):
self.assertEqual(len(iterable), 0)

def assertCountEqual(self, first, second):
"""
Test whether two sequences contain the same elements.
This exists in Python 3, but not Python 2.
"""
self.assertEqual(
collections.Counter(list(first)),
collections.Counter(list(second))
)

def test_module_empty(self):
python_string = textwrap.dedent("")

Expand Down
41 changes: 29 additions & 12 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python

import ast
import collections
import textwrap
import unittest

Expand All @@ -13,17 +12,6 @@ class TestParser(unittest.TestCase):
def assertEmpty(self, iterable):
self.assertEqual(len(iterable), 0)

def assertCountEqual(self, first, second):
"""
Test whether two sequences contain the same elements.
This exists in Python 3, but not Python 2.
"""
self.assertEqual(
collections.Counter(list(first)),
collections.Counter(list(second))
)

def test_valid_syntax(self):
python_string = textwrap.dedent("""
a = 5
Expand Down Expand Up @@ -705,6 +693,35 @@ def func(self):

self.assertCountEqual(result, expected)

def test_get_all_class_variable_names_missing_function_name(self):
python_string = textwrap.dedent("""
class C():
def a(self):
print('a')
def b(self):
print('b')
def f(self, choice):
(self.a if choice else self.b)()
c = C()
c.f(True)
c.f(False)
""")

node = parser.get_ast_node_from_string(python_string)
classes = parser.get_module_classes(node)
result = [
name
for cls in classes
for name in parser.get_all_class_variable_names(cls)
]
expected = ['a', 'b']

self.assertCountEqual(result, expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit ad43839

Please sign in to comment.