Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failing chains pytests #867

Merged
merged 2 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Change Log

## [0.5.2] (Unreleased)

- Refactor `TableChain` to include `_searched` attribute. #867

## [0.5.1] (March 7, 2024)

### Infrastructure
Expand Down
74 changes: 40 additions & 34 deletions src/spyglass/utils/dj_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ class TableChain:
_link_symbol : str
Symbol used to represent the link between parent and child. Hardcoded
to " -> ".
_has_link : bool
has_link : bool
Cached attribute to store whether parent is linked to child. False if
child is not in parent.descendants or nx.NetworkXNoPath is raised by
nx.shortest_path.
_has_directed_link : bool
True if directed graph is used to find path. False if undirected graph.
link_type : str
'directed' or 'undirected' based on whether path is found with directed
or undirected graph. None if no path is found.
graph : nx.DiGraph
Directed graph of parent's dependencies from datajoint.connection.
names : List[str]
Expand Down Expand Up @@ -175,18 +176,19 @@ def __init__(self, parent: Table, child: Table, connection=None):
self._link_symbol = " -> "
self.parent = parent
self.child = child
self._has_link = True
self._has_directed_link = None
self.link_type = None
self._searched = False

if child.full_table_name not in self.graph.nodes:
logger.warning(
"Can't find item in graph. Try importing: "
+ f"{child.full_table_name}"
)
self._searched = True

def __str__(self):
"""Return string representation of chain: parent -> child."""
if not self._has_link:
if not self.has_link:
return "No link"
return (
to_camel_case(self.parent.table_name)
Expand All @@ -196,19 +198,22 @@ def __str__(self):

def __repr__(self):
"""Return full representation of chain: parent -> {links} -> child."""
return (
"Chain: "
+ self._link_symbol.join([t.table_name for t in self.objects])
if self.names
else "No link"
if not self.has_link:
return "No link"
return "Chain: " + self._link_symbol.join(
[t.table_name for t in self.objects]
)

def __len__(self):
"""Return number of tables in chain."""
if not self.has_link:
return 0
return len(self.names)

def __getitem__(self, index: Union[int, str]) -> dj.FreeTable:
"""Return FreeTable object at index."""
if not self.has_link:
return None
if isinstance(index, str):
for i, name in enumerate(self.names):
if index in name:
Expand All @@ -219,10 +224,12 @@ def __getitem__(self, index: Union[int, str]) -> dj.FreeTable:
def has_link(self) -> bool:
"""Return True if parent is linked to child.

Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath
is raised by nx.shortest_path.
If not searched, search for path. If searched and no link is found,
return False. If searched and link is found, return True.
"""
return self._has_link
if not self._searched:
_ = self.path
return self.link_type is not None

def pk_link(self, src, trg, data) -> float:
"""Return 1 if data["primary"] else float("inf").
Expand All @@ -242,7 +249,7 @@ def find_path(self, directed=True) -> OrderedDict:
If True, use directed graph. If False, use undirected graph.
Defaults to True. Undirected permits paths to traverse from merge
part-parent -> merge part -> merge table. Undirected excludes
PERIPHERAL_TABLES likne interval_list, nwbfile, etc.
PERIPHERAL_TABLES like interval_list, nwbfile, etc.

Returns
-------
Expand All @@ -265,6 +272,9 @@ def find_path(self, directed=True) -> OrderedDict:
path = nx.shortest_path(self.graph, source, target)
except nx.NetworkXNoPath:
return None
except nx.NodeNotFound:
self._searched = True
return None

ret = OrderedDict()
prev_table = None
Expand All @@ -283,48 +293,44 @@ def find_path(self, directed=True) -> OrderedDict:
@cached_property
def path(self) -> OrderedDict:
"""Return list of full table names in chain."""
if not self._has_link:
if self._searched and not self.has_link:
return None

link = None
if link := self.find_path(directed=True):
self._has_directed_link = True
self.link_type = "directed"
elif link := self.find_path(directed=False):
self._has_directed_link = False
self.link_type = "undirected"
self._searched = True

if link:
return link

self._has_link = False
return None
return link

@cached_property
def names(self) -> List[str]:
"""Return list of full table names in chain."""
if self._has_link:
return list(self.path.keys())
return None
if not self.has_link:
return None
return list(self.path.keys())

@cached_property
def objects(self) -> List[dj.FreeTable]:
"""Return list of FreeTable objects for each table in chain.

Unused. Preserved for future debugging.
"""
if self._has_link:
return [v["free_table"] for v in self.path.values()]
return None
if not self.has_link:
return None
return [v["free_table"] for v in self.path.values()]

@cached_property
def attr_maps(self) -> List[dict]:
"""Return list of attribute maps for each table in chain.

Unused. Preserved for future debugging.
"""
#
if self._has_link:
return [v["attr_map"] for v in self.path.values()]
return None
if not self.has_link:
return None
return [v["attr_map"] for v in self.path.values()]

def join(
self, restriction: str = None, reverse_order: bool = False
Expand All @@ -339,7 +345,7 @@ def join(
reverse_order : bool, optional
If True, join tables in reverse order. Defaults to False.
"""
if not self._has_link:
if not self.has_link:
return None

restriction = restriction or self.parent.restriction or True
Expand Down
19 changes: 11 additions & 8 deletions tests/utils/test_chains.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from datajoint.utils import to_camel_case


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -31,16 +32,13 @@ def test_invalid_chain(Nwbfile, pos_merge_tables, TableChain):
def test_chain_str(chain):
"""Test that the str of a TableChain object is as expected."""
chain = chain
str_got = str(chain)
str_exp = (
chain.parent.table_name + chain._link_symbol + chain.child.table_name
)
assert str_got == str_exp, "Unexpected str of TableChain object."
parent = to_camel_case(chain.parent.table_name)
child = to_camel_case(chain.child.table_name)

str_got = str(chain)
str_exp = parent + chain._link_symbol + child

def test_chain_str_no_link(no_link_chain):
"""Test that the str of a TableChain object with no link is as expected."""
assert str(no_link_chain) == "No link", "Unexpected str of no link chain."
assert str_got == str_exp, "Unexpected str of TableChain object."


def test_chain_repr(chain):
Expand All @@ -66,3 +64,8 @@ def test_chain_getitem(chain):

def test_nolink_join(no_link_chain):
assert no_link_chain.join() is None, "Unexpected join of no link chain."


def test_chain_str_no_link(no_link_chain):
"""Test that the str of a TableChain object with no link is as expected."""
assert str(no_link_chain) == "No link", "Unexpected str of no link chain."
Loading