Skip to content

Commit

Permalink
Copy over get_all_subclasses (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Sep 6, 2022
1 parent 32a04eb commit 9b0578f
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/lightning_utilities/core/inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Iterator, Set, Type


def get_all_subclasses_iterator(cls: Type) -> Iterator[Type]:
def recurse(cl: Type) -> Iterator[Type]:
for subclass in cl.__subclasses__():
yield subclass
yield from recurse(subclass)

yield from recurse(cls)


def get_all_subclasses(cls: Type) -> Set[Type]:
return set(get_all_subclasses_iterator(cls))
24 changes: 24 additions & 0 deletions tests/unittests/core/test_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from lightning_utilities.core.inheritance import get_all_subclasses


def test_get_all_subclasses():
class A1:
...

class A2(A1):
...

class B1:
...

class B2(B1):
...

class C(A2, B2):
...

assert get_all_subclasses(A1) == {A2, C}
assert get_all_subclasses(A2) == {C}
assert get_all_subclasses(B1) == {B2, C}
assert get_all_subclasses(B2) == {C}
assert get_all_subclasses(C) == set()

0 comments on commit 9b0578f

Please sign in to comment.